Source code for corrct.models

#!/usr/bin/env python3
"""
Define all the models used through-out the code.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from abc import ABC
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import replace as dc_replace
from typing import Any

import numpy as np
import scipy.spatial.transform as spt
from numpy.typing import ArrayLike, NDArray

ROT_DIRS_VALID = ("clockwise", "counter-clockwise")


[docs] class Geometry(ABC): """Base geometry class."""
[docs] def __str__(self) -> str: """ Return a human readable representation of the object. Returns ------- str The human readable representation of the object. """ descr = f"{self.__class__.__name__}(\n" for f, v in self.__dict__.items(): descr += f" {f} = {v},\n" return descr + ")"
[docs] @dataclass class ProjectionGeometry(Geometry): """Store the projection geometry.""" geom_type: str src_pos_xyz: NDArray det_pos_xyz: NDArray det_u_xyz: NDArray det_v_xyz: NDArray rot_dir_xyz: NDArray pix2vox_ratio: float = 1 det_shape_vu: NDArray | None = None
[docs] def __post_init__(self) -> None: self.geom_type = self.geom_type.lower() self.src_pos_xyz = np.array(self.src_pos_xyz, ndmin=2) self.det_pos_xyz = np.array(self.det_pos_xyz, ndmin=2) self.det_u_xyz = np.array(self.det_u_xyz, ndmin=2) self.det_v_xyz = np.array(self.det_v_xyz, ndmin=2) self.rot_dir_xyz = np.array(self.rot_dir_xyz, ndmin=2)
[docs] def __getitem__(self, indx: Any): """ Slice projection geometry along the angular direction. Parameters ---------- indx : Any Indices of the slicing. """ def slice_array(vecs_arr: NDArray, indx: Any): if len(vecs_arr.shape) > 1 and vecs_arr.shape[0] > 1: return vecs_arr[indx, :] else: return vecs_arr return dc_replace( self, src_pos_xyz=slice_array(self.src_pos_xyz, indx), det_pos_xyz=slice_array(self.det_pos_xyz, indx), det_u_xyz=slice_array(self.det_u_xyz, indx), det_v_xyz=slice_array(self.det_v_xyz, indx), )
[docs] def copy(self) -> "ProjectionGeometry": """Deepcopy an existing geometry. Returns ------- ProjectionGeometry The new instance of ProjectionGeometry """ return deepcopy(self)
[docs] @staticmethod def get_default_parallel( *, geom_type: str = "3d", rot_axis_shift_pix: ArrayLike | None = None, rot_axis_dir: str | ArrayLike = "clockwise", ) -> "ProjectionGeometry": """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : ArrayLike | None, optional Rotation axis shift in pixels. The default is None. rot_axis_dir : str | ArrayLike, optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default parallel-beam geometry. """ return get_prj_geom_parallel(geom_type=geom_type, rot_axis_shift_pix=rot_axis_shift_pix, rot_axis_dir=rot_axis_dir)
@property def ndim(self) -> int: """Return the number of dimensions of the geometry. Returns ------- int The number of dimensions. """ if "parallel" in self.geom_type: return int(self.geom_type[-2]) elif self.geom_type.lower() == "cone": return 3 elif self.geom_type.lower() == "fanflat": return 2 else: raise ValueError( f"Geometry ({self.geom_type}) needs to be one of: 'parallel2d' | 'parallel3d' | 'cone' | 'fanflat'." )
[docs] def get_3d(self) -> "ProjectionGeometry": """Return the 3D version of the geometry. Returns ------- ProjectionGeometry The new geometry. """ if self.ndim == 2: if self.det_shape_vu is not None: new_det_shape_vu = np.ones(2, dtype=int) new_det_shape_vu[-len(self.det_shape_vu) :] = self.det_shape_vu else: new_det_shape_vu = None return dc_replace(self, geom_type=self.geom_type.replace("2d", "3d"), det_shape_vu=new_det_shape_vu) else: return dc_replace(self)
[docs] def set_detector_shape_vu(self, vu: int | Sequence[int] | NDArray) -> None: """Set the detector VU shape. Parameters ---------- vu : int | Sequence[int] | NDArray The VU shape of the projection data. """ self.det_shape_vu = np.array(vu, ndmin=1)
[docs] def set_detector_shifts_vu( self, det_pos_vu: ArrayLike | NDArray | None = None, cor_pos_u: float | None = None, det_dist_y: ArrayLike = 0.0, ) -> None: """ Set the detector position in XZ, from VU (vertical, horizontal) coordinates. Parameters ---------- det_pos_vu : ArrayLike | NDArray | None Detector vertical and horizontal positions. Vertical is optional. cor_pos_u : float | None Center of rotation position along U. det_dist_y : ArrayLike, optional Detector distance from origin along Y. The default is 0.0. """ det_pos_vu = np.array(det_pos_vu if det_pos_vu is not None else 0.0, ndmin=2, dtype=np.float64) if cor_pos_u is not None: det_pos_vu[-1, ...] = det_pos_vu[-1, ...] + cor_pos_u det_dist_y = np.array(det_dist_y, ndmin=1, dtype=np.float64) if det_dist_y.size > 1 and ( det_dist_y.ndim > 1 or (det_pos_vu.shape[1] > 1 and det_dist_y.size != det_pos_vu.shape[1]) ): raise ValueError( f"Detector distance along Y (shape: {det_dist_y.shape}) should either be a scalar or a 1D array of the " f"same length as the detector positions (shape: {det_pos_vu.shape}), if detector positions are more than 1." ) if self.det_pos_xyz.shape[0] > 1 and det_pos_vu.shape[-1] > 1 and self.det_pos_xyz.shape[0] != det_pos_vu.shape[-1]: raise ValueError( f"Current number of angles ({self.det_pos_xyz.shape[-2]}) and new number of " f"angles ({det_pos_vu.shape[-1]}) differ!" ) self.det_pos_xyz = np.zeros((det_pos_vu.shape[-1], 3), dtype=np.float64) self.det_pos_xyz += self.det_u_xyz * det_pos_vu[-1, :].reshape([-1, 1]) if self.ndim == 3 and det_pos_vu.shape[0] == 2: self.det_pos_xyz += self.det_v_xyz * det_pos_vu[-2, :].reshape([-1, 1]) self.det_pos_xyz[:, 1] += det_dist_y
[docs] def set_source_shifts_vu(self, src_pos_vu: ArrayLike | NDArray | None = None) -> None: """ Set the source position in XZ, from VU (vertical, horizontal) coordinates. Parameters ---------- src_pos_vu : ArrayLike | NDArray | None Source vertical and horizontal positions. Vertical is optional. """ if src_pos_vu is None: return src_pos_vu = np.array(src_pos_vu, ndmin=2) if self.src_pos_xyz.shape[0] > 1 and src_pos_vu.shape[-1] > 1 and self.src_pos_xyz.shape[0] != src_pos_vu.shape[-1]: raise ValueError( f"Current number of angles ({self.src_pos_xyz.shape[-2]}) and new number of angles ({src_pos_vu.shape[-1]}) differ!" ) src_pos_y = self.src_pos_xyz[:, 1].copy() self.src_pos_xyz = np.zeros((src_pos_vu.shape[-1], 3)) self.src_pos_xyz[:, 0] = src_pos_vu[-1, :] self.src_pos_xyz[:, 1] = src_pos_y if self.ndim == 3 and src_pos_vu.shape[0] == 2: self.src_pos_xyz[:, 2] = src_pos_vu[-2, :]
[docs] def set_detector_tilt( self, angles_t_rad: ArrayLike | NDArray, tilt_axis: Sequence[float] | NDArray = (0, 1, 0), tilt_source: bool = False, ) -> None: """ Rotate the detector by the given angle(s) and axis(axes). Parameters ---------- angles_t_rad : ArrayLike | NDArray Rotation angle(s) in radians. tilt_axis : Sequence[float] | NDArray, optional The tilt axis or axes. The default is (0, 1, 0) tilt_source : bool, optional Whether to also tilt the source. The default is False. Notes ----- When applying multiple axes, they will be applied in order. This means that the application is not going to be independent. """ angles = np.array(angles_t_rad, ndmin=1)[:, None] tilt_axis = np.array(tilt_axis, ndmin=1) if tilt_axis.shape[-1] != 3: raise ValueError( f"Tilt axis/axes should be three-dimensional, along the last dimension. Current shape: {tilt_axis.shape}" ) if tilt_axis.ndim == 1: tilt_axis = tilt_axis[None, :] elif tilt_axis.ndim > 2: raise ValueError( f"Tilt axis/axes should be three-dimensional, along the last dimension. Current shape: {tilt_axis.shape}" ) elif angles.size > 1 and tilt_axis.shape[0] != angles.shape[0]: raise ValueError( "Tilt axes and tilt angles multiplicity should match. " f"Current shapes: {tilt_axis.shape = }, {angles.shape = }" ) for angle, axis in zip(angles, tilt_axis): rotations = spt.Rotation.from_rotvec(angle * axis) # type: ignore if tilt_source: self.src_pos_xyz = rotations.apply(self.src_pos_xyz) self.det_u_xyz = rotations.apply(self.det_u_xyz) self.det_v_xyz = rotations.apply(self.det_v_xyz) self.det_pos_xyz = rotations.apply(self.det_pos_xyz)
[docs] def rotate(self, angles_w_rad: ArrayLike, patch_astra_2d: bool = False) -> "ProjectionGeometry": """ Rotate the geometry by the given angle(s). Parameters ---------- angles_w_rad : ArrayLike Rotation angle(s) in radians. Returns ------- ProjectionGeometry The rotated geometry. """ angles = np.array(angles_w_rad, ndmin=1)[:, None] # Dealing with ASTRA's incoherent 2D and 3D coordinate systems. if patch_astra_2d and self.ndim == 2: angles = -angles rotations = spt.Rotation.from_rotvec(angles * self.rot_dir_xyz) # type: ignore return dc_replace( self, src_pos_xyz=rotations.apply(self.src_pos_xyz), det_pos_xyz=rotations.apply(self.det_pos_xyz), det_u_xyz=rotations.apply(self.det_u_xyz), det_v_xyz=rotations.apply(self.det_v_xyz), )
[docs] def get_field_scaled(self, field_name: str) -> NDArray: """ Return the a field content, scaled by the pix2vox ratio. Parameters ---------- field_name : str Name of the field to access. Returns ------- NDArray The scaled field. """ field_value = getattr(self, field_name) / self.pix2vox_ratio if self.geom_type.lower() != "cone" and int(self.geom_type[-2]) == 2: return field_value[:, :-1] else: return field_value
[docs] def project_displacement_to_detector(self, disp_zyx: ArrayLike) -> NDArray: """Project a given displacement vector in the volume coordinates, over the detector. Parameters ---------- disp_zyx : ArrayLike The displacement vector in volume coordinates. Returns ------- NDArray The projection on u (and if applicable v) coordinates. Raises ------ ValueError When projection geometry and vector dimensions don match. """ disp_zyx = np.array(disp_zyx, ndmin=1) disp_dims = len(disp_zyx) if self.ndim != disp_dims: raise ValueError(f"Geometry is {self.ndim}d, while passed displacement is {disp_dims}d.") disp_xyz = np.flip(disp_zyx) if self.ndim == 2: return self.det_u_xyz[..., : self.ndim].dot(disp_xyz) else: return np.stack( [self.det_v_xyz[..., : self.ndim].dot(disp_xyz), self.det_u_xyz[..., : self.ndim].dot(disp_xyz)], axis=0 )
[docs] def get_pre_weights(self, det_shape_vu: Sequence[int] | NDArray | None = None) -> NDArray | None: """Compute the pre-weights of the projector geometry (notably for cone-beam geometries). Parameters ---------- det_shape_vu : Sequence[int] | NDArray | None, optional Shape of the detector in [V]U coordinates, by default None Returns ------- NDArray | None The computed detector weights """ if self.geom_type != "cone": return None else: if det_shape_vu is None: if self.det_shape_vu is None: print("WARNING: pre-weights cannot be computed because detector shape is None.") return None else: det_shape_vu = self.det_shape_vu det_shape_vu = np.array(det_shape_vu, dtype=int) if self.det_shape_vu is not None and np.any(det_shape_vu != self.det_shape_vu): print("WARNING: overriding the detector shape in the computation of the pre-weights.") src2det_xyz = self.det_pos_xyz + self.src_pos_xyz pixel_coords_vu = [np.linspace(-s / 2, s / 2, int(s)) for s in det_shape_vu] pixel_coords_vu = np.meshgrid(*pixel_coords_vu, indexing="ij") pixel_coords_vu = [coords[..., None, None] for coords in pixel_coords_vu] pixel_coords_xyz = pixel_coords_vu[-1] * self.det_u_xyz if len(pixel_coords_vu) > 1: pixel_coords_xyz += pixel_coords_vu[-2] * self.det_v_xyz src2pixel_dict = np.linalg.norm(pixel_coords_xyz + src2det_xyz, axis=-1) src2det_dist = np.linalg.norm(src2det_xyz, axis=-1) pre_weights = src2det_dist / (src2pixel_dict + (src2pixel_dict == 0)) return pre_weights.swapaxes(-2, -1)
[docs] @dataclass class VolumeGeometry(Geometry): """Store the volume geometry.""" _vol_shape_xyz: NDArray vox_size: float = 1.0
[docs] def __post_init__(self): """Initialize the input parameters.""" self._vol_shape_xyz = np.array(self._vol_shape_xyz, ndmin=1)
[docs] def is_square(self) -> bool: """Compute whether the volume is square in XY. Returns ------- bool True is the volume is square in XY. """ return self._vol_shape_xyz[0] == self._vol_shape_xyz[1]
@property def shape_xyz(self) -> NDArray: """ Return the volume shape (XYZ). Returns ------- NDArray Shape of the volume (XYZ). """ return self._vol_shape_xyz @property def shape_zxy(self) -> NDArray: """ Return the volume shape (ZXY). The swap between X and Y is imposed by the astra-toolbox. Returns ------- NDArray Shape of the volume (ZXY). """ vol_shape_zyx = np.flip(self._vol_shape_xyz) return np.array([*vol_shape_zyx[:-2], vol_shape_zyx[-1], vol_shape_zyx[-2]], dtype=int) @property def mask_shape(self) -> NDArray: """Return the XY volume shape for circular masks. Returns ------- NDArray Shape of the XY volume. """ return self.shape_xyz[:2] @property def extent(self) -> Sequence[float]: """ Return extent of the volume. Returns ------- Sequence[float] The extent of the volume [-x, +x, -y, +y, [-z, +z]]. """ half_size_xyz = self._vol_shape_xyz * self.vox_size / 2 return [hs * sign for hs in half_size_xyz for sign in [-1, +1]]
[docs] def is_3D(self) -> bool: """ Tell whether this is a 3D geometry. Returns ------- bool Whether this is a 3D geometry or not. """ return len(self._vol_shape_xyz) == 3 and self._vol_shape_xyz[-1] > 1
[docs] def get_3d(self) -> "VolumeGeometry": """Return the 3D version of the geometry. Returns ------- VolumeGeometry The new geometry. """ if len(self._vol_shape_xyz) == 2: return dc_replace(self, _vol_shape_xyz=np.concatenate((self._vol_shape_xyz, [1]))) else: return dc_replace(self)
[docs] @staticmethod def get_default_from_data(data: NDArray, data_format: str = "dvwu") -> "VolumeGeometry": """ Generate a default volume geometry from the data shape. Parameters ---------- data : NDArray The data. data_format : str, optional The ordering and meaning of the dimensions in the data. The deault is "dvwu". Returns ------- VolumeGeometry The default volume geometry. """ return get_vol_geom_from_data(data=data, data_format=data_format)
[docs] @staticmethod def get_default_from_volume(volume: NDArray) -> "VolumeGeometry": """ Generate a default volume geometry from the given volume. Parameters ---------- volume : NDArray The volume. Returns ------- VolumeGeometry The default volume geometry. """ return get_vol_geom_from_volume(volume=volume)
[docs] def combine_shifts_vu(shifts_v: NDArray, shifts_u: NDArray) -> NDArray: """Combine vertical and horizontal shifts. Parameters ---------- shifts_v : NDArray The vertical shifts shifts_u : NDArray The horizontal shifts Returns ------- NDArray The combined shifts """ if np.sum(np.array(shifts_v.shape) > 1) > 1: raise ValueError(f"Expected 1-dimensional array for vertical shifts, but an array {shifts_v.shape = } was passed") if np.sum(np.array(shifts_u.shape) > 1) > 1: raise ValueError(f"Expected 1-dimensional array for horizontal shifts, but an array {shifts_u.shape = } was passed") if shifts_v.size != shifts_u.size: raise ValueError(f"Number of vertical shifts ({shifts_v.size}) and horizontal shifts ({shifts_u.size}) should match") return np.stack([np.squeeze(shifts_v), np.squeeze(shifts_u)], axis=-2)
[docs] def get_rot_axis_dir(rot_axis_dir: str | ArrayLike | NDArray = "clockwise") -> NDArray: """Process the requested rotation axis direction and return a meaningful value. Parameters ---------- rot_axis_dir : str | ArrayLike | NDArray, optional The requested direction, by default "clockwise" Returns ------- NDArray The vector corresponding to the rotation direction. Raises ------ ValueError In case of malformed direction. """ if isinstance(rot_axis_dir, str): if rot_axis_dir.lower() not in ROT_DIRS_VALID: raise ValueError(f"Rotation axis direction {rot_axis_dir} not allowed. It should be one of: {ROT_DIRS_VALID}") if rot_axis_dir.lower() == "clockwise": return np.array([0.0, 0.0, -1.0]) else: return np.array([0.0, 0.0, 1.0]) else: return np.array(rot_axis_dir, ndmin=1)
[docs] def _get_data_dims(data_shape: Sequence[int] | NDArray, data_format: str = "dvwu") -> dict[str, int | None]: dims: dict[str, int | None] = dict(u=None, v=None, w=None, d=None) for ii in range(-len(data_shape), 0): dims[data_format[ii]] = data_shape[ii] return dims
[docs] def get_prj_geom_parallel( *, geom_type: str = "3d", rot_axis_shift_pix: ArrayLike | NDArray | None = None, rot_axis_dir: str | ArrayLike | NDArray = "clockwise", data_shape: Sequence[int] | NDArray | None = None, data_format: str = "dvwu", ) -> ProjectionGeometry: """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : ArrayLike | NDArray | None, optional Rotation axis shift in pixels. The default is None. rot_axis_dir : str | ArrayLike | NDArray, optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default parallel-beam geometry. """ geom_type = geom_type.lower() prj_geom = ProjectionGeometry( geom_type="parallel" + geom_type, src_pos_xyz=np.array([0.0, -1.0, 0.0]), det_pos_xyz=np.zeros(3), det_u_xyz=np.array([1.0, 0.0, 0.0]), det_v_xyz=np.array([0.0, 0.0, 1.0]), rot_dir_xyz=get_rot_axis_dir(rot_axis_dir), ) if rot_axis_shift_pix is not None: rot_axis_shift_pix = np.array(rot_axis_shift_pix) if rot_axis_shift_pix.size == 1: prj_geom.set_detector_shifts_vu(cor_pos_u=float(rot_axis_shift_pix)) else: prj_geom.set_detector_shifts_vu(det_pos_vu=rot_axis_shift_pix) if data_shape is not None: data_dims = _get_data_dims(data_shape, data_format) if data_dims["u"] is None: raise ValueError( "Could not determine data dimensions. Coordinate U cannot be undetermined." f" Data shape: {data_shape}, data format: {data_format}" ) if geom_type == "3d": if data_dims["v"] is None: raise ValueError( "Could not determine data dimensions. Coordinate V cannot be undetermined in a 3D geometry." f" Data shape: {data_shape}, data format: {data_format}" ) prj_geom.set_detector_shape_vu([data_dims["v"], data_dims["u"]]) else: prj_geom.set_detector_shape_vu([data_dims["u"]]) return prj_geom
[docs] def get_prj_geom_cone( *, src_to_sam_dist: float, rot_axis_shift_pix: ArrayLike | NDArray | None = None, rot_axis_dir: str | ArrayLike | NDArray = "clockwise", data_shape: Sequence[int] | NDArray | None = None, data_format: str = "dvwu", ) -> ProjectionGeometry: """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : ArrayLike | NDArray | None, optional Rotation axis shift in pixels. The default is None. rot_axis_dir : str | ArrayLike | NDArray, optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default cone-beam geometry. """ prj_geom = ProjectionGeometry( geom_type="cone", src_pos_xyz=np.array([0.0, -src_to_sam_dist, 0.0]), det_pos_xyz=np.zeros(3), det_u_xyz=np.array([1.0, 0.0, 0.0]), det_v_xyz=np.array([0.0, 0.0, 1.0]), rot_dir_xyz=get_rot_axis_dir(rot_axis_dir), ) if rot_axis_shift_pix is not None: rot_axis_shift_pix = np.array(rot_axis_shift_pix) if rot_axis_shift_pix.size == 1: prj_geom.set_detector_shifts_vu(cor_pos_u=float(rot_axis_shift_pix)) else: prj_geom.set_detector_shifts_vu(det_pos_vu=rot_axis_shift_pix) if data_shape is not None: data_dims = _get_data_dims(data_shape, data_format) if data_dims["v"] is None or data_dims["u"] is None: raise ValueError( "Could not determine data dimensions. Coordinates UV cannot be undetermined in a cone-beam geometry." f" Data shape: {data_shape}, data format: {data_format}" ) prj_geom.set_detector_shape_vu([data_dims["v"], data_dims["u"]]) return prj_geom
[docs] def get_vol_geom_from_data( data: NDArray, padding_u: int | Sequence[int] | NDArray = 0, data_format: str = "dvwu", super_sampling: int = 1 ) -> VolumeGeometry: """ Generate a default volume geometry from the data shape. Parameters ---------- data : NDArray The data. padding_u : int | Sequence[int], optional The amount of padding along the U direction. The default is 0. data_format : str, optional The ordering and meaning of the dimensions in the data. The default is "dvwu". super_sampling: int, optional The super-sampling size of the voxels. The default is 1. Returns ------- VolumeGeometry The default volume geometry. """ data_dims = _get_data_dims(data.shape, data_format) if data_dims["u"] is None: raise ValueError( "Could not determine data dimensions. Coordinate U cannot be undetermined." f" Data shape: {data.shape}, data format: {data_format}" ) if isinstance(padding_u, (Sequence, np.ndarray)): if len(padding_u) != 2: raise ValueError( f"Padding along U can only either be an integer or a Sequence/NDArray of 2 values. {padding_u} passed instead." ) data_dims["u"] -= padding_u[0] + padding_u[1] else: data_dims["u"] -= padding_u * 2 dims_xyz = [data_dims["u"]] * 2 if data_dims["v"] is not None: dims_xyz.append(data_dims["v"]) return VolumeGeometry(np.array(dims_xyz) * super_sampling, vox_size=1 / super_sampling)
[docs] def get_vol_geom_from_volume(volume: NDArray) -> VolumeGeometry: """ Generate a default volume geometry from the given volume. Parameters ---------- volume : NDArray The volume. Returns ------- VolumeGeometry The default volume geometry. """ vol_shape_zxy = volume.shape if len(vol_shape_zxy) < 2: raise ValueError(f"The volume should be at least 2-dimensional, but the following shape was passed: {vol_shape_zxy}") return VolumeGeometry(np.array([vol_shape_zxy[-2], vol_shape_zxy[-1], *np.flip(vol_shape_zxy[:-2])]))