Source code for corrct.alignment.shifts

"""
Detector shifts finding classes.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Mapping
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from tqdm.auto import tqdm

from . import fitting

eps = np.finfo(np.float32).eps


NDArrayFloat = NDArray[np.floating]


def _filter_shifts(shifts_vu: NDArrayFloat, max_shifts: NDArrayFloat) -> NDArrayFloat:
    invalid_shifts = np.abs(shifts_vu) > max_shifts

    shifts_vu_filt = shifts_vu.copy()
    if np.any(invalid_shifts):
        print(f"WARNING - Some shifts exceeded the maximum allowed magnitude (max: {max_shifts}). Setting them to 0.")
        shifts_vu_filt[invalid_shifts] = 0

    return shifts_vu_filt


[docs] class DetectorShiftsBase: """Compute the detector shifts for a given dataset.""" data_vwu: NDArrayFloat angles_rad: NDArrayFloat def __init__( self, data_dvwu: NDArrayFloat, rot_angle_rad: Union[ArrayLike, NDArrayFloat], *, data_format: str = "dvwu", data_mask_dvwu: Optional[NDArray] = None, borders_dvwu: dict = {"d": None, "v": None, "w": None, "u": None}, max_shifts: Union[float, NDArrayFloat, None] = None, precision_decimals: int = 2, verbose: bool = True, ): """Initialize the base class for detector shifts. Parameters ---------- data_dvwu : NDArrayFloat The tomographic data. rot_angle_rad : ArrayLike | NDArrayFloat The rotation angles in radians. data_format : str, optional The data organization, by default "dvwu" data_mask_dvwu : NDArray | None, optional The mask of the tomographic data, by default None borders_dvwu : dict, optional The borders of the tomographic data, by default {"d": None, "v": None, "w": None, "u": None} max_shifts : float | NDArrayFloat | None, optional Maximum shifts allowed, by default None precision_decimals : int, optional The precision of the results, by default 2 verbose : bool, optional Whether to be verbose, by default True Raises ------ ValueError Raised when passing incoherent data and angles. """ self.data_dims = len(data_dvwu.shape) self.data_shapes = dict(u=0, v=1, w=0, d=1) for ii in range(-self.data_dims, 0): self.data_shapes[data_format[ii]] = data_dvwu.shape[ii] self.num_dets = self.data_shapes["d"] self.angles_rad = np.array(np.squeeze(rot_angle_rad), ndmin=1, dtype=np.float32) if self.data_shapes["w"] != len(self.angles_rad): raise ValueError( f"Mismatch between rotation angles ({len(self.angles_rad)})," f" and number of projections ({self.data_shapes['w']})." ) slicing = [borders_dvwu[data_format[ii]] for ii in range(-self.data_dims, 0)] slicing = [slice(s, -s) if s is not None and s > 0 else slice(None) for s in slicing] self.slicing = tuple(slicing) if self.data_shapes["v"] > 1: self._align_coords = np.array([-3, -1]) else: self._align_coords = np.array([-1]) if max_shifts is None: max_shifts = np.array(data_dvwu.shape)[self._align_coords] / 2 max_shifts = np.array(max_shifts, ndmin=1) self.max_shifts = max_shifts[:, None] self.decimals = precision_decimals self.verbose = verbose self.data_vwu = data_dvwu self.data_mask_vwu = data_mask_dvwu
[docs] class DetectorShiftsPRE(DetectorShiftsBase): """Compute the pre-alignment detector shifts for a given dataset."""
[docs] def fit_v( self, use_derivative: bool = True, use_rfft: bool = True, normalize_fourier: bool = True, ) -> NDArrayFloat: """Compute the pre-alignment vertical shifts of a 3D dataset. The pre-alignment shifts are computed by cross-correlation of one projection against the others. The projections are integrated in the horizontal direction. In the vertical direction, it is suggested to use some high pass filter. The default option is to use the derivates of the intensity profiles. Parameters ---------- use_derivative : bool, optional Whether to use the derivate of the vertical profile, by default True use_rfft : bool, optional Whether to use the `rfft` transform for the cross-correlation, by default True normalize_fourier : bool, optional Whether to normalize the cross-correlation in Fourier space, by default True Returns ------- NDArrayFloat The vertical shifts. Raises ------ ValueError If the dataset is 2D. """ if self.data_shapes["v"] <= 1: raise ValueError("Vertical alignment not supported for 2D reconstructions.") if use_rfft: local_fft = np.fft.rfft local_ifft = np.fft.irfft else: local_fft = np.fft.fft local_ifft = np.fft.ifft data_vwu = self.data_vwu[self.slicing] if self.data_mask_vwu is None: data_vw = np.mean(data_vwu, axis=-1) else: data_mask_vwu = self.data_mask_vwu[self.slicing].astype(data_vwu.dtype) data_vwu = data_vwu * data_mask_vwu mask_sum_vw = data_mask_vwu.sum(axis=-1) data_vw = data_vwu.sum(axis=-1) / (mask_sum_vw + (mask_sum_vw == 0)) if self.num_dets > 1: data_vw = np.mean(data_vw, axis=-3) mins = data_vw.min(axis=-2, keepdims=True) maxs = data_vw.max(axis=-2, keepdims=True) data_vw_d = (data_vw - mins) / (maxs - mins) pad_size = ((data_vw.shape[-2] // 2,), (0,)) if use_derivative: data_vw_d = np.diff(data_vw_d, axis=-2) data_vw_p = np.pad(data_vw_d, pad_width=pad_size, mode="constant") else: data_vw_p = np.pad(data_vw_d, pad_width=pad_size, mode="linear_ramp") data_vw_f = local_fft(data_vw_p, axis=-2) ref_angle = len(self.angles_rad) // 2 ccs_f = data_vw_f[:, [ref_angle]] * data_vw_f.conj() if normalize_fourier: ccs_f /= np.fmax(np.abs(ccs_f).max(axis=-2, keepdims=True), eps) cross_corr = local_ifft(ccs_f, axis=-2).real cc_coords = np.fft.fftfreq(cross_corr.shape[-2], 1 / cross_corr.shape[-2]) f_vals, fc_ax = fitting.extract_peak_regions_1d(cross_corr, axis=-2, cc_coords=cc_coords) shifts_v = fitting.refine_max_position_1d(f_vals, decimals=self.decimals) + fc_ax[1, :] shifts_v = _filter_shifts(shifts_v, self.max_shifts[0, :]) shifts_v -= np.mean(shifts_v) shifts_v = np.around(shifts_v, decimals=self.decimals) if self.verbose: fig, axs = plt.subplots(2, 2, figsize=[10, 5], sharex=True) axs[0, 0].imshow(data_vw) axs[0, 0].set_title("Data VW") axs[0, 0].set_xlabel("Coord. W (angular)") axs[0, 0].set_ylabel("Coord. V (vertical)") axs[0, 1].plot(shifts_v) axs[0, 1].plot(np.zeros_like(shifts_v)) axs[0, 1].grid() axs[0, 1].set_title(f"Shifts V (wrt angle n.{ref_angle})") axs[0, 1].set_xlabel("Coord. W (angular)") axs[0, 1].set_ylabel("Coord. V (vertical)") axs[1, 0].imshow(data_vw_p) axs[1, 0].set_title("Data used for cross-correlation") axs[1, 1].imshow(np.fft.fftshift(cross_corr, axes=(-2,))) axs[1, 1].set_title("Cross-correlation") fig.tight_layout() plt.show(block=False) return shifts_v
[docs] def fit_u( self, fit_l1: bool = False, background: Union[float, NDArray, None] = None, method: str = "com", ) -> tuple[NDArrayFloat, float]: """Compute the pre-alignment shifts for the horizontal dimension. The pre-alignment shifts, and center-of-rotation (CoR) are computed by fitting a sinusoid to the centers of mass of each angle in the sinogram. The bias of the sinusoid corresponds to the CoR, while the deviations from the fitted curve correspond to the shifts. Parameters ---------- fit_l1 : bool, optional Computes the l1-min fit of the sinusoid, by default False. background : float | NDArray | None, optional Removes the given background, by default None. method : str, optional The method used for the identification of the fiducial marker position. Options are "com" (center-of-mass) | "max" (maximum value), by default "com". Returns ------- Tuple[NDArrayFloat, float] The shifts and the CoR. """ is_3d = self.data_shapes["v"] > 1 data_vwu = self.data_vwu[self.slicing] if background is not None: data_vwu = data_vwu - background data_vwu = np.fmax(data_vwu, 0.0) if self.num_dets > 1: data_vwu = np.mean(data_vwu, axis=0) if is_3d: data_vwu = np.mean(data_vwu, axis=-3) fx_half_size = (data_vwu.shape[-1] - 1) / 2 if method.lower() == "com": fx = np.linspace(-fx_half_size, fx_half_size, data_vwu.shape[-1]) ref_points = -np.sum(data_vwu * fx, axis=-1) / np.sum(data_vwu, axis=-1) elif method.lower() == "max": ref_points = fx_half_size - np.argmax(data_vwu, axis=-1) else: raise ValueError(f"Unkown selected method {method}. Please choose one among: 'com' | 'max'") a, p, b = fitting.fit_sinusoid(self.angles_rad, ref_points, fit_l1=fit_l1) cor = np.around(b, decimals=self.decimals) if self.verbose: angles_deg = np.rad2deg(self.angles_rad) sort_angles_deg = np.sort(angles_deg) sort_angles_rad = np.sort(self.angles_rad) fig, axs = plt.subplots(1, 2, figsize=(10, 5)) axs[0].scatter(angles_deg, ref_points, label="Centers of mass") axs[0].plot(sort_angles_deg, fitting.sinusoid(sort_angles_rad, a, p, b), c="C1", label="Fitted sinusoid") axs[0].plot(sort_angles_deg, np.ones_like(sort_angles_deg) * b, c="C2", label="Bias (CoR)") axs[0].legend() axs[0].grid() axs[0].set_xlabel("Coord. W (angular)") axs[0].set_ylabel("Coord. U (horizontal)") axs[0].xaxis.label.set_fontsize(16) axs[0].yaxis.label.set_fontsize(16) axs[1].imshow(data_vwu) axs[1].scatter(-ref_points + fx_half_size, np.arange(len(ref_points)), c="C1") axs[1].set_xlabel("Coord. U (horizontal)") axs[1].set_ylabel("Coord. W (angular)") axs[1].xaxis.label.set_fontsize(16) axs[1].yaxis.label.set_fontsize(16) fig.tight_layout() plt.show(block=False) print(f"amplitude = {a} (pix)") print(f"phase = {p} (rad)") print(f"bias = {b} (pix)") print(f" -> cor = {cor} (pix)") shifts_u: NDArrayFloat = ref_points - fitting.sinusoid(self.angles_rad, a, p, b) shifts_u = _filter_shifts(shifts_u, self.max_shifts[-1, :]) shifts_u = np.around(shifts_u, decimals=self.decimals) return shifts_u, float(cor)
[docs] class DetectorShiftsXC(DetectorShiftsBase): """Compute the center-of-rotation for a given dataset, by cross correlation."""
[docs] def fit_vu_accum_drifts(self, ref_data_dvwu: Optional[NDArrayFloat] = None) -> NDArray: """Fit static image drifts. Parameters ---------- ref_data_dvwu : Optional[NDArrayFloat], optional Reference image, by default None. If None, the first image in the data stack will be used. Returns ------- NDArray The shifts of the image stack. Raises ------ ValueError When the number of reference images is either too many or not enough. """ if ref_data_dvwu is None: ref_data_dvwu = self.data_vwu[..., [0], :] img_inds = np.arange(self.data_shapes["w"]) if ref_data_dvwu.shape[-2] == 1: ref_inds = np.zeros_like(img_inds) else: ref_inds = np.arange(ref_data_dvwu.shape[-2]) if img_inds.size != ref_inds.size: raise ValueError( f"Reference images should either be 1 or as many as the data images" f" ({img_inds.size}), but {ref_inds.size} were passed instead" ) is_3d = self.data_shapes["v"] > 1 num_dims = 1 + is_3d rel_shifts_vu = np.zeros((num_dims, img_inds.size)) for ii, (ind_ref, ind_img) in enumerate(zip(tqdm(ref_inds, desc="Computing drifts"), img_inds)): rel_shifts_vu[..., [ii]] = self.find_shifts_vu(ref_data_dvwu[..., [ind_ref], :], self.data_vwu[..., [ind_img], :]) return rel_shifts_vu
[docs] def fit_vu(self, fit_l1: bool = False) -> NDArray: """Compute the pre-alignment vertical and horizontal shifts, using cross-correlation. Parameters ---------- fit_l1 : bool, optional Computes the l1-min fit of the sinusoid, by default False. Returns ------- NDArray Pre-alignment shifts in VU coordinates. """ is_3d = self.data_shapes["v"] > 1 num_dims = 1 + is_3d angles_order = np.argsort(self.angles_rad) sorted_angles_rad = self.angles_rad[angles_order] rel_shifts_vu = np.zeros((num_dims, len(self.angles_rad))) desc = "Computing adjacent images shifts" for ii, (a_prev, a_curr) in enumerate(zip(tqdm(angles_order[:-1], desc=desc), angles_order[1:])): rel_shifts_vu[..., [ii + 1]] = self.find_shifts_vu( self.data_vwu[..., [a_prev], :], self.data_vwu[..., [a_curr], :] ) shifts_vu = np.cumsum(rel_shifts_vu, axis=-1) if is_3d: shifts_vu[0, :] -= np.mean(shifts_vu[0, :]) shifts_vu[0, :] = np.around(shifts_vu[0, :], decimals=self.decimals) a, p, b = fitting.fit_sinusoid(sorted_angles_rad, shifts_vu[-1, ...], fit_l1=fit_l1) if self.verbose: coords = ["V", "U"] if is_3d else ["U"] fig, axs = plt.subplots(1, num_dims, figsize=[10, 5], sharex=True, squeeze=False) angles_deg = np.rad2deg(sorted_angles_rad) for ii, (s, coord) in enumerate(zip(shifts_vu, coords)): axs[0, ii].plot(angles_deg, s, label="Shifts") if coord == "V": axs[0, ii].plot(angles_deg, np.zeros_like(s), label="Zero") else: axs[0, ii].plot(angles_deg, np.ones_like(s) * b, label="Bias") axs[0, ii].plot(angles_deg, fitting.sinusoid(sorted_angles_rad, a, p, b), label="Expected motion") axs[0, ii].legend() axs[0, ii].grid() axs[0, ii].set_title(f"Shifts {coord}") fig.tight_layout() plt.show(block=False) shifts_u: NDArrayFloat = shifts_vu[-1, ...] - fitting.sinusoid(sorted_angles_rad, a, p, b) shifts_u = _filter_shifts(shifts_u, self.max_shifts[-1, :]) shifts_u = np.around(shifts_u, decimals=self.decimals) shifts_vu[-1] = shifts_u return shifts_vu[:, list(angles_order)]
[docs] def fit_u_180(self) -> float: """Find the center-of-rotation, using the 0 and 180 degrees projections. Returns ------- float The center-of-rotation. """ angle_0 = self.angles_rad.min() angle_0_ind = np.argmin(np.abs(self.angles_rad - angle_0)) angle_180_ind = np.argmin(np.abs(self.angles_rad - (angle_0 + np.pi))) angle_180 = self.angles_rad[angle_180_ind] if not np.isclose(angle_0 + np.pi, angle_180): print( f"WARNING - No opposite angles found ({np.rad2deg(angle_0)} and {np.rad2deg(angle_180)})." " Center-of-rotation will be inaccurate." ) img_0 = self.data_vwu[..., [angle_0_ind], :] img_180 = np.flip(self.data_vwu[..., [angle_180_ind], :], axis=-1) # upsample_factor = 1 / (10 ** (-self.decimals)) # return skr.phase_cross_correlation(img_0, img_180, upsample_factor=upsample_factor, return_error=False) shifts_vu = self.find_shifts_vu(img_0, img_180) return -shifts_vu[-1] / 2
[docs] def fit_u_360(self) -> float: """Find the center of rotation over a 360 degrees scan, by taking the average of the 0-180 over all pairs of angles. Returns ------- float The center-of-rotation. """ # We should be checking whether the scan is really 360 or not. angles_boundary = self.angles_rad[0] + np.pi num_angles = np.sum(self.angles_rad < angles_boundary) a1s = self.angles_rad[:num_angles] a2s = a1s + np.pi iis_1 = np.arange(num_angles) iis_2 = np.argmin(np.abs(self.angles_rad[None, :] - a2s[:, None]), axis=-1) shifts_vu = np.empty([2, len(iis_1)]) for ii, (ii1, ii2) in enumerate(tqdm(zip(iis_1, iis_2), total=num_angles)): img_1 = self.data_vwu[..., [ii1], :] img_2 = np.flip(self.data_vwu[..., [ii2], :], axis=-1) shifts_vu[..., [ii]] = self.find_shifts_vu(img_1, img_2) cors = -shifts_vu[-1, :] / 2 # upsample_factor = int(1 / (10 ** (-self.decimals))) # cors = np.empty_like(a1s) # for ii_1, a1 in enumerate(tqdm(a1s)): # ii_2 = np.argmin(np.abs(self.angles_rad - (a1 + np.pi))) # # We should be handling non-redundant scans. # img_0 = self.data_vwu[..., ii_1, :] # img_180 = np.flip(self.data_vwu[..., ii_2, :], axis=-1) # s = skr.phase_cross_correlation(img_0, img_180, upsample_factor=upsample_factor, return_error=False) # cors[ii_1] = -s[-1] / 2 cor = np.around(np.mean(cors), decimals=self.decimals) if self.verbose: fig, axs = plt.subplots(1, 1, figsize=[10, 5]) axs.plot(cors) axs.plot(np.ones_like(cors) * cor) axs.grid() axs.set_title("Centers of rotation") fig.tight_layout() plt.show(block=False) return float(cor)
[docs] def find_shifts_vu( self, data_dvwu: NDArrayFloat, proj_dvwu: NDArrayFloat, use_derivative: bool = False, xc_opts: Mapping = dict(normalize_fourier=False), ) -> NDArrayFloat: """Find shifts between two images or sets of lines. Parameters ---------- data_dvwu : NDArrayFloat The reference data. proj_dvwu : NDArrayFloat The other data. use_derivative : bool, optional Whether to use derivatives over the horizontal (U) coordinate, by default False. Returns ------- NDArrayFloat The shifts in vertical (optional) and horizontal coordinates ([V]U). """ if self.num_dets == 1: data_dvwu = data_dvwu[None, ...] proj_dvwu = proj_dvwu[None, ...] shifts_vu_all = [np.array([])] * self.num_dets for ii_d in range(self.num_dets): data_vwu = data_dvwu[ii_d] proj_vwu = proj_dvwu[ii_d] if use_derivative: data_vwu = np.diff(data_vwu, axis=-1) proj_vwu = np.diff(proj_vwu, axis=-1) # Allow to choose different shift finding functions shifts_vu = fitting.fit_shifts_vu_xc(data_vwu, proj_vwu, decimals=self.decimals, **xc_opts) # shifts_vu = fitting.fit_shifts_u_sad(data_vwu, proj_vwu, decimals=self.decimals) shifts_vu_all[ii_d] = _filter_shifts(shifts_vu, self.max_shifts) return np.mean(shifts_vu_all, axis=0)