"""
Pre-processing routines.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from collections.abc import Sequence
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import pywt
import skimage.transform as skt
from numpy.polynomial import Polynomial
from numpy.typing import DTypeLike, NDArray
from scipy.ndimage import convolve, fourier_shift, gaussian_filter, shift
from skimage.measure import block_reduce
from tqdm.auto import trange
eps = np.finfo(np.float32).eps
[docs]
def pad_sinogram(
sinogram: NDArray, width: Union[int, Sequence[int], NDArray], pad_axis: int = -1, mode: str = "edge", **kwds
) -> NDArray:
"""
Pad the sinogram.
Parameters
----------
sinogram : NDArray
The sinogram to pad.
width : Union[int, Sequence[int]]
The width of the padding. Normally, it should either be an int or a tuple(int, int).
pad_axis : int, optional
The axis to pad. The default is -1.
mode : str, optional
The padding type (from numpy.pad). The default is "edge".
**kwds :
The numpy.pad arguments.
Returns
-------
NDArray
The padded sinogram.
"""
width = np.array(width, ndmin=1)
pad_size = np.zeros((len(sinogram.shape), len(width)), dtype=int)
pad_size[pad_axis, :] = width
return np.pad(sinogram, pad_size, mode=mode.lower(), **kwds) # type: ignore
[docs]
def apply_flat_field(
projs_wvu: NDArray,
flats_wvu: NDArray,
darks_wvu: Optional[NDArray] = None,
crop: Union[NDArray, Sequence[int], None] = None,
cap_intensity: Optional[float] = None,
dtype: DTypeLike = np.float32,
) -> NDArray:
"""
Apply flat field.
Parameters
----------
projs : NDArray
Projections.
flats : NDArray
Flat fields.
darks : Optional[NDArray], optional
Dark noise. The default is None.
crop : Optional[Sequence[int]], optional
Crop region. The default is None.
cap_intensity: float | None, optional
Cap the intensity to a given value. The default is None.
dtype : DTypeLike, optional
Data type of the processed data. The default is np.float32.
Returns
-------
NDArray
Flat-field corrected and linearized projections.
"""
projs_wvu = np.ascontiguousarray(projs_wvu, dtype=dtype)
flats_wvu = np.ascontiguousarray(flats_wvu, dtype=dtype)
if crop is not None:
projs_wvu = projs_wvu[..., crop[0] : crop[2], crop[1] : crop[3]]
flats_wvu = flats_wvu[..., crop[0] : crop[2], crop[1] : crop[3]]
if darks_wvu is not None:
darks_wvu = darks_wvu[..., crop[0] : crop[2], crop[1] : crop[3]]
if darks_wvu is not None:
darks_wvu = np.ascontiguousarray(darks_wvu, dtype=dtype)
projs_wvu = projs_wvu - darks_wvu
flats_wvu = flats_wvu - darks_wvu
if flats_wvu.ndim == 3:
flats_wvu = np.mean(flats_wvu, axis=0)
projs_wvu = projs_wvu / flats_wvu
if cap_intensity is not None:
projs_wvu = np.fmin(projs_wvu, cap_intensity)
return projs_wvu
[docs]
def apply_minus_log(projs: NDArray, lower_limit: float = -np.inf) -> NDArray:
"""
Apply -log.
Parameters
----------
projs : NDArray
Projections.
Returns
-------
NDArray
Linearized projections.
"""
return np.fmax(-np.log(projs), lower_limit)
[docs]
def rotate_proj_stack(data_vwu: NDArray, rot_angle_deg: float) -> NDArray:
"""
Rotate the projection stack.
Parameters
----------
data_vwu : NDArray
The projection stack, with dimensions [v, w, u] (vertical, omega / sample rotation, horizontal).
rot_angle_deg : float
The rotation angle in degrees.
Returns
-------
NDArray
The rotated projection stack.
"""
data_vwu_r = np.empty_like(data_vwu)
for ii in range(data_vwu.shape[-2]):
data_vwu_r[:, ii, :] = skt.rotate(data_vwu[:, ii, :], -rot_angle_deg, clip=False)
return data_vwu_r
[docs]
def shift_proj_stack(data_vwu: NDArray, shifts: NDArray, use_fft: bool = False) -> NDArray:
"""Shift each projection in a stack of projections, by projection dependent shifts.
Parameters
----------
data_vwu : NDArray
The projection stack
shifts : NDArray
The shifts
use_fft : bool, optional
Whether to use fft shift or not, by default False
Returns
-------
NDArray
The shifted stack
"""
new_data = np.empty_like(data_vwu)
for ii in range(data_vwu.shape[-2]):
if use_fft:
img = data_vwu[..., ii, :]
img_f = np.fft.rfftn(img)
img_f = fourier_shift(img_f, shifts[..., ii], n=img.shape[-1])
new_data[..., ii, :] = np.fft.irfftn(img_f)
else:
new_data[..., ii, :] = shift(data_vwu[..., ii, :], shifts[..., ii], order=1, mode="nearest")
return new_data
[docs]
def bin_imgs(
imgs: NDArray, binning: Union[int, float], axes: Sequence[int] = (-2, -1), auto_crop: bool = False, verbose: bool = True
) -> NDArray:
"""Bin a stack of images.
Parameters
----------
imgs : NDArray
The stack of images.
binning : int | float
The binning factor.
auto_crop : bool, optional
Whether to automatically crop the images to match, by default False
verbose : bool, optional
Whether to print the image shapes, by default True
Returns
-------
NDArray
The binned images
"""
if auto_crop:
imgs_shape = np.array(imgs.shape)
excess_pixels = (imgs_shape[list(axes)] % binning).astype(int)
crop_vu = (excess_pixels - excess_pixels // 2, imgs_shape[list(axes)] - excess_pixels // 2)
slicing = [slice(None)] * len(imgs_shape)
for ii, ax in enumerate(axes):
if excess_pixels[ii] > 0:
slicing[ax] = slice(crop_vu[0][ii], crop_vu[1][ii])
imgs = imgs[tuple(slicing)]
if verbose:
print(f"Auto-cropping {crop_vu}: {imgs_shape} => {imgs.shape}")
imgs_shape = imgs.shape
if isinstance(binning, int):
binning_shape = np.ones_like(imgs_shape)
for ax in axes:
binning_shape[ax] = binning
imgs = block_reduce(imgs, tuple(binning_shape), np.mean)
binned_shape = imgs.shape
else:
imgs = imgs.reshape([-1, *imgs_shape[-2:]])
imgs = skt.rescale(imgs, 1 / binning, channel_axis=0 if len(imgs_shape) > 2 else None)
binned_shape = [*imgs_shape[:-2], *imgs.shape[-2:]]
imgs = imgs.reshape(binned_shape)
if verbose:
print(f"Binning {binning}: {imgs_shape} => {binned_shape}")
return imgs
[docs]
def background_from_margin(
data_vwu: NDArray, margin: Union[int, Sequence[int], NDArray[np.integer]] = 4, poly_order: int = 0, plot: bool = False
) -> NDArray:
"""Compute background of the projection data, from the margins of the projections.
Parameters
----------
data_vwu : NDArray
The projection data in the format [V]WU.
margin : int | Sequence[int] | NDArray[np.integer], optional
The size of the margin, by default 4
poly_order : int, optional
The order of the interpolation polynomial, by default 0
Returns
-------
NDArray
The computed background.
Raises
------
NotImplementedError
Different margins per line are not supported, at the moment.
ValueError
In case the margins ar larger than the image size in U.
"""
data_shape_u = data_vwu.shape[-1]
data_shape_w = data_vwu.shape[-2]
margin = np.array(margin, dtype=int, ndmin=1)
if margin.size == 1:
margin = np.tile(margin, [*np.ones(margin.ndim - 1), 2])
if margin.ndim > 1:
raise NotImplementedError("Complex masks support has not been implemented, yet.")
if margin.sum() > data_shape_u:
raise ValueError(f"Margin size ({margin}) should be smaller than the image size in U ({data_shape_u})")
if poly_order > 0 and np.any(margin == 0):
print("WARNING: parameter `poly_order` cannot be greater than 0 if one of the margins is 0")
poly_order = 0
if poly_order > 0:
ydata = np.concatenate([data_vwu[..., : margin[0]], data_vwu[..., -margin[1] :]], axis=-1)
xdata = np.concatenate([np.arange(0, margin[0]), np.arange(data_shape_u - margin[1], data_shape_u)])
if data_vwu.ndim > 2:
ydata = ydata.mean(axis=-3)
background = np.empty([data_shape_w, data_shape_u], dtype=data_vwu.dtype)
for ii_w in range(data_shape_w):
poly = Polynomial.fit(xdata, ydata[ii_w], deg=poly_order)
background[ii_w, :] = poly(np.arange(data_shape_u))
if plot:
fig, axs = plt.subplots(1, 1)
axs.plot(background[0])
axs.scatter(xdata, ydata[0])
axs.grid()
axs.set_ylim(0)
fig.tight_layout()
plt.show(block=False)
if data_vwu.ndim > 2:
background = np.tile(background[None, ...], [data_vwu.shape[-3], 1, 1])
return background
else:
sum_vals = data_vwu[..., : margin[0]].sum(axis=-1) + data_vwu[..., -margin[1] :].sum(axis=-1)
background: NDArray = sum_vals / margin.sum(axis=-1)
if data_vwu.ndim > 2:
background = background.mean(axis=-2, keepdims=True)
return np.tile(background[..., None], [*np.ones(background.ndim, dtype=int), data_shape_u])
[docs]
def snip(
img: NDArray, kernel_dims: Union[int, None] = None, iterations: int = 1000, window: int = 3, verbose: bool = False
) -> NDArray:
"""
Apply the SNIP algorithm to an image to estimate the background.
Parameters
----------
img : NDArray
The input image to process.
kernel_dims : Union[int, None], optional
The number of dimensions to apply the SNIP algorithm to. If None, it defaults to the number of dimensions of the image.
iterations : int, optional
The number of iterations to run the SNIP algorithm.
window : int, optional
The size of the window for the convolution kernel.
verbose : bool, optional
If True, display a progress bar during the iterations.
Returns
-------
NDArray
The background-estimated image.
Raises
------
ValueError
If `kernel_dims` is not between 1 and the number of dimensions of the image.
"""
if kernel_dims is None:
kernel_dims = img.ndim
elif kernel_dims > img.ndim or kernel_dims < 1:
raise ValueError(f"Kernel dimensions (#{kernel_dims}) should be between [1, {img.ndim}] (number of image dimensions)")
kernel_shape = [1 if ii < -kernel_dims else window for ii in range(-img.ndim, 0)]
kern = np.ones(kernel_shape, dtype=img.dtype)
slices = [slice(None) if ii < -kernel_dims else slice(1, -1) for ii in range(-img.ndim, 0)]
kern[tuple(slices)] = 0.0
kern /= kern.sum()
bckgnd_img = img.copy()
for _ in trange(iterations, disable=not verbose):
conv_img = convolve(bckgnd_img, kern, mode="nearest")
bckgnd_img = np.fmin(conv_img, bckgnd_img)
return bckgnd_img
[docs]
def background_from_snip(data_vwu: NDArray, snip_iterations: int = 6, smooth_std: float = 0.0) -> NDArray:
"""
Fit the background of the projection data using the SNIP algorithm.
Parameters
----------
data_vwu : NDArray
The input dataset to process.
snip_iterations : int, optional
The number of iterations to run the SNIP algorithm.
smooth_std : float, optional
The standard deviation for Gaussian smoothing. If 0.0, no smoothing is applied.
Returns
-------
NDArray
The background-fitted dataset.
"""
bg_data_vwu = np.empty_like(data_vwu)
for ii_a in trange(data_vwu.shape[-2], desc="Angles"):
img = data_vwu[..., ii_a, :]
if smooth_std > 0.0:
img = gaussian_filter(img, sigma=smooth_std)
if img.ndim > 2:
for ii_d in range(img.shape[0]):
img[ii_d] = snip(img[ii_d], iterations=snip_iterations)
bg_data_vwu[..., ii_a, :] = img
else:
bg_data_vwu[..., ii_a, :] = snip(img, iterations=snip_iterations)
return bg_data_vwu
[docs]
def destripe_wlf_vwu(
data: NDArray,
sigma: float = 0.005,
level: int = 1,
wavelet: str = "bior2.2",
angle_axis: int = -2,
other_axes: Union[Sequence[int], NDArray, None] = None,
) -> NDArray:
"""Remove stripes from sinogram, using the Wavelet-Fourier method.
Parameters
----------
data : NDArray
The data to de-stripe
sigma : float, optional
Fourier space filter coefficient, by default 0.005
level : int, optional
The wavelet level to use, by default 1
wavelet : str, optional
The type of wavelet to use, by default "bior2.2"
angle_axis : int, optional
The axis of the Fourier transform, by default -2
other_axes : Union[Sequence[int], NDArray, None], optional
The axes of the wavelet decomposition, by default None
Returns
-------
NDArray
The de-striped data.
"""
if other_axes is None:
other_axes = np.arange(-data.ndim, 0)
else:
other_axes = np.array(other_axes)
if angle_axis is other_axes:
other_axes = np.delete(other_axes, angle_axis)
level_power = 2**level
data_shape = np.array(data.shape)
target_shape = data_shape.copy()
target_shape[list(other_axes)] = np.ceil(data_shape[list(other_axes)] / level_power) * level_power
diff_size = target_shape - data_shape
padding = np.stack((diff_size - diff_size // 2, diff_size // 2), axis=-1)
data = np.pad(data, pad_width=padding, mode="edge")
coeffs = pywt.swtn(data, wavelet=wavelet, axes=other_axes, level=level)
for ii_l in range(level):
for wl_label, coeffs_l_wl in coeffs[ii_l].items():
if wl_label == "a" * len(other_axes):
continue
coeff_f = np.fft.rfft(coeffs_l_wl, axis=angle_axis)
filt_f = 1 - np.exp(-(np.fft.rfftfreq(coeffs_l_wl.shape[angle_axis]) ** 2) / (2 * sigma**2))
coeff_f *= filt_f[:, None]
coeffs[ii_l][wl_label] = np.fft.irfft(coeff_f, axis=angle_axis, n=coeffs_l_wl.shape[angle_axis])
data = pywt.iswtn(coeffs, wavelet=wavelet, axes=other_axes)
slicing = [slice(padding[ii, 0], data.shape[ii] - padding[ii, 1]) for ii in range(data.ndim)]
return data[tuple(slicing)]
[docs]
def compute_eigen_flats(
trans_wvu: NDArray,
flats_wvu: Optional[NDArray] = None,
darks_wvu: Optional[NDArray] = None,
ndim: int = 2,
plot: bool = False,
) -> tuple[NDArray, NDArray]:
"""Compute the eigen flats of a stack of transmission images.
Parameters
----------
trans : NDArray
The stack of transmission images.
flats : NDArray
The flats without sample.
darks : NDArray
The darks.
ndim : int, optional
The number of dimensions of the images, by default 2
plot : bool, optional
Whether to plot the results, by default False
Returns
-------
Tuple[NDArray, NDArray]
The decomposition of the transmissions of the sample and the flats.
"""
trans_shape = trans_wvu.shape
trans_num = np.prod(trans_shape[:-ndim])
img_shape = trans_shape[-ndim:]
stack_imgs = [trans_wvu.reshape([-1, *img_shape])]
if flats_wvu is not None:
stack_imgs.append(flats_wvu.reshape([-1, *img_shape]))
stack_imgs = np.concatenate(stack_imgs)
if darks_wvu is not None:
if darks_wvu.ndim > 2:
darks_wvu = darks_wvu.mean(axis=tuple(np.arange(darks_wvu.ndim - 2)))
stack_imgs = np.fmax(stack_imgs - darks_wvu, np.finfo(np.float32).eps)
stack_imgs = stack_imgs.reshape([-1, np.prod(img_shape)]).transpose()
stack_imgs = np.log(stack_imgs)
mat_u, sigma, mat_v_h = np.linalg.svd(stack_imgs, full_matrices=False)
eigen_projs: NDArray = (mat_u[..., 1:] * sigma[..., None, 1:]) @ mat_v_h[..., 1:, :]
eigen_projs = np.exp(eigen_projs)
eigen_projs = eigen_projs.transpose().reshape([-1, *img_shape])[:trans_num]
eigen_flats: NDArray = (mat_u[..., 0:1:] * sigma[..., None, 0:1:]) @ mat_v_h[..., 0:1:, :]
eigen_flats = np.exp(eigen_flats)
eigen_flats = eigen_flats.transpose().reshape([-1, *img_shape])[:trans_num]
if plot:
fig, axs = plt.subplots(1, 3, figsize=[10, 3.75])
axs[0].plot(sigma)
axs[0].grid()
axs[0].set_title("Singular values")
axs[1].imshow(mat_u[:, 0].reshape(img_shape))
axs[1].set_title("Highest value component")
axs[2].plot(eigen_flats.mean(axis=(-2, -1)).flatten())
axs[2].grid()
axs[2].set_title("Eigen intensities")
fig.tight_layout()
plt.show(block=False)
return eigen_projs.reshape(trans_shape), eigen_flats.reshape(trans_shape)