Source code for corrct.models

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Define all the models used through-out the code.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

import numpy as np
from numpy.typing import ArrayLike, NDArray
from typing import Optional, Sequence, Union, Any

import scipy.spatial.transform as spt

from dataclasses import dataclass, replace as dc_replace

from abc import ABC


ROT_DIRS_VALID = ("clockwise", "counter-clockwise")


[docs]class Geometry(ABC): """Base geometry class.""" def __str__(self) -> str: """ Return a human readable representation of the object. Returns ------- str The human readable representation of the object. """ descr = f"{self.__class__.__name__}(\n" for f, v in self.__dict__.items(): descr += f" {f} = {v},\n" return descr + ")"
[docs]@dataclass class ProjectionGeometry(Geometry): """Store the projection geometry.""" geom_type: str src_pos_xyz: NDArray det_pos_xyz: NDArray det_u_xyz: NDArray det_v_xyz: NDArray rot_dir_xyz: NDArray pix2vox_ratio: float = 1 det_shape_vu: Optional[NDArray] = None def __post_init__(self) -> None: self.geom_type = self.geom_type.lower() self.src_pos_xyz = np.array(self.src_pos_xyz, ndmin=2) self.det_pos_xyz = np.array(self.det_pos_xyz, ndmin=2) self.det_u_xyz = np.array(self.det_u_xyz, ndmin=2) self.det_v_xyz = np.array(self.det_v_xyz, ndmin=2) self.rot_dir_xyz = np.array(self.rot_dir_xyz, ndmin=2) def __getitem__(self, indx: Any): """ Slice projection geometry along the angular direction. Parameters ---------- indx : Any Indices of the slicing. """ def slice_array(vecs_arr: NDArray, indx: Any): if len(vecs_arr.shape) > 1: return vecs_arr[indx, :] else: return vecs_arr return dc_replace( self, src_pos_xyz=slice_array(self.src_pos_xyz, indx), det_pos_xyz=slice_array(self.det_pos_xyz, indx), det_u_xyz=slice_array(self.det_u_xyz, indx), det_v_xyz=slice_array(self.det_v_xyz, indx), )
[docs] @staticmethod def get_default_parallel( *, geom_type: str = "3d", rot_axis_shift_pix: Optional[ArrayLike] = None, rot_axis_dir: Union[str, ArrayLike] = "clockwise", ) -> "ProjectionGeometry": """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : Optional[ArrayLike], optional Rotation axis shift in pixels. The default is None. rot_axis_dir : Union[str, ArrayLike], optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default paralle-beam geometry. """ return get_prj_geom_parallel(geom_type=geom_type, rot_axis_shift_pix=rot_axis_shift_pix, rot_axis_dir=rot_axis_dir)
@property def ndim(self) -> int: """Return the number of dimensions of the geometry. Returns ------- int The numder of dimensions. """ if "parallel" in self.geom_type: return int(self.geom_type[-2]) elif self.geom_type.lower() == "cone": return 3 elif self.geom_type.lower() == "fanflat": return 2 else: raise ValueError( f"Geometry ({self.geom_type}) needs to be one of: 'parallel2d' | 'parallel3d' | 'cone' | 'fanflat'." )
[docs] def get_3d(self) -> "ProjectionGeometry": """Return the 3D version of the geometry. Returns ------- ProjectionGeometry The new geometry. """ if self.ndim == 2: if self.det_shape_vu is not None: new_det_shape_vu = np.ones(2, dtype=int) new_det_shape_vu[-len(self.det_shape_vu) :] = self.det_shape_vu else: new_det_shape_vu = None return dc_replace(self, geom_type=self.geom_type.replace("2d", "3d"), det_shape_vu=new_det_shape_vu) else: return dc_replace(self)
[docs] def set_detector_shifts_vu( self, det_pos_vu: Union[ArrayLike, NDArray, None] = None, cor_pos_u: Union[float, None] = None, det_dist_y: ArrayLike = 0.0, ) -> None: """ Set the detector position in XZ, from VU (vertical, horizontal) coordinates. Parameters ---------- det_pos_vu : ArrayLike | NDArray | None Detector vertical and horizontal positions. Vertical is optional. cor_pos_u : float | None Center of rotation position along U. det_dist_y : ArrayLike, optional Detector distance from origin along Y. The default is 0.0. """ det_pos_vu = np.array(det_pos_vu if det_pos_vu is not None else 0.0, ndmin=2, dtype=np.float64) if cor_pos_u is not None: det_pos_vu[-1, ...] = det_pos_vu[-1, ...] + cor_pos_u det_dist_y = np.array(det_dist_y, ndmin=1, dtype=np.float64) if det_dist_y.size > 1 and ( det_dist_y.ndim > 1 or (det_pos_vu.shape[1] > 1 and det_dist_y.size != det_pos_vu.shape[1]) ): raise ValueError( f"Detector distance along Y (shape: {det_dist_y.shape}) should either be a scalar or a 1D array of the " f"same length as the detector positions (shape: {det_pos_vu.shape}), if detector positions are more than 1." ) if self.det_pos_xyz.shape[0] > 1 and det_pos_vu.shape[-1] > 1 and self.det_pos_xyz.shape[0] != det_pos_vu.shape[-1]: raise ValueError( f"Current number of angles ({self.det_pos_xyz.shape[-2]}) and new number of " f"angles ({det_pos_vu.shape[-1]}) differ!" ) self.det_pos_xyz = np.zeros((det_pos_vu.shape[-1], 3), dtype=np.float64) self.det_pos_xyz += self.det_u_xyz * det_pos_vu[-1, :].reshape([-1, 1]) if self.ndim == 3 and det_pos_vu.shape[0] == 2: self.det_pos_xyz += self.det_v_xyz * det_pos_vu[-2, :].reshape([-1, 1]) self.det_pos_xyz[:, 1] += det_dist_y
[docs] def set_source_shifts_vu(self, src_pos_vu: Union[ArrayLike, NDArray, None] = None) -> None: """ Set the source position in XZ, from VU (vertical, horizontal) coordinates. Parameters ---------- src_pos_vu : ArrayLike | NDArray | None Source vertical and horizontal positions. Vertical is optional. """ if src_pos_vu is None: return src_pos_vu = np.array(src_pos_vu, ndmin=2) if self.src_pos_xyz.shape[0] > 1 and src_pos_vu.shape[-1] > 1 and self.src_pos_xyz.shape[0] != src_pos_vu.shape[-1]: raise ValueError( f"Current number of angles ({self.src_pos_xyz.shape[-2]}) and new number of angles ({src_pos_vu.shape[-1]}) differ!" ) src_pos_y = self.src_pos_xyz[:, 1].copy() self.src_pos_xyz = np.zeros((src_pos_vu.shape[-1], 3)) self.src_pos_xyz[:, 0] = src_pos_vu[-1, :] self.src_pos_xyz[:, 1] = src_pos_y if self.ndim == 3 and src_pos_vu.shape[0] == 2: self.src_pos_xyz[:, 2] = src_pos_vu[-2, :]
[docs] def set_detector_tilt( self, angles_t_rad: ArrayLike, tilt_axis: Union[Sequence[float], NDArray] = (0, 1, 0), tilt_source: bool = False ) -> None: """ Rotate the detector by the given angle(s) and axis(axes). Parameters ---------- angles_t_rad : ArrayLike Rotation angle(s) in radians. tilt_axis : Sequence[float] | NDArray, optional The tilt axis or axes. The default is (0, 1, 0) tilt_source : bool, optional Whether to also tilt the source. The default is False. """ angles = np.array(angles_t_rad, ndmin=1)[:, None] tilt_axis = np.array(tilt_axis, ndmin=1) if tilt_axis.shape[-1] != 3: raise ValueError( f"Tilt axis/axes should be three-dimensional, along the last dimension. Current shape: {tilt_axis.shape}" ) if tilt_axis.ndim == 1: tilt_axis = tilt_axis[None, :] elif tilt_axis.ndim > 2: raise ValueError( f"Tilt axis/axes should be three-dimensional, along the last dimension. Current shape: {tilt_axis.shape}" ) elif angles.size > 1 and tilt_axis.shape[0] != angles.shape[0]: raise ValueError( "Tilt axes and tilt angles multiplicity should match. " f"Current shapes: {tilt_axis.shape = }, {angles.shape = }" ) rotations = spt.Rotation.from_rotvec(angles * tilt_axis) # type: ignore if tilt_source: self.src_pos_xyz = rotations.apply(self.src_pos_xyz) self.det_u_xyz = rotations.apply(self.det_u_xyz) self.det_v_xyz = rotations.apply(self.det_v_xyz) self.det_pos_xyz = rotations.apply(self.det_pos_xyz)
[docs] def rotate(self, angles_w_rad: ArrayLike, patch_astra_2d: bool = False) -> "ProjectionGeometry": """ Rotate the geometry by the given angle(s). Parameters ---------- angles_w_rad : ArrayLike Rotation angle(s) in radians. Returns ------- ProjectionGeometry The rotated geometry. """ angles = np.array(angles_w_rad, ndmin=1)[:, None] # Deadling with ASTRA's incoherent 2D and 3D coordinate systems. if patch_astra_2d and self.ndim == 2: angles = -angles rotations = spt.Rotation.from_rotvec(angles * self.rot_dir_xyz) # type: ignore return dc_replace( self, src_pos_xyz=rotations.apply(self.src_pos_xyz), det_pos_xyz=rotations.apply(self.det_pos_xyz), det_u_xyz=rotations.apply(self.det_u_xyz), det_v_xyz=rotations.apply(self.det_v_xyz), )
[docs] def get_field_scaled(self, field_name: str) -> NDArray: """ Return the a field content, scaled by the pix2vox ratio. Parameters ---------- field_name : str Name of the field to access. Returns ------- NDArray The scaled field. """ field_value = getattr(self, field_name) / self.pix2vox_ratio if self.geom_type.lower() != "cone" and int(self.geom_type[-2]) == 2: return field_value[:, :-1] else: return field_value
[docs] def project_displacement_to_detector(self, disp_zyx: ArrayLike) -> NDArray: """Project a given displacement vector in the volume coordinates, over the detector. Parameters ---------- disp_zyx : ArrayLike The displacement vector in volume coordinates. Returns ------- NDArray The projection on u (and if applicable v) coordinates. Raises ------ ValueError When projection geometry and vector dimensions don match. """ disp_zyx = np.array(disp_zyx, ndmin=1) disp_dims = len(disp_zyx) if self.ndim != disp_dims: raise ValueError(f"Geometry is {self.ndim}d, while passed displacement is {disp_dims}d.") disp_xyz = np.flip(disp_zyx) if self.ndim == 2: return self.det_u_xyz[..., : self.ndim].dot(disp_xyz) else: return np.stack( [self.det_v_xyz[..., : self.ndim].dot(disp_xyz), self.det_u_xyz[..., : self.ndim].dot(disp_xyz)], axis=0 )
[docs] def get_pre_weights(self, det_shape_vu: Union[Sequence[int], NDArray, None] = None) -> Union[NDArray, None]: """Compute the pre-weights of the projector geometry (notably for cone-beam geometries). Parameters ---------- det_shape_vu : Sequence[int] | NDArray | None, optional Shape of the detector in [V]U coordinates, by default None Returns ------- NDArray | None The computed detector weights """ if self.geom_type != "cone": return None else: if det_shape_vu is None: if self.det_shape_vu is None: print("WARNING: pre-weights cannot be computed because detector shape is None.") return None else: det_shape_vu = self.det_shape_vu det_shape_vu = np.array(det_shape_vu, dtype=int) if self.det_shape_vu is not None and np.any(det_shape_vu != self.det_shape_vu): print("WARNING: overriding the detector shape in the computation of the pre-weights.") src2det_xyz = self.det_pos_xyz + self.src_pos_xyz pixel_coords_vu = [np.linspace(-s / 2, s / 2, int(s)) for s in det_shape_vu] pixel_coords_vu = np.meshgrid(*pixel_coords_vu, indexing="ij") pixel_coords_vu = [coords[..., None, None] for coords in pixel_coords_vu] pixel_coords_xyz = pixel_coords_vu[-1] * self.det_u_xyz if len(pixel_coords_vu) > 1: pixel_coords_xyz += pixel_coords_vu[-2] * self.det_v_xyz src2pixel_dict = np.linalg.norm(pixel_coords_xyz + src2det_xyz, axis=-1) src2det_dist = np.linalg.norm(src2det_xyz, axis=-1) pre_weights = src2det_dist / (src2pixel_dict + (src2pixel_dict == 0)) return pre_weights.swapaxes(-2, -1)
[docs]@dataclass class VolumeGeometry(Geometry): """Store the volume geometry.""" _vol_shape_xyz: NDArray vox_size: float = 1.0 def __post_init__(self): """Initialize the input parameters.""" self._vol_shape_xyz = np.array(self._vol_shape_xyz, ndmin=1)
[docs] def is_square(self) -> bool: """Compute whether the volume is square in XY. Returns ------- bool True is the volume is square in XY. """ return self._vol_shape_xyz[0] == self._vol_shape_xyz[1]
@property def shape_xyz(self) -> NDArray: """ Return the volume shape (XYZ). Returns ------- NDArray Shape of the volume (XYZ). """ return self._vol_shape_xyz @property def shape_zxy(self) -> NDArray: """ Return the volume shape (ZXY). The swap between X and Y is imposed by the astra-toolbox. Returns ------- NDArray Shape of the volume (ZXY). """ vol_shape_zyx = np.flip(self._vol_shape_xyz) return np.array([*vol_shape_zyx[:-2], vol_shape_zyx[-1], vol_shape_zyx[-2]], dtype=int) @property def mask_shape(self) -> NDArray: """Return the XY volume shape for circular masks. Returns ------- NDArray Shape of the XY volume. """ return self.shape_xyz[:2] @property def extent(self) -> Sequence[float]: """ Return extent of the volume. Returns ------- Sequence[float] The extent of the volume [-x, +x, -y, +y, [-z, +z]]. """ half_size_xyz = self._vol_shape_xyz * self.vox_size / 2 return [hs * sign for hs in half_size_xyz for sign in [-1, +1]]
[docs] def is_3D(self) -> bool: """ Tell whether this is a 3D geometry. Returns ------- bool Whether this is a 3D geometry or not. """ return len(self._vol_shape_xyz) == 3 and self._vol_shape_xyz[-1] > 1
[docs] def get_3d(self) -> "VolumeGeometry": """Return the 3D version of the geometry. Returns ------- VolumeGeometry The new geometry. """ if len(self._vol_shape_xyz) == 2: return dc_replace(self, _vol_shape_xyz=np.concatenate((self._vol_shape_xyz, [1]))) else: return dc_replace(self)
[docs] @staticmethod def get_default_from_data(data: NDArray, data_format: str = "dvwu") -> "VolumeGeometry": """ Generate a default volume geometry from the data shape. Parameters ---------- data : NDArray The data. data_format : str, optional The ordering and meaning of the dimensions in the data. The deault is "dvwu". Returns ------- VolumeGeometry The default volume geometry. """ return get_vol_geom_from_data(data=data, data_format=data_format)
[docs] @staticmethod def get_default_from_volume(volume: NDArray) -> "VolumeGeometry": """ Generate a default volume geometry from the given volume. Parameters ---------- volume : NDArray The volume. Returns ------- VolumeGeometry The default volume geometry. """ return get_vol_geom_from_volume(volume=volume)
[docs]def combine_shifts_vu(shifts_v: NDArray, shifts_u: NDArray) -> NDArray: """Combine vertical and horizontal shifts. Parameters ---------- shifts_v : NDArray The vertical shifts shifts_u : NDArray The horizontal shifts Returns ------- NDArray The combined shifts """ if np.sum(np.array(shifts_v.shape) > 1) > 1: raise ValueError(f"Expected 1-dimensional array for vertical shifts, but an array {shifts_v.shape = } was passed") if np.sum(np.array(shifts_u.shape) > 1) > 1: raise ValueError(f"Expected 1-dimensional array for horizontal shifts, but an array {shifts_u.shape = } was passed") if shifts_v.size != shifts_u.size: raise ValueError(f"Number of vertical shifts ({shifts_v.size}) and horizontal shifts ({shifts_u.size}) should match") return np.stack([np.squeeze(shifts_v), np.squeeze(shifts_u)], axis=-2)
[docs]def get_rot_axis_dir(rot_axis_dir: Union[str, ArrayLike, NDArray] = "clockwise") -> NDArray: """Process the requested rotation axis direction and return a meaningful value. Parameters ---------- rot_axis_dir : Union[str, ArrayLike, NDArray], optional The requested direction, by default "clockwise" Returns ------- NDArray The vector corresponding to the rotation direction. Raises ------ ValueError In case of malformed direction. """ if isinstance(rot_axis_dir, str): if rot_axis_dir.lower() not in ROT_DIRS_VALID: raise ValueError(f"Rotation axis direction {rot_axis_dir} not allowed. It should be one of: {ROT_DIRS_VALID}") if rot_axis_dir.lower() == "clockwise": return np.array([0.0, 0.0, -1.0]) else: return np.array([0.0, 0.0, 1.0]) else: return np.array(rot_axis_dir, ndmin=1)
[docs]def get_prj_geom_parallel( *, geom_type: str = "3d", rot_axis_shift_pix: Union[ArrayLike, NDArray, None] = None, rot_axis_dir: Union[str, ArrayLike, NDArray] = "clockwise", ) -> ProjectionGeometry: """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : ArrayLike | NDArray | None, optional Rotation axis shift in pixels. The default is None. rot_axis_dir : str | ArrayLike | NDArray, optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default paralle-beam geometry. """ if rot_axis_shift_pix is None: det_pos_xyz = np.zeros(3) else: rot_axis_shift_pix = np.array(rot_axis_shift_pix, ndmin=1) det_pos_xyz = np.concatenate([rot_axis_shift_pix[:, None], np.zeros((len(rot_axis_shift_pix), 2))], axis=-1) return ProjectionGeometry( geom_type="parallel" + geom_type, src_pos_xyz=np.array([0.0, -1.0, 0.0]), det_pos_xyz=det_pos_xyz, det_u_xyz=np.array([1.0, 0.0, 0.0]), det_v_xyz=np.array([0.0, 0.0, 1.0]), rot_dir_xyz=get_rot_axis_dir(rot_axis_dir), )
[docs]def get_prj_geom_cone( *, src_to_sam_dist: float, rot_axis_shift_pix: Union[ArrayLike, NDArray, None] = None, rot_axis_dir: Union[str, ArrayLike, NDArray] = "clockwise", ) -> ProjectionGeometry: """ Generate the default geometry for parallel beam. Parameters ---------- geom_type : str, optional The geometry type. The default is "parallel3d". rot_axis_shift_pix : ArrayLike | NDArray | None, optional Rotation axis shift in pixels. The default is None. rot_axis_dir : str | ArrayLike | NDArray, optional Rotation axis direction. It can be either a string or a direction. The default is "clockwise". Returns ------- ProjectionGeometry The default paralle-beam geometry. """ if rot_axis_shift_pix is None: det_pos_xyz = np.zeros(3) else: rot_axis_shift_pix = np.array(rot_axis_shift_pix, ndmin=1) det_pos_xyz = np.concatenate([rot_axis_shift_pix[:, None], np.zeros((len(rot_axis_shift_pix), 2))], axis=-1) return ProjectionGeometry( geom_type="cone", src_pos_xyz=np.array([0.0, -src_to_sam_dist, 0.0]), det_pos_xyz=det_pos_xyz, det_u_xyz=np.array([1.0, 0.0, 0.0]), det_v_xyz=np.array([0.0, 0.0, 1.0]), rot_dir_xyz=get_rot_axis_dir(rot_axis_dir), )
[docs]def get_vol_geom_from_data(data: NDArray, data_format: str = "dvwu") -> VolumeGeometry: """ Generate a default volume geometry from the data shape. Parameters ---------- data : NDArray The data. data_format : str, optional The ordering and meaning of the dimensions in the data. The deault is "dvwu". Returns ------- VolumeGeometry The default volume geometry. """ dims = dict(u=[], v=[], w=[], d=[]) for ii in range(-len(data.shape), 0): dims[data_format[ii]] = [data.shape[ii]] return VolumeGeometry([*(dims["u"] * 2), *dims["v"]])
[docs]def get_vol_geom_from_volume(volume: NDArray) -> VolumeGeometry: """ Generate a default volume geometry from the given volume. Parameters ---------- volume : NDArray The volume. Returns ------- VolumeGeometry The default volume geometry. """ vol_shape_zxy = volume.shape if len(vol_shape_zxy) < 2: raise ValueError(f"The volume should be at least 2-dimensional, but the following shape was passed: {vol_shape_zxy}") return VolumeGeometry([vol_shape_zxy[-2], vol_shape_zxy[-1], *np.flip(vol_shape_zxy[:-2])])