"""
Incident beam and emidded radiation attenuation support.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
import concurrent.futures as cf
import multiprocessing as mp
from collections.abc import Callable, Sequence
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes._axes import Axes
from numpy.typing import ArrayLike, DTypeLike, NDArray
from tqdm.auto import tqdm
from corrct import _projector_backends as prj_backends
from corrct import models
from corrct.physics.xraylib_helper import get_compound, get_compound_cross_section
from corrct.physics.xrf import DetectorXRF
from corrct.processing.pre import bin_imgs
num_threads = round(np.log2(mp.cpu_count() + 1))
NDArrayFloat = NDArray[np.floating]
NDArrayInt = NDArray[np.integer]
CONVERT_UM_TO_CM = 1e-4
[docs]
class AttenuationVolume:
"""Attenuation volume computation class."""
incident_local: NDArrayFloat | None
emitted_local: NDArrayFloat | None
angles_rot_rad: NDArrayFloat
detectors: Sequence[DetectorXRF]
emitted_sub_sampling: int
dtype: DTypeLike
vol_shape_zyx: NDArray
maps: NDArray
def __init__(
self,
incident_local: NDArrayFloat | None,
emitted_local: NDArrayFloat | None,
angles_rot_rad: NDArrayFloat | Sequence[float],
angles_det_rad: NDArrayFloat | Sequence[float | DetectorXRF] | float | DetectorXRF = np.pi / 2,
emitted_sub_sampling: int = 1,
dtype: DTypeLike = np.float32,
):
"""
Initialize the AttenuationVolume class.
Raises
------
ValueError
In case no volumes were passed, or if they differed in shape.
"""
self.incident_local = incident_local
self.emitted_local = emitted_local
self.angles_rot_rad = np.array(angles_rot_rad, ndmin=1)
self.detectors = []
if isinstance(angles_det_rad, (int, float, DetectorXRF)):
angles_det_rad = [angles_det_rad]
for a in angles_det_rad:
if isinstance(a, DetectorXRF):
self.detectors.append(a)
elif isinstance(a, float):
self.detectors.append(DetectorXRF(surface_mm2=0, distance_mm=1, angle_rad=a))
else:
raise ValueError(
f"Input parameter {angles_det_rad = } should be one of: float | NDArray | Sequence[float | DetectorXRF]"
)
self.dtype = dtype
if self.incident_local is not None:
self.vol_shape_zyx = np.array(self.incident_local.shape)
if self.emitted_local is not None and np.any(self.vol_shape_zyx != self.emitted_local.shape):
raise ValueError(
f"Incident volume shape ({self.incident_local.shape}) does not"
+ f" match the emitted volume shape ({self.emitted_local.shape})"
)
elif self.emitted_local is not None:
self.vol_shape_zyx = np.array(self.emitted_local.shape)
else:
raise ValueError("No attenuation volumes were given.")
self.emitted_sub_sampling = emitted_sub_sampling
self.vol_shape_zyx = np.array(self.vol_shape_zyx, ndmin=1)
num_dims = len(self.vol_shape_zyx)
if num_dims not in [2, 3]:
raise ValueError(f"Maps can only be 2D or 3D Arrays. A {num_dims}-dimensional was passed ({self.vol_shape_zyx}).")
[docs]
def _get_detector_angles(self) -> NDArray:
return np.array([det.angle_rad for det in self.detectors], ndmin=1)
[docs]
def _compute_attenuation_angle_in(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray:
return prj_backends.compute_attenuation(local_att, angle_rad, invert=False)[None, ...]
[docs]
def _compute_attenuation_angle_out(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray:
angle_det = angle_rad + self._get_detector_angles()
atts = np.zeros(self.maps.shape[1:], dtype=self.dtype)
for ii, a in enumerate(angle_det):
if self.detectors[ii].surface_mm2 > 0.0:
sub_angles = np.linspace(*self.detectors[ii].angle_range_rad, self.emitted_sub_sampling + 1)
sub_angles = np.convolve(sub_angles, np.ones(2), "valid") / 2
else:
sub_angles = np.zeros(1)
for a_s in sub_angles:
atts[ii, ...] += prj_backends.compute_attenuation(local_att, a + a_s, invert=True)
atts[ii, ...] /= len(sub_angles)
return atts
[docs]
def compute_maps(self, use_multithreading: bool = True, verbose: bool = True) -> None:
"""
Compute the correction maps for each angle.
Parameters
----------
use_multithreading : bool, optional
Use multi-threading for computing the attenuation maps. The default is True.
verbose : bool, optional
Show verbose output. The default is True.
"""
num_rot_angles = len(self.angles_rot_rad)
self.maps = np.ones([num_rot_angles, len(self.detectors), *self.vol_shape_zyx], dtype=self.dtype)
def process_angles(
func: Callable[[NDArray, float], NDArray], att_vol: NDArrayFloat, angles: NDArrayFloat, description: str
) -> None:
if use_multithreading:
with cf.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures_to_angle = {executor.submit(func, att_vol, a): (ii, a) for ii, a in enumerate(angles)}
try:
for f in tqdm(
cf.as_completed(futures_to_angle),
desc=description,
disable=(not verbose),
total=num_rot_angles,
):
ii, a = futures_to_angle[f]
try:
self.maps[ii, ...] *= f.result()
except ValueError as exc:
raise RuntimeError(f"Angle {a} (#{ii}) generated an exception") from exc
except:
print("Shutting down..", end="", flush=True)
executor.shutdown(cancel_futures=True)
print("\b\b: Done.")
raise
else:
for ii, a in enumerate(tqdm(angles, desc=description, disable=(not verbose))):
self.maps[ii, ...] *= func(att_vol, a)
if self.incident_local is not None:
description = "Computing attenuation maps for incident beam"
process_angles(
self._compute_attenuation_angle_in, self.incident_local, angles=self.angles_rot_rad, description=description
)
if self.emitted_local is not None:
if self.emitted_sub_sampling > 1:
for ii, det in enumerate(self.detectors):
sub_angles = np.linspace(*det.angle_range_rad, self.emitted_sub_sampling + 1)
sub_angles = np.convolve(sub_angles, np.ones(2), "valid") / 2
print(
f"Detector #{ii}: Super-sampling outgoing sub-angles: {np.round(np.rad2deg(sub_angles), decimals=3)} deg"
)
description = "Computing attenuation maps for emitted photons"
process_angles(
self._compute_attenuation_angle_out, self.emitted_local, angles=self.angles_rot_rad, description=description
)
[docs]
def plot_map(
self,
ax: Axes,
rot_ind: int,
det_ind: int = 0,
slice_ind: int | None = None,
axes: Sequence[int] | NDArrayInt = (-2, -1),
) -> Sequence[float]:
"""
Plot the requested attenuation map.
Parameters
----------
ax : matplotlib axes
The axes where to plot.
rot_ind : int
Rotation angle index.
det_ind : int, optional
Detector angle index. The default is 0.
slice_ind : int | None, optional
Volume slice index (for 3D volumes). The default is None.
axes : Sequence[int] | NDArray, optional
Axes of the slice. The default is (-2, -1).
Returns
-------
Sequence[float]
The extent of the axes plot (min-max coords).
Raises
------
ValueError
In case a slice index is not passed for a 3D volume.
"""
att_map = np.squeeze(self.get_maps(rot_ind=rot_ind, det_ind=det_ind))
other_dim = np.squeeze(np.delete(np.arange(-3, 0), axes))
if len(att_map.shape) == 3:
if slice_ind is None:
raise ValueError("Slice index is needed for 3D volumes. None was passed.")
att_map = np.take(att_map, slice_ind, axis=int(other_dim))
slice_shape = self.vol_shape_zyx[list(axes)]
coords = [(-(s - 1) / 2, (s - 1) / 2) for s in slice_shape]
extent = tuple(np.concatenate(coords))
ax.imshow(att_map, extent=extent)
if other_dim == -3:
arrow_length = np.linalg.norm(slice_shape) / np.pi
arrow_args = dict(
width=arrow_length / 25,
head_width=arrow_length / 8,
head_length=arrow_length / 6,
length_includes_head=True,
)
prj_geom = models.ProjectionGeometry.get_default_parallel()
beam_i_geom = prj_geom.rotate(-self.angles_rot_rad[rot_ind])
beam_e_geom = prj_geom.rotate(-(self.angles_rot_rad[rot_ind] + self.detectors[det_ind].angle_rad))
beam_i_dir = beam_i_geom.src_pos_xyz[0, :2] * arrow_length
beam_i_orig = -beam_i_dir
beam_e_dir = beam_e_geom.src_pos_xyz[0, :2] * arrow_length
beam_e_orig = np.array([0, 0])
ax.arrow(*beam_i_orig, *beam_i_dir, **arrow_args, fc="r", ec="r")
ax.arrow(*beam_e_orig, *beam_e_dir, **arrow_args, fc="k", ec="k")
return extent
[docs]
def get_maps(
self,
roi: ArrayLike | None = None,
rot_ind: int | slice | Sequence[int] | NDArrayInt | None = None,
det_ind: int | slice | Sequence[int] | NDArrayInt | None = None,
binning: int = 1,
) -> NDArray:
"""
Return the attenuation maps.
Parameters
----------
roi : ArrayLike | None, optional
The region-of-interest to select. The default is None.
rot_ind : int | slice | Sequence[int] | NDArrayInt | None, optional
The rotation index or indices to select. The default is None.
det_ind : int | slice | Sequence[int] | NDArrayInt | None, optional
The detector index or indices to select. The default is None.
binning : int, optional
The binning factor to apply to the maps. The default is 1.
Returns
-------
NDArray
The attenuation maps.
"""
maps = self.maps
if rot_ind is not None:
if isinstance(rot_ind, int):
rot_ind = slice(rot_ind, rot_ind + 1, 1)
maps = maps[rot_ind, ...]
if det_ind is not None:
if isinstance(det_ind, int):
det_ind = slice(det_ind, det_ind + 1, 1)
maps = maps[:, det_ind, ...]
if roi is not None:
raise NotImplementedError("Extracting a region of interest is not supported, yet.")
if binning > 1:
maps = bin_imgs(maps, binning=binning, axes=tuple(range(-len(self.vol_shape_zyx), 0)), auto_crop=True)
return maps
[docs]
def get_projector_args(
self,
roi: ArrayLike | None = None,
rot_ind: int | slice | Sequence[int] | NDArrayInt | None = None,
det_ind: int | slice | Sequence[int] | NDArrayInt | None = None,
binning: int = 1,
) -> dict[str, NDArray]:
"""
Return the projector arguments.
Parameters
----------
roi : ArrayLike | None, optional
The region-of-interest to select. The default is None.
rot_ind : int | slice | Sequence[int] | NDArrayInt | None, optional
The rotation index or indices to select. The default is None.
det_ind : int | slice | Sequence[int] | NDArrayInt | None, optional
The detector index or indices to select. The default is None.
binning : int, optional
The binning factor to apply to the maps. The default is 1.
Returns
-------
dict[str, NDArray]
A dictionary containing the attenuation maps and the detector angle.
"""
if det_ind is None:
det_angles = self._get_detector_angles()
else:
det_angles = self.detectors[det_ind].angle_rad
return dict(
att_maps=self.get_maps(roi=roi, rot_ind=rot_ind, det_ind=det_ind, binning=binning), angles_detectors_rad=det_angles
)
[docs]
def get_linear_attenuation_coefficient(
compound: str | dict, energy_keV: float, pixel_size_um: float, density: float | None = None
) -> float:
"""Compute the linear attenuation coefficient for given compound, energy, and pixel size.
Parameters
----------
compound : str | dict
The compound for which we compute the linear attenuation coefficient
energy_keV : float
The energy of the photons
pixel_size_um : float
The pixel size in microns
density : float | None, optional
The density of the compound (if different from the default value), by default None
Returns
-------
float
The linear attenuation coefficient
"""
if isinstance(compound, str):
compound = get_compound(compound)
if density is not None:
compound["density"] = density
cmp_cs = get_compound_cross_section(compound, energy_keV)
return pixel_size_um * CONVERT_UM_TO_CM * compound["density"] * cmp_cs
[docs]
def plot_emission_line_attenuation(
compound: str | dict,
thickness_um: float,
mean_energy_keV: float,
fwhm_keV: float,
line_shape: str = "lorentzian",
num_points: int = 201,
plot_lines_mean: bool = True,
) -> None:
"""Plot spectral attenuation of a given line.
Parameters
----------
compound : str | dict
Compound to consider
thickness_um : float
Thickness of the compound (in microns)
mean_energy_keV : float
Average energy of the line
fwhm_keV : float
Full-width half-maximum of the line
line_shape : str, optional
Shape of the line, by default "lorentzian".
Options are: "gaussian" | "lorentzian" | "sech**2".
num_points : int, optional
number of discretization points, by default 201
plot_lines_mean : bool, optional
Whether to plot the line mean and the effective line mean, by default True
Raises
------
ValueError
When an unsupported line is chosen.
"""
xc = np.linspace(-0.5, 0.5, num_points)
if line_shape.lower() == "gaussian":
xc *= fwhm_keV * 3
yg = np.exp(-4 * np.log(2) * (xc**2) / (fwhm_keV**2))
elif line_shape.lower() == "lorentzian":
xc *= fwhm_keV * 13
hwhm_keV = fwhm_keV / 2
yg = hwhm_keV / (xc**2 + hwhm_keV**2)
elif line_shape.lower() == "sech**2":
# doi: 10.1364/ol.20.001160
xc *= fwhm_keV * 4
tau = fwhm_keV / (2 * np.arccosh(np.sqrt(2)))
yg = 1 / np.cosh(xc / tau) ** 2
else:
raise ValueError(f"Unknown beam shape: {line_shape.lower()}")
nrgs_keV = xc + mean_energy_keV
if isinstance(compound, str):
compound = get_compound(compound)
atts = np.empty_like(yg)
for ii, nrg in enumerate(nrgs_keV):
cmp_cs = get_compound_cross_section(compound, nrg)
atts[ii] = np.exp(-thickness_um * CONVERT_UM_TO_CM * compound["density"] * cmp_cs)
yg = yg / np.max(yg)
mean_effective_energy_keV = float(nrgs_keV.dot(yg * atts) / yg.dot(atts))
fig, axs_line = plt.subplots(1, 1)
pl_line = axs_line.plot(nrgs_keV, yg, label="$I_0$", color="C0")
if plot_lines_mean:
axs_line.axvline(mean_energy_keV, linestyle="--", color="C0")
axs_line.tick_params(axis="y", labelcolor="C0", labelsize=14)
axs_line.tick_params(axis="x", labelsize=14)
axs_line.set_ylabel("Intensity", color="C0", fontsize=14)
axs_atts = axs_line.twinx()
pl_atts = axs_atts.plot(nrgs_keV, atts, label="exp(-$\\mu (E) x)$", color="C1")
pl_line_att = axs_atts.plot(nrgs_keV, yg * atts, label=r"$I_{measured}$", color="C2")
if plot_lines_mean:
axs_atts.axvline(mean_effective_energy_keV, linestyle="--", color="C2")
axs_atts.tick_params(axis="y", labelcolor="C1", labelsize=14)
axs_atts.set_ylabel("Transmittance", color="C1", fontsize=14)
all_pls = pl_line + pl_atts + pl_line_att
axs_atts.legend(all_pls, [str(pl.get_label()) for pl in all_pls], fontsize=13)
axs_line.grid()
fig.tight_layout()
I_lin = np.sum(yg * atts[num_points // 2])
I_meas = yg.dot(atts)
print(f"Expected intensity: {I_lin}, measured: {I_meas} ({I_meas / I_lin:%})")
print(f"Expected mean energy: {nrgs_keV.dot(yg / np.sum(yg)):.5}, effective: {mean_effective_energy_keV:.5}")
[docs]
def plot_transmittance_decay(
compounds: str | dict | Sequence[str | dict],
mean_energy_keV: float,
thickness_range_um: tuple[float, float, int] = (0.0, 10.0, 101),
) -> None:
"""Plot transmittance decay curve(s) for the given compound(s) at a given energy and thickness range.
Parameters
----------
compounds : str | dict | Sequence[str | dict]
The compound(s) description
mean_energy_keV : float
The mean photon energy
thickness_range_um : tuple[float, float, int], optional
The thickness range as (start, end, num_points), by default (0.0, 10.0, 101)
"""
if isinstance(compounds, (str, dict)):
compounds = [compounds]
compounds = [get_compound(c) if isinstance(c, str) else c for c in compounds]
thicknesses_um = np.linspace(*thickness_range_um)
atts = np.zeros((len(compounds), len(thicknesses_um)), dtype=np.float32)
for ii, cmp in enumerate(compounds):
cmp_cs = get_compound_cross_section(cmp, mean_energy_keV)
atts[ii] = np.exp(-thicknesses_um * CONVERT_UM_TO_CM * cmp["density"] * cmp_cs)
fig, axs = plt.subplots(1, 1, figsize=(8, 4))
for ii, cmp in enumerate(compounds):
axs.plot(thicknesses_um, atts[ii], label=cmp["name"])
axs.legend(fontsize=13)
axs.grid()
axs.tick_params(labelsize=14)
axs.set_xlabel(r"Thickness [$\mu m$]", fontsize=14)
axs.set_xlim(thickness_range_um[0], thickness_range_um[1])
axs.set_ylabel("Transmittance", fontsize=14)
axs.set_ylim(0.0, 1.0)
axs.set_title(f"Transmittance curve at {mean_energy_keV:.2f} keV", fontsize=14)
fig.tight_layout()
plt.plot(block=False)