"""
Detector shifts finding classes.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from collections.abc import Mapping
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from tqdm.auto import tqdm
from . import fitting
eps = np.finfo(np.float32).eps
NDArrayFloat = NDArray[np.floating]
def _filter_shifts(shifts_vu: NDArrayFloat, max_shifts: NDArrayFloat) -> NDArrayFloat:
invalid_shifts = np.abs(shifts_vu) > max_shifts
shifts_vu_filt = shifts_vu.copy()
if np.any(invalid_shifts):
print(f"WARNING - Some shifts exceeded the maximum allowed magnitude (max: {max_shifts}). Setting them to 0.")
shifts_vu_filt[invalid_shifts] = 0
return shifts_vu_filt
[docs]
class DetectorShiftsBase:
"""Compute the detector shifts for a given dataset."""
data_vwu: NDArrayFloat
angles_rad: NDArrayFloat
def __init__(
self,
data_dvwu: NDArrayFloat,
rot_angle_rad: Union[ArrayLike, NDArrayFloat],
*,
data_format: str = "dvwu",
data_mask_dvwu: Optional[NDArray] = None,
borders_dvwu: dict = {"d": None, "v": None, "w": None, "u": None},
max_shifts: Union[float, NDArrayFloat, None] = None,
precision_decimals: int = 2,
verbose: bool = True,
):
"""Initialize the base class for detector shifts.
Parameters
----------
data_dvwu : NDArrayFloat
The tomographic data.
rot_angle_rad : ArrayLike | NDArrayFloat
The rotation angles in radians.
data_format : str, optional
The data organization, by default "dvwu"
data_mask_dvwu : NDArray | None, optional
The mask of the tomographic data, by default None
borders_dvwu : dict, optional
The borders of the tomographic data, by default {"d": None, "v": None, "w": None, "u": None}
max_shifts : float | NDArrayFloat | None, optional
Maximum shifts allowed, by default None
precision_decimals : int, optional
The precision of the results, by default 2
verbose : bool, optional
Whether to be verbose, by default True
Raises
------
ValueError
Raised when passing incoherent data and angles.
"""
self.data_dims = len(data_dvwu.shape)
self.data_shapes = dict(u=0, v=1, w=0, d=1)
for ii in range(-self.data_dims, 0):
self.data_shapes[data_format[ii]] = data_dvwu.shape[ii]
self.num_dets = self.data_shapes["d"]
self.angles_rad = np.array(np.squeeze(rot_angle_rad), ndmin=1, dtype=np.float32)
if self.data_shapes["w"] != len(self.angles_rad):
raise ValueError(
f"Mismatch between rotation angles ({len(self.angles_rad)}),"
f" and number of projections ({self.data_shapes['w']})."
)
slicing = [borders_dvwu[data_format[ii]] for ii in range(-self.data_dims, 0)]
slicing = [slice(s, -s) if s is not None and s > 0 else slice(None) for s in slicing]
self.slicing = tuple(slicing)
if self.data_shapes["v"] > 1:
self._align_coords = np.array([-3, -1])
else:
self._align_coords = np.array([-1])
if max_shifts is None:
max_shifts = np.array(data_dvwu.shape)[self._align_coords] / 2
max_shifts = np.array(max_shifts, ndmin=1)
self.max_shifts = max_shifts[:, None]
self.decimals = precision_decimals
self.verbose = verbose
self.data_vwu = data_dvwu
self.data_mask_vwu = data_mask_dvwu
[docs]
class DetectorShiftsPRE(DetectorShiftsBase):
"""Compute the pre-alignment detector shifts for a given dataset."""
[docs]
def fit_v(
self,
use_derivative: bool = True,
use_rfft: bool = True,
normalize_fourier: bool = True,
) -> NDArrayFloat:
"""Compute the pre-alignment vertical shifts of a 3D dataset.
The pre-alignment shifts are computed by cross-correlation of one
projection against the others.
The projections are integrated in the horizontal direction.
In the vertical direction, it is suggested to use some high pass filter.
The default option is to use the derivates of the intensity profiles.
Parameters
----------
use_derivative : bool, optional
Whether to use the derivate of the vertical profile, by default True
use_rfft : bool, optional
Whether to use the `rfft` transform for the cross-correlation, by default True
normalize_fourier : bool, optional
Whether to normalize the cross-correlation in Fourier space, by default True
Returns
-------
NDArrayFloat
The vertical shifts.
Raises
------
ValueError
If the dataset is 2D.
"""
if self.data_shapes["v"] <= 1:
raise ValueError("Vertical alignment not supported for 2D reconstructions.")
if use_rfft:
local_fft = np.fft.rfft
local_ifft = np.fft.irfft
else:
local_fft = np.fft.fft
local_ifft = np.fft.ifft
data_vwu = self.data_vwu[self.slicing]
if self.data_mask_vwu is None:
data_vw = np.mean(data_vwu, axis=-1)
else:
data_mask_vwu = self.data_mask_vwu[self.slicing].astype(data_vwu.dtype)
data_vwu = data_vwu * data_mask_vwu
mask_sum_vw = data_mask_vwu.sum(axis=-1)
data_vw = data_vwu.sum(axis=-1) / (mask_sum_vw + (mask_sum_vw == 0))
if self.num_dets > 1:
data_vw = np.mean(data_vw, axis=-3)
mins = data_vw.min(axis=-2, keepdims=True)
maxs = data_vw.max(axis=-2, keepdims=True)
data_vw_d = (data_vw - mins) / (maxs - mins)
pad_size = ((data_vw.shape[-2] // 2,), (0,))
if use_derivative:
data_vw_d = np.diff(data_vw_d, axis=-2)
data_vw_p = np.pad(data_vw_d, pad_width=pad_size, mode="constant")
else:
data_vw_p = np.pad(data_vw_d, pad_width=pad_size, mode="linear_ramp")
data_vw_f = local_fft(data_vw_p, axis=-2)
ref_angle = len(self.angles_rad) // 2
ccs_f = data_vw_f[:, [ref_angle]] * data_vw_f.conj()
if normalize_fourier:
ccs_f /= np.fmax(np.abs(ccs_f).max(axis=-2, keepdims=True), eps)
cross_corr = local_ifft(ccs_f, axis=-2).real
cc_coords = np.fft.fftfreq(cross_corr.shape[-2], 1 / cross_corr.shape[-2])
f_vals, fc_ax = fitting.extract_peak_regions_1d(cross_corr, axis=-2, cc_coords=cc_coords)
shifts_v = fitting.refine_max_position_1d(f_vals, decimals=self.decimals) + fc_ax[1, :]
shifts_v = _filter_shifts(shifts_v, self.max_shifts[0, :])
shifts_v -= np.mean(shifts_v)
shifts_v = np.around(shifts_v, decimals=self.decimals)
if self.verbose:
fig, axs = plt.subplots(2, 2, figsize=[10, 5], sharex=True)
axs[0, 0].imshow(data_vw)
axs[0, 0].set_title("Data VW")
axs[0, 0].set_xlabel("Coord. W (angular)")
axs[0, 0].set_ylabel("Coord. V (vertical)")
axs[0, 1].plot(shifts_v)
axs[0, 1].plot(np.zeros_like(shifts_v))
axs[0, 1].grid()
axs[0, 1].set_title(f"Shifts V (wrt angle n.{ref_angle})")
axs[0, 1].set_xlabel("Coord. W (angular)")
axs[0, 1].set_ylabel("Coord. V (vertical)")
axs[1, 0].imshow(data_vw_p)
axs[1, 0].set_title("Data used for cross-correlation")
axs[1, 1].imshow(np.fft.fftshift(cross_corr, axes=(-2,)))
axs[1, 1].set_title("Cross-correlation")
fig.tight_layout()
plt.show(block=False)
return shifts_v
[docs]
def fit_u(
self,
fit_l1: bool = False,
background: Union[float, NDArray, None] = None,
method: str = "com",
) -> tuple[NDArrayFloat, float]:
"""Compute the pre-alignment shifts for the horizontal dimension.
The pre-alignment shifts, and center-of-rotation (CoR) are computed by
fitting a sinusoid to the centers of mass of each angle in the sinogram.
The bias of the sinusoid corresponds to the CoR, while the deviations
from the fitted curve correspond to the shifts.
Parameters
----------
fit_l1 : bool, optional
Computes the l1-min fit of the sinusoid, by default False.
background : float | NDArray | None, optional
Removes the given background, by default None.
method : str, optional
The method used for the identification of the fiducial marker position.
Options are "com" (center-of-mass) | "max" (maximum value), by default "com".
Returns
-------
Tuple[NDArrayFloat, float]
The shifts and the CoR.
"""
is_3d = self.data_shapes["v"] > 1
data_vwu = self.data_vwu[self.slicing]
if background is not None:
data_vwu = data_vwu - background
data_vwu = np.fmax(data_vwu, 0.0)
if self.num_dets > 1:
data_vwu = np.mean(data_vwu, axis=0)
if is_3d:
data_vwu = np.mean(data_vwu, axis=-3)
fx_half_size = (data_vwu.shape[-1] - 1) / 2
if method.lower() == "com":
fx = np.linspace(-fx_half_size, fx_half_size, data_vwu.shape[-1])
ref_points = -np.sum(data_vwu * fx, axis=-1) / np.sum(data_vwu, axis=-1)
elif method.lower() == "max":
ref_points = fx_half_size - np.argmax(data_vwu, axis=-1)
else:
raise ValueError(f"Unkown selected method {method}. Please choose one among: 'com' | 'max'")
a, p, b = fitting.fit_sinusoid(self.angles_rad, ref_points, fit_l1=fit_l1)
cor = np.around(b, decimals=self.decimals)
if self.verbose:
angles_deg = np.rad2deg(self.angles_rad)
sort_angles_deg = np.sort(angles_deg)
sort_angles_rad = np.sort(self.angles_rad)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].scatter(angles_deg, ref_points, label="Centers of mass")
axs[0].plot(sort_angles_deg, fitting.sinusoid(sort_angles_rad, a, p, b), c="C1", label="Fitted sinusoid")
axs[0].plot(sort_angles_deg, np.ones_like(sort_angles_deg) * b, c="C2", label="Bias (CoR)")
axs[0].legend()
axs[0].grid()
axs[0].set_xlabel("Coord. W (angular)")
axs[0].set_ylabel("Coord. U (horizontal)")
axs[0].xaxis.label.set_fontsize(16)
axs[0].yaxis.label.set_fontsize(16)
axs[1].imshow(data_vwu)
axs[1].scatter(-ref_points + fx_half_size, np.arange(len(ref_points)), c="C1")
axs[1].set_xlabel("Coord. U (horizontal)")
axs[1].set_ylabel("Coord. W (angular)")
axs[1].xaxis.label.set_fontsize(16)
axs[1].yaxis.label.set_fontsize(16)
fig.tight_layout()
plt.show(block=False)
print(f"amplitude = {a} (pix)")
print(f"phase = {p} (rad)")
print(f"bias = {b} (pix)")
print(f" -> cor = {cor} (pix)")
shifts_u: NDArrayFloat = ref_points - fitting.sinusoid(self.angles_rad, a, p, b)
shifts_u = _filter_shifts(shifts_u, self.max_shifts[-1, :])
shifts_u = np.around(shifts_u, decimals=self.decimals)
return shifts_u, float(cor)
[docs]
class DetectorShiftsXC(DetectorShiftsBase):
"""Compute the center-of-rotation for a given dataset, by cross correlation."""
[docs]
def fit_vu_accum_drifts(self, ref_data_dvwu: Optional[NDArrayFloat] = None) -> NDArray:
"""Fit static image drifts.
Parameters
----------
ref_data_dvwu : Optional[NDArrayFloat], optional
Reference image, by default None. If None, the first image in the
data stack will be used.
Returns
-------
NDArray
The shifts of the image stack.
Raises
------
ValueError
When the number of reference images is either too many or not enough.
"""
if ref_data_dvwu is None:
ref_data_dvwu = self.data_vwu[..., [0], :]
img_inds = np.arange(self.data_shapes["w"])
if ref_data_dvwu.shape[-2] == 1:
ref_inds = np.zeros_like(img_inds)
else:
ref_inds = np.arange(ref_data_dvwu.shape[-2])
if img_inds.size != ref_inds.size:
raise ValueError(
f"Reference images should either be 1 or as many as the data images"
f" ({img_inds.size}), but {ref_inds.size} were passed instead"
)
is_3d = self.data_shapes["v"] > 1
num_dims = 1 + is_3d
rel_shifts_vu = np.zeros((num_dims, img_inds.size))
for ii, (ind_ref, ind_img) in enumerate(zip(tqdm(ref_inds, desc="Computing drifts"), img_inds)):
rel_shifts_vu[..., [ii]] = self.find_shifts_vu(ref_data_dvwu[..., [ind_ref], :], self.data_vwu[..., [ind_img], :])
return rel_shifts_vu
[docs]
def fit_vu(self, fit_l1: bool = False) -> NDArray:
"""Compute the pre-alignment vertical and horizontal shifts, using cross-correlation.
Parameters
----------
fit_l1 : bool, optional
Computes the l1-min fit of the sinusoid, by default False.
Returns
-------
NDArray
Pre-alignment shifts in VU coordinates.
"""
is_3d = self.data_shapes["v"] > 1
num_dims = 1 + is_3d
angles_order = np.argsort(self.angles_rad)
sorted_angles_rad = self.angles_rad[angles_order]
rel_shifts_vu = np.zeros((num_dims, len(self.angles_rad)))
desc = "Computing adjacent images shifts"
for ii, (a_prev, a_curr) in enumerate(zip(tqdm(angles_order[:-1], desc=desc), angles_order[1:])):
rel_shifts_vu[..., [ii + 1]] = self.find_shifts_vu(
self.data_vwu[..., [a_prev], :], self.data_vwu[..., [a_curr], :]
)
shifts_vu = np.cumsum(rel_shifts_vu, axis=-1)
if is_3d:
shifts_vu[0, :] -= np.mean(shifts_vu[0, :])
shifts_vu[0, :] = np.around(shifts_vu[0, :], decimals=self.decimals)
a, p, b = fitting.fit_sinusoid(sorted_angles_rad, shifts_vu[-1, ...], fit_l1=fit_l1)
if self.verbose:
coords = ["V", "U"] if is_3d else ["U"]
fig, axs = plt.subplots(1, num_dims, figsize=[10, 5], sharex=True, squeeze=False)
angles_deg = np.rad2deg(sorted_angles_rad)
for ii, (s, coord) in enumerate(zip(shifts_vu, coords)):
axs[0, ii].plot(angles_deg, s, label="Shifts")
if coord == "V":
axs[0, ii].plot(angles_deg, np.zeros_like(s), label="Zero")
else:
axs[0, ii].plot(angles_deg, np.ones_like(s) * b, label="Bias")
axs[0, ii].plot(angles_deg, fitting.sinusoid(sorted_angles_rad, a, p, b), label="Expected motion")
axs[0, ii].legend()
axs[0, ii].grid()
axs[0, ii].set_title(f"Shifts {coord}")
fig.tight_layout()
plt.show(block=False)
shifts_u: NDArrayFloat = shifts_vu[-1, ...] - fitting.sinusoid(sorted_angles_rad, a, p, b)
shifts_u = _filter_shifts(shifts_u, self.max_shifts[-1, :])
shifts_u = np.around(shifts_u, decimals=self.decimals)
shifts_vu[-1] = shifts_u
return shifts_vu[:, list(angles_order)]
[docs]
def fit_u_180(self) -> float:
"""Find the center-of-rotation, using the 0 and 180 degrees projections.
Returns
-------
float
The center-of-rotation.
"""
angle_0 = self.angles_rad.min()
angle_0_ind = np.argmin(np.abs(self.angles_rad - angle_0))
angle_180_ind = np.argmin(np.abs(self.angles_rad - (angle_0 + np.pi)))
angle_180 = self.angles_rad[angle_180_ind]
if not np.isclose(angle_0 + np.pi, angle_180):
print(
f"WARNING - No opposite angles found ({np.rad2deg(angle_0)} and {np.rad2deg(angle_180)})."
" Center-of-rotation will be inaccurate."
)
img_0 = self.data_vwu[..., [angle_0_ind], :]
img_180 = np.flip(self.data_vwu[..., [angle_180_ind], :], axis=-1)
# upsample_factor = 1 / (10 ** (-self.decimals))
# return skr.phase_cross_correlation(img_0, img_180, upsample_factor=upsample_factor, return_error=False)
shifts_vu = self.find_shifts_vu(img_0, img_180)
return -shifts_vu[-1] / 2
[docs]
def fit_u_360(self) -> float:
"""Find the center of rotation over a 360 degrees scan, by taking the average of the 0-180 over all pairs of angles.
Returns
-------
float
The center-of-rotation.
"""
# We should be checking whether the scan is really 360 or not.
angles_boundary = self.angles_rad[0] + np.pi
num_angles = np.sum(self.angles_rad < angles_boundary)
a1s = self.angles_rad[:num_angles]
a2s = a1s + np.pi
iis_1 = np.arange(num_angles)
iis_2 = np.argmin(np.abs(self.angles_rad[None, :] - a2s[:, None]), axis=-1)
shifts_vu = np.empty([2, len(iis_1)])
for ii, (ii1, ii2) in enumerate(tqdm(zip(iis_1, iis_2), total=num_angles)):
img_1 = self.data_vwu[..., [ii1], :]
img_2 = np.flip(self.data_vwu[..., [ii2], :], axis=-1)
shifts_vu[..., [ii]] = self.find_shifts_vu(img_1, img_2)
cors = -shifts_vu[-1, :] / 2
# upsample_factor = int(1 / (10 ** (-self.decimals)))
# cors = np.empty_like(a1s)
# for ii_1, a1 in enumerate(tqdm(a1s)):
# ii_2 = np.argmin(np.abs(self.angles_rad - (a1 + np.pi)))
# # We should be handling non-redundant scans.
# img_0 = self.data_vwu[..., ii_1, :]
# img_180 = np.flip(self.data_vwu[..., ii_2, :], axis=-1)
# s = skr.phase_cross_correlation(img_0, img_180, upsample_factor=upsample_factor, return_error=False)
# cors[ii_1] = -s[-1] / 2
cor = np.around(np.mean(cors), decimals=self.decimals)
if self.verbose:
fig, axs = plt.subplots(1, 1, figsize=[10, 5])
axs.plot(cors)
axs.plot(np.ones_like(cors) * cor)
axs.grid()
axs.set_title("Centers of rotation")
fig.tight_layout()
plt.show(block=False)
return float(cor)
[docs]
def find_shifts_vu(
self,
data_dvwu: NDArrayFloat,
proj_dvwu: NDArrayFloat,
use_derivative: bool = False,
xc_opts: Mapping = dict(normalize_fourier=False),
) -> NDArrayFloat:
"""Find shifts between two images or sets of lines.
Parameters
----------
data_dvwu : NDArrayFloat
The reference data.
proj_dvwu : NDArrayFloat
The other data.
use_derivative : bool, optional
Whether to use derivatives over the horizontal (U) coordinate, by default False.
Returns
-------
NDArrayFloat
The shifts in vertical (optional) and horizontal coordinates ([V]U).
"""
if self.num_dets == 1:
data_dvwu = data_dvwu[None, ...]
proj_dvwu = proj_dvwu[None, ...]
shifts_vu_all = [np.array([])] * self.num_dets
for ii_d in range(self.num_dets):
data_vwu = data_dvwu[ii_d]
proj_vwu = proj_dvwu[ii_d]
if use_derivative:
data_vwu = np.diff(data_vwu, axis=-1)
proj_vwu = np.diff(proj_vwu, axis=-1)
# Allow to choose different shift finding functions
shifts_vu = fitting.fit_shifts_vu_xc(data_vwu, proj_vwu, decimals=self.decimals, **xc_opts)
# shifts_vu = fitting.fit_shifts_u_sad(data_vwu, proj_vwu, decimals=self.decimals)
shifts_vu_all[ii_d] = _filter_shifts(shifts_vu, self.max_shifts)
return np.mean(shifts_vu_all, axis=0)