# -*- coding: utf-8 -*-
"""
Incident beam and emidded radiation attenuation support.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from matplotlib.axes._axes import Axes
import numpy as np
from . import _projector_backends as prj_backends
from . import models
import concurrent.futures as cf
import multiprocessing as mp
from tqdm import tqdm
from typing import Dict, Optional, Sequence, Union, List
from numpy.typing import ArrayLike, DTypeLike, NDArray
num_threads = round(np.log2(mp.cpu_count() + 1))
NDArrayFloat = NDArray[np.floating]
NDArrayInt = NDArray[np.integer]
[docs]class AttenuationVolume:
"""Attenuation volume computation class."""
incident_local: Union[NDArrayFloat, None]
emitted_local: Union[NDArrayFloat, None]
angles_rot_rad: NDArrayFloat
angles_det_rad: NDArrayFloat
dtype: DTypeLike
vol_shape_zyx: NDArray
maps: NDArray
def __init__(
self,
incident_local: Union[NDArrayFloat, None],
emitted_local: Union[NDArrayFloat, None],
angles_rot_rad: ArrayLike,
angles_det_rad: Union[NDArrayFloat, ArrayLike, float] = np.pi / 2,
dtype: DTypeLike = np.float32,
):
"""
Initialize the AttenuationVolume class.
Raises
------
ValueError
In case no volumes were passed, or if they differed in shape.
"""
self.incident_local = incident_local
self.emitted_local = emitted_local
self.angles_rot_rad = np.array(angles_rot_rad, ndmin=1)
self.angles_det_rad = np.array(angles_det_rad, ndmin=1)
self.dtype = dtype
if self.incident_local is not None:
self.vol_shape_zyx = np.array(self.incident_local.shape)
if self.emitted_local is not None and np.any(self.vol_shape_zyx != self.emitted_local.shape):
raise ValueError(
f"Incident volume shape ({self.incident_local.shape}) does not"
+ f" match the emitted volume shape ({self.emitted_local.shape})"
)
elif self.emitted_local is not None:
self.vol_shape_zyx = np.array(self.emitted_local.shape)
else:
raise ValueError("No attenuation volumes were given.")
self.vol_shape_zyx = np.array(self.vol_shape_zyx, ndmin=1)
num_dims = len(self.vol_shape_zyx)
if num_dims not in [2, 3]:
raise ValueError(f"Maps can only be 2D or 3D Arrays. A {num_dims}-dimensional was passed ({self.vol_shape_zyx}).")
def _compute_attenuation_angle_in(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray:
return prj_backends.compute_attenuation(local_att, angle_rad, invert=False)[None, ...]
def _compute_attenuation_angle_out(self, local_att: NDArrayFloat, angle_rad: float) -> NDArray:
angle_det = angle_rad + self.angles_det_rad
atts = np.empty(self.maps.shape[1:], dtype=self.dtype)
for ii, a in enumerate(angle_det):
atts[ii, ...] = prj_backends.compute_attenuation(local_att, a, invert=True)
return atts
[docs] def compute_maps(self, use_multithreading: bool = True, verbose: bool = True) -> None:
"""
Compute the correction maps for each angle.
Parameters
----------
use_multithreading : bool, optional
Use multi-threading for computing the attenuation maps. The default is True.
verbose : bool, optional
Show verbose output. The default is True.
"""
num_rot_angles = len(self.angles_rot_rad)
self.maps = np.ones([num_rot_angles, len(self.angles_det_rad), *self.vol_shape_zyx], dtype=self.dtype)
if self.incident_local is not None:
description = "Computing attenuation maps for incident beam: "
if use_multithreading:
r = []
with cf.ThreadPoolExecutor(max_workers=num_threads) as executor:
# angle_atts = executor.map(self._compute_attenuation_angle_in, self.angles_rot_rad)
for a in self.angles_rot_rad:
r.append(executor.submit(self._compute_attenuation_angle_in, self.incident_local, a))
for ii in tqdm(range(num_rot_angles), desc=description, disable=(not verbose)):
self.maps[ii, ...] *= r[ii].result()
else:
for ii, a in enumerate(tqdm(self.angles_rot_rad, desc=description, disable=(not verbose))):
self.maps[ii, ...] *= self._compute_attenuation_angle_in(self.incident_local, a)
if self.emitted_local is not None:
description = "Computing attenuation maps for emitted photons: "
if use_multithreading:
r = []
with cf.ThreadPoolExecutor(max_workers=num_threads) as executor:
for a in self.angles_rot_rad:
r.append(executor.submit(self._compute_attenuation_angle_out, self.emitted_local, a))
for ii in tqdm(range(num_rot_angles), desc=description, disable=(not verbose)):
self.maps[ii, ...] *= r[ii].result()
else:
for ii, a in enumerate(tqdm(self.angles_rot_rad, desc=description, disable=(not verbose))):
self.maps[ii, ...] *= self._compute_attenuation_angle_out(self.emitted_local, a)
[docs] def plot_map(
self,
ax: Axes,
rot_ind: int,
det_ind: int = 0,
slice_ind: Optional[int] = None,
axes: Union[Sequence[int], NDArrayInt] = (-2, -1),
) -> List[float]:
"""
Plot the requested attenuation map.
Parameters
----------
ax : matplotlib axes
The axes where to plot.
rot_ind : int
Rotation angle index.
det_ind : int, optional
Detector angle index. The default is 0.
slice_ind : Optional[int], optional
Volume slice index (for 3D volumes). The default is None.
axes : Sequence[int] | NDArray, optional
Axes of the slice. The default is (-2, -1).
Returns
-------
List[float]
The extent of the axes plot (min-max coords).
Raises
------
ValueError
In case a slice index is not passed for a 3D volume.
"""
att_map = np.squeeze(self.get_maps(rot_ind=rot_ind, det_ind=det_ind))
other_dim = np.squeeze(np.delete(np.arange(-3, 0), axes))
if len(att_map.shape) == 3:
if slice_ind is None:
raise ValueError("Slice index is needed for 3D volumes. None was passed.")
att_map = np.take(att_map, slice_ind, axis=int(other_dim))
slice_shape = self.vol_shape_zyx[list(axes)]
coords = [(-(s - 1) / 2, (s - 1) / 2) for s in slice_shape]
extent = list(np.concatenate(coords))
ax.imshow(att_map, extent=extent)
if other_dim == -3:
arrow_length = np.linalg.norm(slice_shape) / np.pi
arrow_args = dict(
width=arrow_length / 25,
head_width=arrow_length / 8,
head_length=arrow_length / 6,
length_includes_head=True,
)
prj_geom = models.ProjectionGeometry.get_default_parallel()
beam_i_geom = prj_geom.rotate(-self.angles_rot_rad[rot_ind])
beam_e_geom = prj_geom.rotate(-(self.angles_rot_rad[rot_ind] + self.angles_det_rad[det_ind]))
beam_i_dir = beam_i_geom.src_pos_xyz[0, :2] * arrow_length
beam_i_orig = -beam_i_dir
beam_e_dir = beam_e_geom.src_pos_xyz[0, :2] * arrow_length
beam_e_orig = np.array([0, 0])
ax.arrow(*beam_i_orig, *beam_i_dir, **arrow_args, fc="r", ec="r")
ax.arrow(*beam_e_orig, *beam_e_dir, **arrow_args, fc="k", ec="k")
return extent
[docs] def get_maps(
self,
roi: Optional[ArrayLike] = None,
rot_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None,
det_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None,
) -> NDArray:
"""
Return the attenuation maps.
Parameters
----------
roi : ArrayLike, optional
The region-of-interest to select. The default is None.
rot_ind : int, optional
A specific rotation index, if only one is to be desired. The default is None.
det_ind : int, optional
A specific detector index, if only one is to be desired. The default is None.
Returns
-------
NDArray
The attenuation maps.
"""
maps = self.maps
if rot_ind is not None:
if isinstance(rot_ind, int):
rot_ind = slice(rot_ind, rot_ind + 1, 1)
maps = maps[rot_ind, ...]
if det_ind is not None:
if isinstance(det_ind, int):
det_ind = slice(det_ind, det_ind + 1, 1)
maps = maps[:, det_ind, ...]
if roi is not None:
raise NotImplementedError("Extracting a region of interest is not supported, yet.")
return maps
[docs] def get_projector_args(
self,
roi: Optional[ArrayLike] = None,
rot_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None,
det_ind: Union[int, slice, Sequence[int], NDArrayInt, None] = None,
) -> Dict[str, NDArray]:
"""
Return the projector arguments.
Parameters
----------
roi : ArrayLike, optional
The region-of-interest to select. The default is None.
rot_ind : int, optional
A specific rotation index, if only one is to be desired. The default is None.
det_ind : int, optional
A specific detector index, if only one is to be desired. The default is None.
Returns
-------
Dict[str, NDArray]
A dictionary containing the attenuation maps and the detector angle.
"""
if det_ind is None:
det_angles = self.angles_det_rad
else:
det_angles = self.angles_det_rad[det_ind]
return dict(att_maps=self.get_maps(roi=roi, rot_ind=rot_ind, det_ind=det_ind), angles_detectors_rad=det_angles)