Source code for corrct.processing.pre

"""
Pre-processing routines.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Sequence
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import pywt
import scipy.ndimage as spimg
import skimage.transform as skt
from numpy.polynomial import Polynomial
from numpy.typing import DTypeLike, NDArray


eps = np.finfo(np.float32).eps


[docs] def pad_sinogram( sinogram: NDArray, width: Union[int, Sequence[int], NDArray], pad_axis: int = -1, mode: str = "edge", **kwds ) -> NDArray: """ Pad the sinogram. Parameters ---------- sinogram : NDArray The sinogram to pad. width : Union[int, Sequence[int]] The width of the padding. Normally, it should either be an int or a tuple(int, int). pad_axis : int, optional The axis to pad. The default is -1. mode : str, optional The padding type (from numpy.pad). The default is "edge". **kwds : The numpy.pad arguments. Returns ------- NDArray The padded sinogram. """ width = np.array(width, ndmin=1) pad_size = np.zeros((len(sinogram.shape), len(width)), dtype=int) pad_size[pad_axis, :] = width return np.pad(sinogram, pad_size, mode=mode.lower(), **kwds) # type: ignore
[docs] def apply_flat_field( projs_wvu: NDArray, flats_wvu: NDArray, darks_wvu: Optional[NDArray] = None, crop: Optional[Sequence[int]] = None, dtype: DTypeLike = np.float32, ) -> NDArray: """ Apply flat field. Parameters ---------- projs : NDArray Projections. flats : NDArray Flat fields. darks : Optional[NDArray], optional Dark noise. The default is None. crop : Optional[Sequence[int]], optional Crop region. The default is None. dtype : DTypeLike, optional Data type of the processed data. The default is np.float32. Returns ------- NDArray Falt-field corrected and linearized projections. """ projs_wvu = np.ascontiguousarray(projs_wvu, dtype=dtype) flats_wvu = np.ascontiguousarray(flats_wvu, dtype=dtype) if crop is not None: projs_wvu = projs_wvu[..., crop[0] : crop[2], crop[1] : crop[3]] flats_wvu = flats_wvu[..., crop[0] : crop[2], crop[1] : crop[3]] if darks_wvu is not None: darks_wvu = darks_wvu[..., crop[0] : crop[2], crop[1] : crop[3]] if darks_wvu is not None: darks_wvu = np.ascontiguousarray(darks_wvu, dtype=dtype) projs_wvu = projs_wvu - darks_wvu flats_wvu = flats_wvu - darks_wvu if flats_wvu.ndim == 3: flats_wvu = np.mean(flats_wvu, axis=0) return projs_wvu / flats_wvu
[docs] def apply_minus_log(projs: NDArray, lower_limit: float = -np.inf) -> NDArray: """ Apply -log. Parameters ---------- projs : NDArray Projections. Returns ------- NDArray Linearized projections. """ return np.fmax(-np.log(projs), lower_limit)
[docs] def rotate_proj_stack(data_vwu: NDArray, rot_angle_deg: float) -> NDArray: """ Rotate the projection stack. Parameters ---------- data_vwu : NDArray The projection stack, with dimensions [v, w, u] (vertical, omega / sample rotation, horizontal). rot_angle_deg : float The rotation angle in degrees. Returns ------- NDArray The rotated projection stack. """ data_vwu_r = np.empty_like(data_vwu) for ii in range(data_vwu.shape[-2]): data_vwu_r[:, ii, :] = skt.rotate(data_vwu[:, ii, :], -rot_angle_deg, clip=False) return data_vwu_r
[docs] def shift_proj_stack(data_vwu: NDArray, shifts: NDArray, use_fft: bool = False) -> NDArray: """Shift each projection in a stack of projections, by projection dependent shifts. Parameters ---------- data_vwu : NDArray The projection stack shifts : NDArray The shifts use_fft : bool, optional Whether to use fft shift or not, by default False Returns ------- NDArray The shifted stack """ new_data = np.empty_like(data_vwu) for ii in range(data_vwu.shape[-2]): if use_fft: img = data_vwu[..., ii, :] img_f = np.fft.rfftn(img) img_f = spimg.fourier_shift(img_f, shifts[..., ii], n=img.shape[-1]) new_data[..., ii, :] = np.fft.irfftn(img_f) else: new_data[..., ii, :] = spimg.shift(data_vwu[..., ii, :], shifts[..., ii], order=1, mode="nearest") return new_data
[docs] def bin_imgs(imgs: NDArray, binning: Union[int, float], auto_crop: bool = False, verbose: bool = True) -> NDArray: """Bin a stack of images. Parameters ---------- imgs : NDArray The stack of images. binning : int | float The binning factor. auto_crop : bool, optional Whether to automatically crop the images to match, by default False verbose : bool, optional Whether to print the image shapes, by default True Returns ------- NDArray The binned images """ if auto_crop: imgs_shape = imgs.shape excess_pixels_vu = (np.array(imgs.shape[-2:]) % binning).astype(int) crop_vu = (excess_pixels_vu - excess_pixels_vu // 2, np.array(imgs.shape[-2:]) - excess_pixels_vu // 2) imgs = imgs[..., crop_vu[0][0] : crop_vu[1][0], crop_vu[0][1] : crop_vu[1][1]] if verbose: print(f"Auto-cropping {crop_vu}: {imgs_shape} => {imgs.shape}") imgs_shape = imgs.shape if isinstance(binning, int): binned_shape = (*imgs_shape[:-2], imgs_shape[-2] // binning, imgs_shape[-1] // binning) imgs = imgs.reshape([*binned_shape[:-1], binning, binned_shape[-1], binning]) imgs = imgs.mean(axis=(-3, -1)) else: imgs = imgs.reshape([-1, *imgs_shape[-2:]]) imgs = skt.rescale(imgs, 1 / binning, channel_axis=0) binned_shape = [*imgs_shape[:-2], *imgs.shape[-2:]] imgs = imgs.reshape(binned_shape) if verbose: print(f"Binning {binning}: {imgs_shape} => {binned_shape}") return imgs
[docs] def background_from_margin( data_vwu: NDArray, margin: Union[int, Sequence[int], NDArray[np.integer]] = 4, poly_order: int = 0, plot: bool = False ) -> NDArray: """Compute background of the projection data, from the margins of the projections. Parameters ---------- data_vwu : NDArray The projection data in the format [V]WU. margin : int | Sequence[int] | NDArray[np.integer], optional The size of the margin, by default 4 poly_order : int, optional The order of the interpolation polynomial, by default 0 Returns ------- NDArray The computed background. Raises ------ NotImplementedError Different margins per line are not supported, at the moment. ValueError In case the margins ar larger than the image size in U. """ data_shape_u = data_vwu.shape[-1] data_shape_w = data_vwu.shape[-2] margin = np.array(margin, dtype=int, ndmin=1) if margin.size == 1: margin = np.tile(margin, [*np.ones(margin.ndim - 1), 2]) if margin.ndim > 1: raise NotImplementedError("Complex masks support has not been implemented, yet.") if margin.sum() > data_shape_u: raise ValueError(f"Margin size ({margin}) should be smaller than the image size in U ({data_shape_u})") if poly_order > 0 and np.any(margin == 0): print("WARNING: parameter `poly_order` cannot be greater than 0 if one of the margins is 0") poly_order = 0 if poly_order > 0: ydata = np.concatenate([data_vwu[..., : margin[0]], data_vwu[..., -margin[1] :]], axis=-1) xdata = np.concatenate([np.arange(0, margin[0]), np.arange(data_shape_u - margin[1], data_shape_u)]) if data_vwu.ndim > 2: ydata = ydata.mean(axis=-3) background = np.empty([data_shape_w, data_shape_u], dtype=data_vwu.dtype) for ii_w in range(data_shape_w): poly = Polynomial.fit(xdata, ydata[ii_w], deg=poly_order) background[ii_w, :] = poly(np.arange(data_shape_u)) if plot: fig, axs = plt.subplots(1, 1) axs.plot(background[0]) axs.scatter(xdata, ydata[0]) axs.grid() axs.set_ylim(0) fig.tight_layout() plt.show(block=False) if data_vwu.ndim > 2: background = np.tile(background[None, ...], [data_vwu.shape[-3], 1, 1]) return background else: sum_vals = data_vwu[..., : margin[0]].sum(axis=-1) + data_vwu[..., -margin[1] :].sum(axis=-1) background: NDArray = sum_vals / margin.sum(axis=-1) if data_vwu.ndim > 2: background = background.mean(axis=-2, keepdims=True) return np.tile(background[..., None], [*np.ones(background.ndim, dtype=int), data_shape_u])
[docs] def destripe_wlf_vwu( data: NDArray, sigma: float = 0.005, level: int = 1, wavelet: str = "bior2.2", angle_axis: int = -2, other_axes: Union[Sequence[int], NDArray, None] = None, ) -> NDArray: """Remove stripes from sinogram, using the Wavelet-Fourier method. Parameters ---------- data : NDArray The data to de-stripe sigma : float, optional Fourier space filter coefficient, by default 0.005 level : int, optional The wavelet level to use, by default 1 wavelet : str, optional The type of wavelet to use, by default "bior2.2" angle_axis : int, optional The axis of the Fourier transform, by default -2 other_axes : Union[Sequence[int], NDArray, None], optional The axes of the wavelet decomposition, by default None Returns ------- NDArray The de-striped data. """ if other_axes is None: other_axes = np.arange(-data.ndim, 0) else: other_axes = np.array(other_axes) if angle_axis is other_axes: other_axes = np.delete(other_axes, angle_axis) level_power = 2**level data_shape = np.array(data.shape) target_shape = data_shape.copy() target_shape[list(other_axes)] = np.ceil(data_shape[list(other_axes)] / level_power) * level_power diff_size = target_shape - data_shape padding = np.stack((diff_size - diff_size // 2, diff_size // 2), axis=-1) data = np.pad(data, pad_width=padding, mode="edge") coeffs = pywt.swtn(data, wavelet=wavelet, axes=other_axes, level=level) for ii_l in range(level): for wl_label, coeffs_l_wl in coeffs[ii_l].items(): if wl_label == "a" * len(other_axes): continue coeff_f = np.fft.rfft(coeffs_l_wl, axis=angle_axis) filt_f = 1 - np.exp(-(np.fft.rfftfreq(coeffs_l_wl.shape[angle_axis]) ** 2) / (2 * sigma**2)) coeff_f *= filt_f[:, None] coeffs[ii_l][wl_label] = np.fft.irfft(coeff_f, axis=angle_axis, n=coeffs_l_wl.shape[angle_axis]) data = pywt.iswtn(coeffs, wavelet=wavelet, axes=other_axes) slicing = [slice(padding[ii, 0], data.shape[ii] - padding[ii, 1]) for ii in range(data.ndim)] return data[tuple(slicing)]
[docs] def compute_eigen_flats( trans_wvu: NDArray, flats_wvu: Optional[NDArray] = None, darks_wvu: Optional[NDArray] = None, ndim: int = 2, plot: bool = False, ) -> tuple[NDArray, NDArray]: """Compute the eigen flats of a stack of transmission images. Parameters ---------- trans : NDArray The stack of transmission images. flats : NDArray The flats without sample. darks : NDArray The darks. ndim : int, optional The number of dimensions of the images, by default 2 plot : bool, optional Whether to plot the results, by default False Returns ------- Tuple[NDArray, NDArray] The decomposition of the transmissions of the sample and the flats. """ trans_shape = trans_wvu.shape trans_num = np.prod(trans_shape[:-ndim]) img_shape = trans_shape[-ndim:] stack_imgs = [trans_wvu.reshape([-1, *img_shape])] if flats_wvu is not None: stack_imgs.append(flats_wvu.reshape([-1, *img_shape])) stack_imgs = np.concatenate(stack_imgs) if darks_wvu is not None: if darks_wvu.ndim > 2: darks_wvu = darks_wvu.mean(axis=tuple(np.arange(darks_wvu.ndim - 2))) stack_imgs = np.fmax(stack_imgs - darks_wvu, np.finfo(np.float32).eps) stack_imgs = stack_imgs.reshape([-1, np.prod(img_shape)]).transpose() stack_imgs = np.log(stack_imgs) mat_u, sigma, mat_v_h = np.linalg.svd(stack_imgs, full_matrices=False) eigen_projs: NDArray = (mat_u[..., 1:] * sigma[..., None, 1:]) @ mat_v_h[..., 1:, :] eigen_projs = np.exp(eigen_projs) eigen_projs = eigen_projs.transpose().reshape([-1, *img_shape])[:trans_num] eigen_flats: NDArray = (mat_u[..., 0:1:] * sigma[..., None, 0:1:]) @ mat_v_h[..., 0:1:, :] eigen_flats = np.exp(eigen_flats) eigen_flats = eigen_flats.transpose().reshape([-1, *img_shape])[:trans_num] if plot: fig, axs = plt.subplots(1, 3, figsize=[10, 3.75]) axs[0].plot(sigma) axs[0].grid() axs[0].set_title("Singular values") axs[1].imshow(mat_u[:, 0].reshape(img_shape)) axs[1].set_title("Highest value component") axs[2].plot(eigen_flats.mean(axis=(-2, -1)).flatten()) axs[2].grid() axs[2].set_title("Eigen intensities") fig.tight_layout() plt.show(block=False) return eigen_projs.reshape(trans_shape), eigen_flats.reshape(trans_shape)