Source code for corrct.processing.misc

"""
Miscellaneous processing routines.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Callable, Sequence

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from numpy.typing import ArrayLike, DTypeLike, NDArray

eps = np.finfo(np.float32).eps

NDArrayInt = NDArray[np.signedinteger]


[docs] def circular_mask( vol_shape_zxy: Sequence[int] | NDArrayInt, radius_offset: float = 0, coords_ball: Sequence[int] | NDArrayInt | None = None, ball_norm: float = 2, vol_origin_zxy: Sequence[float] | NDArray | None = None, taper_func: str | None = None, taper_target: str | float = "edge", super_sampling: int = 1, squeeze: bool = True, dtype: DTypeLike = np.float32, ) -> NDArray: """ Compute a circular mask for the reconstruction volume. Parameters ---------- vol_shape_zxy : Sequence[int] | NDArrayInt The size of the volume. radius_offset : float, optional The offset with respect to the volume edge. The default is 0. coords_ball : Sequence[int] | NDArrayInt | None, optional The coordinates to consider for the non-masked region. The default is None. ball_norm : float, optional The norm of the ball. The default is 2. vol_origin_zxy : Sequence[float] | None, optional The origin of the coordinates in voxels. The default is None. taper_func : str, optional The mask data type. Allowed types: "const" | "cos". The default is "const". taper_target : str | float, optional The size target. Allowed values: "edge" | "diagonal". The default is "edge". super_sampling : int, optional The pixel super sampling to be used for the mask. The default is 1. squeeze : bool, optional Whether to squeeze the mask. The default is True. dtype : DTypeLike, optional The type of mask. The default is np.float32. Raises ------ ValueError In case of unknown taper_func value, or mismatching volume origin and shape. Returns ------- NDArray The circular mask. """ vol_shape_zxy_s = np.array(vol_shape_zxy, dtype=int) * super_sampling coords = [ np.linspace(-(s - 1) / (2 * super_sampling), (s - 1) / (2 * super_sampling), s, dtype=dtype) for s in vol_shape_zxy_s ] if vol_origin_zxy is not None: if len(coords) != len(vol_origin_zxy): raise ValueError(f"The volume shape ({len(coords)}), and the origin shape ({len(vol_origin_zxy)}) should match") coords = [c + vol_origin_zxy[ii] for ii, c in enumerate(coords)] coords = np.meshgrid(*coords, indexing="ij") if coords_ball is None: coords_ball = np.arange(-np.fmin(2, len(vol_shape_zxy_s)), 0, dtype=int) else: coords_ball = np.array(coords_ball, dtype=int) max_radius = np.min(vol_shape_zxy_s[coords_ball]) / (2 * super_sampling) + radius_offset coords = np.stack(coords, axis=0) if coords_ball.size == 1: dists = np.abs(coords[coords_ball, ...]) else: dists = np.linalg.norm(coords[coords_ball, ...], axis=0, ord=ball_norm) if taper_func is None: mask = (dists <= max_radius).astype(dtype) elif isinstance(taper_func, str): if isinstance(taper_target, str): if taper_target.lower() == "edge": cut_off_denom = 2 cut_off_offset = 0 elif taper_target.lower() == "diagonal": cut_off_denom = np.sqrt(2) cut_off_offset = 0 else: raise ValueError( f"When `taper_target` is str, it should be one of: 'edge' | 'diagonal', but {taper_target} passed instead." ) else: cut_off_denom = 2 if taper_target < radius_offset: print(f"WARNING: parameter `taper_target`={taper_target} is smaller than `radius_offset`={radius_offset}.") cut_off_offset = np.fmax(taper_target, radius_offset) if taper_func.lower() == "cos": cut_off_radius = np.min(vol_shape_zxy_s[coords_ball]) / (cut_off_denom * super_sampling) + cut_off_offset cut_off_size = cut_off_radius - max_radius outter_vals = np.cos(np.fmax(dists - max_radius, 0) / cut_off_size * np.pi) / 2 + 0.5 mask = (outter_vals * (dists < cut_off_radius)).astype(dtype) else: raise ValueError(f"Unknown taper function: {taper_func}") else: raise ValueError(f"Parameter `taper_func` should either be a string or None.") if super_sampling > 1: new_shape = np.stack([np.array(vol_shape_zxy), np.ones_like(vol_shape_zxy) * super_sampling], axis=1).flatten() mask = mask.reshape(new_shape) mask = np.mean(mask, axis=tuple(np.arange(1, len(vol_shape_zxy) * 2, 2, dtype=int))) if squeeze: mask = np.squeeze(mask) return mask
[docs] def ball( data_shape_vu: ArrayLike, radius: int | float, super_sampling: int = 5, dtype: DTypeLike = np.float32, func: Callable | None = None, ) -> ArrayLike: """ Compute a ball with specified radius. Parameters ---------- data_shape_vu : ArrayLike Shape of the output array. radius : int | float Radius of the ball. super_sampling : int, optional Super-sampling for having smoother ball edges. The default is 5. dtype : DTypeLike, optional Type of the output. The default is np.float32. func : Callable | None, optional Point-wise function for the local values. The default is None. Returns ------- ArrayLike The ball. """ data_shape_vu = np.array(data_shape_vu, dtype=int) * super_sampling # coords = [np.linspace(-(s - 1) / 2, (s - 1) / 2, s, dtype=np.float32) for s in data_shape_vu] coords = [np.fft.fftfreq(d, 1 / d) for d in data_shape_vu] coords = np.stack(np.meshgrid(*coords, indexing="ij"), axis=0) r = np.sqrt(np.sum(coords**2, axis=0)) / super_sampling probe = (r < radius).astype(dtype) if func is not None: probe *= func(r) probe = np.roll(probe, super_sampling // 2, axis=tuple(np.arange(len(data_shape_vu)))) new_shape = np.stack([data_shape_vu // super_sampling, np.ones_like(data_shape_vu) * super_sampling], axis=1).flatten() probe = probe.reshape(new_shape) probe = np.mean(probe, axis=tuple(np.arange(1, len(data_shape_vu) * 2, 2, dtype=int))) return np.fft.fftshift(probe)
[docs] def azimuthal_integration(img: NDArray, axes: Sequence[int] = (-2, -1), domain: str = "direct") -> NDArray: """ Compute the azimuthal integration of a n-dimensional image or a stack of them. Parameters ---------- img : NDArray The image or stack of images. axes : tuple(int, int), optional Axes of that need to be azimuthally integrated. The default is (-2, -1). domain : string, optional Domain of the integration. Options are: "direct" | "fourier". Default is "direct". Raises ------ ValueError Error returned when not passing images or wrong axes. Returns ------- NDArray The azimuthally integrated profile. """ num_dims_int = len(axes) num_dims_img = len(img.shape) if num_dims_img < num_dims_int: raise ValueError( "Input image ({num_dims_img}D) should be at least the same dimensionality" " of the axes for the integration (#{num_dims_int})." ) if len(axes) == 0: raise ValueError("Input axes should be at least 1.") # Compute the coordinates of the pixels along the chosen axes img_axes_dims = np.array(np.array(img.shape)[list(axes)], ndmin=1) if domain.lower() == "direct": half_dims = (img_axes_dims - 1) / 2 coords = [np.linspace(-h, h, d) for h, d in zip(half_dims, img_axes_dims)] else: coords = [np.fft.fftfreq(d, 1 / d) for d in img_axes_dims] coords = np.stack(np.meshgrid(*coords, indexing="ij")) r = np.sqrt(np.sum(coords**2, axis=0)) # Reshape the volume to have the axes to be integrates as right-most axes img_tr_op = np.array([*range(len(img.shape))]) img_tr_op = np.concatenate((np.delete(img_tr_op, obj=axes), img_tr_op[list(axes)])) img = np.transpose(img, img_tr_op) if num_dims_img > num_dims_int: img_old_shape = img.shape[:-num_dims_int] img = np.reshape(img, [-1, *img_axes_dims]) # Compute the linear interpolation coefficients r_l = np.floor(r) r_u = r_l + 1 w_l = (r_u - r) * img w_u = (r - r_l) * img # Do the azimuthal integration as a histogram operation r_all = np.concatenate((r_l.flatten(), r_u.flatten())).astype(int) if num_dims_img > num_dims_int: num_imgs = img.shape[0] az_img = [] for ii in range(num_imgs): w_all = np.concatenate((w_l[ii, ...].flatten(), w_u[ii, ...].flatten())) az_img.append(np.bincount(r_all, weights=w_all)) az_img = np.array(az_img) return np.reshape(az_img, (*img_old_shape, az_img.shape[-1])) # type: ignore else: w_all = np.concatenate((w_l.flatten(), w_u.flatten())) return np.bincount(r_all, weights=w_all)
[docs] def lines_intersection( line_1: NDArray, line_2: float | NDArray, position: str = "first", x_lims: tuple[float | None, float | None] | None = None, ) -> tuple[float, float] | None: """ Compute the intersection point between two lines. Parameters ---------- line_1 : NDArray The first line. line_2 : float | NDArray The second line. It can be a scalar representing a horizontal line. position : str, optional The position of the point to select. Either "first" or "last". The default is "first". Raises ------ ValueError If position is neither "first" nor "last". Returns ------- Tuple[float, float] | None It returns either the requested crossing point, or None in case the point was not found. """ line_1 = np.array(np.squeeze(line_1), ndmin=1) line_2 = np.array(np.squeeze(line_2), ndmin=1) # Find the transition points, by first finding where line_2 is above line_1 crossing_points = np.where(line_2 > line_1, 0, 1) crossing_points = np.abs(np.diff(crossing_points)) if x_lims is not None: if x_lims[0] is None: if x_lims[1] is None: raise ValueError("When passing `x_lims`, at least one of the values should not be None.") else: bias = 0 crossing_points = crossing_points[: x_lims[1]] else: bias = x_lims[0] if x_lims[1] is None: crossing_points = crossing_points[x_lims[0] :] else: crossing_points = crossing_points[x_lims[0] : x_lims[1]] else: bias = 0 crossing_points = np.where(crossing_points)[0] if crossing_points.size == 0: print("No crossing found!") return None if position.lower() == "first": point_l = crossing_points[0] + bias elif position.lower() == "last": point_l = crossing_points[-1] + bias else: raise ValueError(f"Crossing position: {position} unknown. Please choose either 'first' or 'last'.") x1 = 0.0 x2 = 1.0 y1 = line_1[point_l] y2 = line_1[point_l + 1] x3 = 0.0 x4 = 1.0 if line_2.size == 1: y3 = line_2 y4 = line_2 else: y3 = line_2[point_l] y4 = line_2[point_l + 1] # From wikipedia: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection#Given_two_points_on_each_line p_den = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) p_x_num = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4) p_y_num = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4) p_x = p_x_num / p_den + point_l p_y = p_y_num / p_den return float(p_x), float(p_y)
[docs] def norm_cross_corr( img1: NDArray, img2: NDArray | None = None, axes: Sequence[int] = (-2, -1), t_match: bool = False, mode_full: bool = True, compute_profile: bool = True, plot: bool = True, ) -> NDArray | tuple[NDArray, NDArray]: """ Compute the normalized cross-correlation between two images. Parameters ---------- img1 : NDArray The first image. img2 : NDArray, optional The second images. If None, it computes the auto-correlation. The default is None. axes : Sequence[int], optional Axes along which to compute the cross-correlation. The default is (-2, -1). t_match : bool, optional Whether to perform the cross-correlation for template matching. The default is False. mode_full : bool, optional Whether to return the "full" or "same" convolution size. The default is True. compute_profile : bool, optional Whether to compute the azimuthal integration of the cross-correlation or not. The default is True. plot : bool, optional Whether to plot the profile of the cross-correlation curve. The default is True. Returns ------- NDArray The one-dimensional cross-correlation curve. """ def local_sum(x: NDArray, axes: Sequence[int]) -> NDArray: padding = np.zeros(len(x.shape), dtype=int) for a in axes: padding[a] = x.shape[a] y: NDArray[np.floating] = np.pad(x, padding) for a in axes: y = np.cumsum(y, axis=a) slicing1 = [slice(None)] * len(x.shape) slicing1[a] = slice(x.shape[a], -1) slicing2 = [slice(None)] * len(x.shape) slicing2[a] = slice(0, -x.shape[a] - 1) y = y[tuple(slicing1)] - y[tuple(slicing2)] return y if img2 is None: is_autocorrelation = True img2 = img1 else: is_autocorrelation = False img1 = img1.astype(np.float32) img2 = img2.astype(np.float32) for a in axes: img2 = np.flip(img2, axis=a) cc = sp.signal.fftconvolve(img1, img2, mode="full", axes=axes) if not mode_full: slices = [slice(None)] * len(cc.shape) for a in axes: start_ind = (cc.shape[a] - img1.shape[a]) // 2 end_ind = start_ind + img1.shape[a] slices[a] = slice(start_ind, end_ind) slices = tuple(slices) cc = cc[slices] if t_match: local_sums_img2 = local_sum(img2, axes=axes) local_sums_img2_2 = local_sum(img2**2, axes=axes) if not mode_full: local_sums_img2 = local_sums_img2[slices] local_sums_img2_2 = local_sums_img2_2[slices] cc_n = cc - local_sums_img2 * np.mean(img1) cc_n /= np.std(img1) * np.sqrt(np.prod(np.array(img1.shape)[list(axes)])) diff_local_sums = local_sums_img2_2 - (local_sums_img2**2) / np.prod(np.array(img2.shape)[list(axes)]) cc_n /= np.sqrt(diff_local_sums.clip(0, None)) else: cc_n = cc / (np.linalg.norm(img1) * np.linalg.norm(img2)) cc_n = np.fft.ifftshift(cc_n) if compute_profile: cc_l = azimuthal_integration(cc_n, axes=axes, domain="fourier") cc_o = azimuthal_integration(np.ones_like(cc_n), axes=axes, domain="fourier") cc_l /= cc_o cc_l = cc_l[: np.min(np.array(img1.shape)[list(axes)])] if plot: p_xy = lines_intersection(cc_l, 0.5, position="first") fig, axs = plt.subplots(1, 1, figsize=(7, 3.5)) axs.plot(cc_l, label="Auto-correlation" if is_autocorrelation else "Cross-correlation") axs.plot(np.ones_like(cc_l) * 0.5, label="Half-maximum") if p_xy is not None: axs.scatter(p_xy[0], p_xy[1]) axs.plot([p_xy[0], p_xy[0]], [0, 1], label=f"Resolution: {p_xy[0]:.3} pix") axs.grid() axs.legend(fontsize=13) axs.tick_params(labelsize=16) axs.set_title("Cross-correlation") fig.tight_layout() plt.show(block=False) return cc_n, cc_l else: return cc_n
[docs] def inspect_fourier_img(img: NDArray, remove_zero: bool = False) -> None: """Display Fourier representation of the input image. Parameters ---------- img : NDArray Input image. remove_zero : bool, optional Remove the zero frequency value. The default is False. """ img_f = np.squeeze(np.fft.fft2(img)) if remove_zero is True: img_f[0, 0] = 0 img_f_sh = np.fft.fftshift(img_f) f, axs = plt.subplots(2, 3) f.suptitle("Fourier representation") axs[0, 0].imshow(np.real(img_f_sh)) axs[0, 0].set_title("Real") axs[0, 1].imshow(np.imag(img_f_sh)) axs[0, 1].set_title("Imag") axs[0, 2].imshow(np.abs(img_f_sh)) axs[0, 2].set_title("Abs") vert_img_f = np.fft.fftshift(img_f[:, 0]) axs[1, 0].plot(np.stack((np.real(vert_img_f), np.imag(vert_img_f), np.abs(vert_img_f)), axis=1)) axs[1, 0].set_title("Vertical profiles") horz_img_f = np.fft.fftshift(img_f[0, :]) axs[1, 1].plot(np.stack((np.real(horz_img_f), np.imag(horz_img_f), np.abs(horz_img_f)), axis=1)) axs[1, 1].set_title("Horizontal profiles") diag_img_f = np.fft.fftshift(np.diag(img_f)) axs[1, 2].plot(np.stack((np.real(diag_img_f), np.imag(diag_img_f), np.abs(diag_img_f)), axis=1)) axs[1, 2].set_title("Diagonal profiles") plt.show(block=False)