Source code for corrct.physics.attenuation

"""
Incident beam and emidded radiation attenuation support.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

import concurrent.futures as cf
import multiprocessing as mp
from collections.abc import Sequence
from typing import Callable, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes._axes import Axes
from numpy.typing import ArrayLike, DTypeLike, NDArray
from tqdm.auto import tqdm

from corrct import _projector_backends as prj_backends
from corrct import models
from corrct.physics.xraylib_helper import get_compound, get_compound_cross_section

num_threads = round(np.log2(mp.cpu_count() + 1))

NDArrayFloat = NDArray[np.floating]
NDArrayInt = NDArray[np.integer]

CONVERT_UM_TO_CM = 1e-4


[docs] class AttenuationVolume: """Attenuation volume computation class.""" incident_local: Union[NDArrayFloat, None] emitted_local: Union[NDArrayFloat, None] angles_rot_rad: NDArrayFloat angles_det_rad: NDArrayFloat dtype: DTypeLike vol_shape_zyx: NDArray maps: NDArray def __init__( self, incident_local: Union[NDArrayFloat, None], emitted_local: Union[NDArrayFloat, None], angles_rot_rad: ArrayLike, angles_det_rad: Union[NDArrayFloat, ArrayLike, float] = np.pi / 2, dtype: DTypeLike = np.float32, ): """ Initialize the AttenuationVolume class. Raises ------ ValueError In case no volumes were passed, or if they differed in shape. """ self.incident_local = incident_local self.emitted_local = emitted_local self.angles_rot_rad = np.array(angles_rot_rad, ndmin=1) self.angles_det_rad = np.array(angles_det_rad, ndmin=1) self.dtype = dtype if self.incident_local is not None: self.vol_shape_zyx = np.array(self.incident_local.shape) if self.emitted_local is not None and np.any(self.vol_shape_zyx != self.emitted_local.shape): raise ValueError( f"Incident volume shape ({self.incident_local.shape}) does not" + f" match the emitted volume shape ({self.emitted_local.shape})" ) elif self.emitted_local is not None: self.vol_shape_zyx = np.array(self.emitted_local.shape) else: raise ValueError("No attenuation volumes were given.") self.vol_shape_zyx = np.array(self.vol_shape_zyx, ndmin=1) num_dims = len(self.vol_shape_zyx) if num_dims not in [2, 3]: raise ValueError(f"Maps can only be 2D or 3D Arrays. A {num_dims}-dimensional was passed ({self.vol_shape_zyx}).") def _compute_attenuation_angle_in(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray: return prj_backends.compute_attenuation(local_att, angle_rad, invert=False)[None, ...] def _compute_attenuation_angle_out(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray: angle_det = angle_rad + self.angles_det_rad atts = np.empty(self.maps.shape[1:], dtype=self.dtype) for ii, a in enumerate(angle_det): atts[ii, ...] = prj_backends.compute_attenuation(local_att, a, invert=True) return atts
[docs] def compute_maps(self, use_multithreading: bool = True, verbose: bool = True) -> None: """ Compute the correction maps for each angle. Parameters ---------- use_multithreading : bool, optional Use multi-threading for computing the attenuation maps. The default is True. verbose : bool, optional Show verbose output. The default is True. """ num_rot_angles = len(self.angles_rot_rad) self.maps = np.ones([num_rot_angles, len(self.angles_det_rad), *self.vol_shape_zyx], dtype=self.dtype) def process_angles( func: Callable[[NDArray, float], NDArray], att_vol: NDArrayFloat, angles: NDArrayFloat, description: str ) -> None: if use_multithreading: with cf.ThreadPoolExecutor(max_workers=num_threads) as executor: futures_to_angle = {executor.submit(func, att_vol, a): (ii, a) for ii, a in enumerate(angles)} try: for f in tqdm( cf.as_completed(futures_to_angle), desc=description, disable=(not verbose), total=num_rot_angles, ): ii, a = futures_to_angle[f] try: self.maps[ii, ...] *= f.result() except ValueError as exc: raise RuntimeError(f"Angle {a} (#{ii}) generated an exception") from exc except: print("Shutting down..", end="", flush=True) executor.shutdown(cancel_futures=True) print("\b\b: Done.") raise else: for ii, a in enumerate(tqdm(angles, desc=description, disable=(not verbose))): self.maps[ii, ...] *= func(att_vol, a) if self.incident_local is not None: description = "Computing attenuation maps for incident beam" process_angles( self._compute_attenuation_angle_in, self.incident_local, angles=self.angles_rot_rad, description=description ) if self.emitted_local is not None: description = "Computing attenuation maps for emitted photons" process_angles( self._compute_attenuation_angle_out, self.emitted_local, angles=self.angles_rot_rad, description=description )
[docs] def plot_map( self, ax: Axes, rot_ind: int, det_ind: int = 0, slice_ind: Optional[int] = None, axes: Union[Sequence[int], NDArrayInt] = (-2, -1), ) -> Sequence[float]: """ Plot the requested attenuation map. Parameters ---------- ax : matplotlib axes The axes where to plot. rot_ind : int Rotation angle index. det_ind : int, optional Detector angle index. The default is 0. slice_ind : Optional[int], optional Volume slice index (for 3D volumes). The default is None. axes : Sequence[int] | NDArray, optional Axes of the slice. The default is (-2, -1). Returns ------- Sequence[float] The extent of the axes plot (min-max coords). Raises ------ ValueError In case a slice index is not passed for a 3D volume. """ att_map = np.squeeze(self.get_maps(rot_ind=rot_ind, det_ind=det_ind)) other_dim = np.squeeze(np.delete(np.arange(-3, 0), axes)) if len(att_map.shape) == 3: if slice_ind is None: raise ValueError("Slice index is needed for 3D volumes. None was passed.") att_map = np.take(att_map, slice_ind, axis=int(other_dim)) slice_shape = self.vol_shape_zyx[list(axes)] coords = [(-(s - 1) / 2, (s - 1) / 2) for s in slice_shape] extent = tuple(np.concatenate(coords)) ax.imshow(att_map, extent=extent) if other_dim == -3: arrow_length = np.linalg.norm(slice_shape) / np.pi arrow_args = dict( width=arrow_length / 25, head_width=arrow_length / 8, head_length=arrow_length / 6, length_includes_head=True, ) prj_geom = models.ProjectionGeometry.get_default_parallel() beam_i_geom = prj_geom.rotate(-self.angles_rot_rad[rot_ind]) beam_e_geom = prj_geom.rotate(-(self.angles_rot_rad[rot_ind] + self.angles_det_rad[det_ind])) beam_i_dir = beam_i_geom.src_pos_xyz[0, :2] * arrow_length beam_i_orig = -beam_i_dir beam_e_dir = beam_e_geom.src_pos_xyz[0, :2] * arrow_length beam_e_orig = np.array([0, 0]) ax.arrow(*beam_i_orig, *beam_i_dir, **arrow_args, fc="r", ec="r") ax.arrow(*beam_e_orig, *beam_e_dir, **arrow_args, fc="k", ec="k") return extent
[docs] def get_maps( self, roi: Optional[ArrayLike] = None, rot_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None, det_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None, ) -> NDArray: """ Return the attenuation maps. Parameters ---------- roi : ArrayLike, optional The region-of-interest to select. The default is None. rot_ind : int, optional A specific rotation index, if only one is to be desired. The default is None. det_ind : int, optional A specific detector index, if only one is to be desired. The default is None. Returns ------- NDArray The attenuation maps. """ maps = self.maps if rot_ind is not None: if isinstance(rot_ind, int): rot_ind = slice(rot_ind, rot_ind + 1, 1) maps = maps[rot_ind, ...] if det_ind is not None: if isinstance(det_ind, int): det_ind = slice(det_ind, det_ind + 1, 1) maps = maps[:, det_ind, ...] if roi is not None: raise NotImplementedError("Extracting a region of interest is not supported, yet.") return maps
[docs] def get_projector_args( self, roi: Optional[ArrayLike] = None, rot_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None, det_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None, ) -> dict[str, NDArray]: """ Return the projector arguments. Parameters ---------- roi : ArrayLike, optional The region-of-interest to select. The default is None. rot_ind : int, optional A specific rotation index, if only one is to be desired. The default is None. det_ind : int, optional A specific detector index, if only one is to be desired. The default is None. Returns ------- dict[str, NDArray] A dictionary containing the attenuation maps and the detector angle. """ if det_ind is None: det_angles = self.angles_det_rad else: det_angles = self.angles_det_rad[det_ind] return dict(att_maps=self.get_maps(roi=roi, rot_ind=rot_ind, det_ind=det_ind), angles_detectors_rad=det_angles)
[docs] def get_linear_attenuation_coefficient( compound: Union[str, dict], energy_keV: float, pixel_size_um: float, density: Union[float, None] = None ) -> float: """Compute the linear attenuation coefficient for given compound, energy, and pixel size. Parameters ---------- compound : Union[str, dict] The compound for which we compute the linear attenuation coefficient energy_keV : float The energy of the photons pixel_size_um : float The pixel size in microns density : Union[float, None], optional The density of the compound (if different from the default value), by default None Returns ------- float The linear attenuation coefficient """ if isinstance(compound, str): compound = get_compound(compound) if density is not None: compound["density"] = density cmp_cs = get_compound_cross_section(compound, energy_keV) return pixel_size_um * CONVERT_UM_TO_CM * compound["density"] * cmp_cs
[docs] def plot_emission_line_attenuation( compound: Union[str, dict], thickness_um: float, mean_energy_keV: float, fwhm_keV: float, line_shape: str = "lorentzian", num_points: int = 201, ) -> None: """Plot spectral attenuation of a given line. Parameters ---------- compound : Union[str, dict] Compound to consider thickness_um : float Thickness of the compound (in microns) mean_energy_keV : float Average energy of the line fwhm_keV : float Full-width half-maximum of the line line_shape : str, optional Shape of the line, by default "lorentzian". Options are: "gaussian" | "lorentzian" | "sech**2". num_points : int, optional number of discretization points, by default 201 Raises ------ ValueError When an unsupported line is chosen. """ xc = np.linspace(-0.5, 0.5, num_points) if line_shape.lower() == "gaussian": xc *= fwhm_keV * 3 yg = np.exp(-4 * np.log(2) * (xc**2) / (fwhm_keV**2)) elif line_shape.lower() == "lorentzian": xc *= fwhm_keV * 13 hwhm_keV = fwhm_keV / 2 yg = hwhm_keV / (xc**2 + hwhm_keV**2) elif line_shape.lower() == "sech**2": # doi: 10.1364/ol.20.001160 xc *= fwhm_keV * 4 tau = fwhm_keV / (2 * np.arccosh(np.sqrt(2))) yg = 1 / np.cosh(xc / tau) ** 2 else: raise ValueError(f"Unknown beam shape: {line_shape.lower()}") nrgs_keV = xc + mean_energy_keV if isinstance(compound, str): compound = get_compound(compound) atts = np.empty_like(yg) for ii, nrg in enumerate(nrgs_keV): cmp_cs = get_compound_cross_section(compound, nrg) atts[ii] = np.exp(-thickness_um * CONVERT_UM_TO_CM * compound["density"] * cmp_cs) yg = yg / np.max(yg) fig, axs_line = plt.subplots(1, 1) pl_line = axs_line.plot(nrgs_keV, yg, label="$I_0$", color="C0") axs_line.tick_params(axis="y", labelcolor="C0") axs_atts = axs_line.twinx() pl_atts = axs_atts.plot(nrgs_keV, atts, label="$\\mu (E)$", color="C1") pl_line_att = axs_atts.plot(nrgs_keV, yg * atts, label="$I_m$", color="C2") axs_atts.tick_params(axis="y", labelcolor="C1") all_pls = pl_line + pl_atts + pl_line_att axs_atts.legend(all_pls, [str(pl.get_label()) for pl in all_pls]) axs_line.grid() fig.tight_layout() I_lin = np.sum(yg * atts[num_points // 2]) I_meas = yg.dot(atts) print(f"Expected intensity: {I_lin}, measured: {I_meas} ({I_meas / I_lin:%})") print(f"Mean energy {nrgs_keV.dot(yg / np.sum(yg) * (atts / atts[len(atts) // 2]))}, {nrgs_keV.dot(yg / np.sum(yg))}")
[docs] def plot_transmittance_decay( compounds: Union[str, dict, Sequence[Union[str, dict]]], mean_energy_keV: float, thickness_range_um: tuple[float, float, int] = (0.0, 10.0, 101), ) -> None: """Plot transmittance decay curve(s) for the given compound(s) at a given energy and thickness range. Parameters ---------- compounds : str | dict | Sequence[str | dict] The compound(s) description mean_energy_keV : float The mean photon energy thickness_range_um : tuple[float, float, int], optional The thickness range as (start, end, num_points), by default (0.0, 10.0, 101) """ if isinstance(compounds, (str, dict)): compounds = [compounds] compounds = [get_compound(c) if isinstance(c, str) else c for c in compounds] thicknesses_um = np.linspace(*thickness_range_um) atts = np.zeros((len(compounds), len(thicknesses_um)), dtype=np.float32) for ii, cmp in enumerate(compounds): cmp_cs = get_compound_cross_section(cmp, mean_energy_keV) atts[ii] = np.exp(-thicknesses_um * CONVERT_UM_TO_CM * cmp["density"] * cmp_cs) fig, axs = plt.subplots(1, 1, figsize=(8, 4)) for ii, cmp in enumerate(compounds): axs.plot(thicknesses_um, atts[ii], label=cmp["name"]) axs.legend(fontsize=13) axs.grid() axs.tick_params(labelsize=14) axs.set_xlabel("Thickness [$\mu m$]", fontsize=14) axs.set_xlim(thickness_range_um[0], thickness_range_um[1]) axs.set_ylabel("Transmittance", fontsize=14) axs.set_ylim(0.0, 1.0) axs.set_title(f"Transmittance curve at {mean_energy_keV:.2f} keV", fontsize=14) fig.tight_layout() plt.plot(block=False)