"""
Miscellaneous processing routines.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
import numpy as np
import scipy as sp
from typing import Optional, Union, Callable
from collections.abc import Sequence
from numpy.typing import NDArray, ArrayLike, DTypeLike
import matplotlib.pyplot as plt
eps = np.finfo(np.float32).eps
NDArrayInt = NDArray[np.signedinteger]
[docs]
def circular_mask(
vol_shape_zxy: Union[Sequence[int], NDArrayInt],
radius_offset: float = 0,
coords_ball: Union[Sequence[int], NDArrayInt, None] = None,
vol_origin_zxy: Union[Sequence[float], NDArray, None] = None,
taper_func: Optional[str] = None,
taper_target: Union[str, float] = "edge",
super_sampling: int = 1,
squeeze: bool = True,
dtype: DTypeLike = np.float32,
) -> NDArray:
"""
Compute a circular mask for the reconstruction volume.
Parameters
----------
vol_shape_zxy : Sequence[int] | NDArrayInt
The size of the volume.
radius_offset : float, optional
The offset with respect to the volume edge. The default is 0.
coords_ball : Sequence[int] | NDArrayInt | None, optional
The coordinates to consider for the non-masked region. The default is None.
vol_origin_zxy : Optional[Sequence[float]], optional
The origin of the coordinates in voxels. The default is None.
taper_func : str, optional
The mask data type. Allowed types: "const" | "cos". The default is "const".
taper_target : str | float, optional
The size target. Allowed values: "edge" | "diagonal". The default is "edge".
super_sampling : int, optional
The pixel super sampling to be used for the mask. The default is 1.
squeeze : bool, optional
Whether to squeeze the mask. The default is True.
dtype : DTypeLike, optional
The type of mask. The default is np.float32.
Raises
------
ValueError
In case of unknown taper_func value, or mismatching volume origin and shape.
Returns
-------
NDArray
The circular mask.
"""
vol_shape_zxy_s = np.array(vol_shape_zxy, dtype=int) * super_sampling
coords = [
np.linspace(-(s - 1) / (2 * super_sampling), (s - 1) / (2 * super_sampling), s, dtype=dtype) for s in vol_shape_zxy_s
]
if vol_origin_zxy is not None:
if len(coords) != len(vol_origin_zxy):
raise ValueError(f"The volume shape ({len(coords)}), and the origin shape ({len(vol_origin_zxy)}) should match")
coords = [c + vol_origin_zxy[ii] for ii, c in enumerate(coords)]
coords = np.meshgrid(*coords, indexing="ij")
if coords_ball is None:
coords_ball = np.arange(-np.fmin(2, len(vol_shape_zxy_s)), 0, dtype=int)
else:
coords_ball = np.array(coords_ball, dtype=int)
max_radius = np.min(vol_shape_zxy_s[coords_ball]) / (2 * super_sampling) + radius_offset
coords = np.stack(coords, axis=0)
if coords_ball.size == 1:
dists = np.abs(coords[coords_ball, ...])
else:
dists = np.sqrt(np.sum(coords[coords_ball, ...] ** 2, axis=0))
if taper_func is None:
mask = (dists <= max_radius).astype(dtype)
elif isinstance(taper_func, str):
if isinstance(taper_target, str):
if taper_target.lower() == "edge":
cut_off_denom = 2
cut_off_offset = 0
elif taper_target.lower() == "diagonal":
cut_off_denom = np.sqrt(2)
cut_off_offset = 0
else:
raise ValueError(
f"When `taper_target` is str, it should be one of: 'edge' | 'diagonal', but {taper_target} passed instead."
)
else:
cut_off_denom = 2
if taper_target < radius_offset:
print(f"WARNING: parameter `taper_target`={taper_target} is smaller than `radius_offset`={radius_offset}.")
cut_off_offset = np.fmax(taper_target, radius_offset)
if taper_func.lower() == "cos":
cut_off_radius = np.min(vol_shape_zxy_s[coords_ball]) / (cut_off_denom * super_sampling) + cut_off_offset
cut_off_size = cut_off_radius - max_radius
outter_vals = np.cos(np.fmax(dists - max_radius, 0) / cut_off_size * np.pi) / 2 + 0.5
mask = (outter_vals * (dists < cut_off_radius)).astype(dtype)
else:
raise ValueError(f"Unknown taper function: {taper_func}")
else:
raise ValueError(f"Parameter `taper_func` should either be a string or None.")
if super_sampling > 1:
new_shape = np.stack([np.array(vol_shape_zxy), np.ones_like(vol_shape_zxy) * super_sampling], axis=1).flatten()
mask = mask.reshape(new_shape)
mask = np.mean(mask, axis=tuple(np.arange(1, len(vol_shape_zxy) * 2, 2, dtype=int)))
if squeeze:
mask = np.squeeze(mask)
return mask
[docs]
def ball(
data_shape_vu: ArrayLike,
radius: Union[int, float],
super_sampling: int = 5,
dtype: DTypeLike = np.float32,
func: Optional[Callable] = None,
) -> ArrayLike:
"""
Compute a ball with specified radius.
Parameters
----------
data_shape_vu : ArrayLike
Shape of the output array.
radius : int | float
Radius of the ball.
super_sampling : int, optional
Super-sampling for having smoother ball edges. The default is 5.
dtype : DTypeLike, optional
Type of the output. The default is np.float32.
func : Optional[Callable], optional
Point-wise function for the local values. The default is None.
Returns
-------
ArrayLike
The ball.
"""
data_shape_vu = np.array(data_shape_vu, dtype=int) * super_sampling
# coords = [np.linspace(-(s - 1) / 2, (s - 1) / 2, s, dtype=np.float32) for s in data_shape_vu]
coords = [np.fft.fftfreq(d, 1 / d) for d in data_shape_vu]
coords = np.stack(np.meshgrid(*coords, indexing="ij"), axis=0)
r = np.sqrt(np.sum(coords**2, axis=0)) / super_sampling
probe = (r < radius).astype(dtype)
if func is not None:
probe *= func(r)
probe = np.roll(probe, super_sampling // 2, axis=tuple(np.arange(len(data_shape_vu))))
new_shape = np.stack([data_shape_vu // super_sampling, np.ones_like(data_shape_vu) * super_sampling], axis=1).flatten()
probe = probe.reshape(new_shape)
probe = np.mean(probe, axis=tuple(np.arange(1, len(data_shape_vu) * 2, 2, dtype=int)))
return np.fft.fftshift(probe)
[docs]
def azimuthal_integration(img: NDArray, axes: Sequence[int] = (-2, -1), domain: str = "direct") -> NDArray:
"""
Compute the azimuthal integration of a n-dimensional image or a stack of them.
Parameters
----------
img : NDArray
The image or stack of images.
axes : tuple(int, int), optional
Axes of that need to be azimuthally integrated. The default is (-2, -1).
domain : string, optional
Domain of the integration. Options are: "direct" | "fourier". Default is "direct".
Raises
------
ValueError
Error returned when not passing images or wrong axes.
Returns
-------
NDArray
The azimuthally integrated profile.
"""
num_dims_int = len(axes)
num_dims_img = len(img.shape)
if num_dims_img < num_dims_int:
raise ValueError(
"Input image ({num_dims_img}D) should be at least the same dimensionality"
" of the axes for the integration (#{num_dims_int})."
)
if len(axes) == 0:
raise ValueError("Input axes should be at least 1.")
# Compute the coordinates of the pixels along the chosen axes
img_axes_dims = np.array(np.array(img.shape)[list(axes)], ndmin=1)
if domain.lower() == "direct":
half_dims = (img_axes_dims - 1) / 2
coords = [np.linspace(-h, h, d) for h, d in zip(half_dims, img_axes_dims)]
else:
coords = [np.fft.fftfreq(d, 1 / d) for d in img_axes_dims]
coords = np.stack(np.meshgrid(*coords, indexing="ij"))
r = np.sqrt(np.sum(coords**2, axis=0))
# Reshape the volume to have the axes to be integrates as right-most axes
img_tr_op = np.array([*range(len(img.shape))])
img_tr_op = np.concatenate((np.delete(img_tr_op, obj=axes), img_tr_op[list(axes)]))
img = np.transpose(img, img_tr_op)
if num_dims_img > num_dims_int:
img_old_shape = img.shape[:-num_dims_int]
img = np.reshape(img, [-1, *img_axes_dims])
# Compute the linear interpolation coefficients
r_l = np.floor(r)
r_u = r_l + 1
w_l = (r_u - r) * img
w_u = (r - r_l) * img
# Do the azimuthal integration as a histogram operation
r_all = np.concatenate((r_l.flatten(), r_u.flatten())).astype(int)
if num_dims_img > num_dims_int:
num_imgs = img.shape[0]
az_img = []
for ii in range(num_imgs):
w_all = np.concatenate((w_l[ii, ...].flatten(), w_u[ii, ...].flatten()))
az_img.append(np.bincount(r_all, weights=w_all))
az_img = np.array(az_img)
return np.reshape(az_img, (*img_old_shape, az_img.shape[-1])) # type: ignore
else:
w_all = np.concatenate((w_l.flatten(), w_u.flatten()))
return np.bincount(r_all, weights=w_all)
[docs]
def lines_intersection(
line_1: NDArray,
line_2: Union[float, NDArray],
position: str = "first",
x_lims: Optional[tuple[Optional[float], Optional[float]]] = None,
) -> Optional[tuple[float, float]]:
"""
Compute the intersection point between two lines.
Parameters
----------
line_1 : NDArray
The first line.
line_2 : float | NDArray
The second line. It can be a scalar representing a horizontal line.
position : str, optional
The position of the point to select. Either "first" or "last".
The default is "first".
Raises
------
ValueError
If position is neither "first" nor "last".
Returns
-------
Tuple[float, float] | None
It returns either the requested crossing point, or None in case the
point was not found.
"""
line_1 = np.array(np.squeeze(line_1), ndmin=1)
line_2 = np.array(np.squeeze(line_2), ndmin=1)
# Find the transition points, by first finding where line_2 is above line_1
crossing_points = np.where(line_2 > line_1, 0, 1)
crossing_points = np.abs(np.diff(crossing_points))
if x_lims is not None:
if x_lims[0] is None:
if x_lims[1] is None:
raise ValueError("When passing `x_lims`, at least one of the values should not be None.")
else:
bias = 0
crossing_points = crossing_points[: x_lims[1]]
else:
bias = x_lims[0]
if x_lims[1] is None:
crossing_points = crossing_points[x_lims[0] :]
else:
crossing_points = crossing_points[x_lims[0] : x_lims[1]]
else:
bias = 0
crossing_points = np.where(crossing_points)[0]
if crossing_points.size == 0:
print("No crossing found!")
return None
if position.lower() == "first":
point_l = crossing_points[0] + bias
elif position.lower() == "last":
point_l = crossing_points[-1] + bias
else:
raise ValueError(f"Crossing position: {position} unknown. Please choose either 'first' or 'last'.")
x1 = 0.0
x2 = 1.0
y1 = line_1[point_l]
y2 = line_1[point_l + 1]
x3 = 0.0
x4 = 1.0
if line_2.size == 1:
y3 = line_2
y4 = line_2
else:
y3 = line_2[point_l]
y4 = line_2[point_l + 1]
# From wikipedia: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection#Given_two_points_on_each_line
p_den = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
p_x_num = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
p_y_num = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
p_x = p_x_num / p_den + point_l
p_y = p_y_num / p_den
return float(p_x), float(p_y)
[docs]
def norm_cross_corr(
img1: NDArray,
img2: Optional[NDArray] = None,
axes: Sequence[int] = (-2, -1),
t_match: bool = False,
mode_full: bool = True,
compute_profile: bool = True,
plot: bool = True,
) -> Union[NDArray, tuple[NDArray, NDArray]]:
"""
Compute the normalized cross-correlation between two images.
Parameters
----------
img1 : NDArray
The first image.
img2 : NDArray, optional
The second images. If None, it computes the auto-correlation. The default is None.
axes : Sequence[int], optional
Axes along which to compute the cross-correlation. The default is (-2, -1).
t_match : bool, optional
Whether to perform the cross-correlation for template matching. The default is False.
mode_full : bool, optional
Whether to return the "full" or "same" convolution size. The default is True.
compute_profile : bool, optional
Whether to compute the azimuthal integration of the cross-correlation or not. The default is True.
plot : bool, optional
Whether to plot the profile of the cross-correlation curve. The default is True.
Returns
-------
NDArray
The one-dimensional cross-correlation curve.
"""
def local_sum(x: NDArray, axes: Sequence[int]) -> NDArray:
padding = np.zeros(len(x.shape), dtype=int)
for a in axes:
padding[a] = x.shape[a]
y: NDArray[np.floating] = np.pad(x, padding)
for a in axes:
y = np.cumsum(y, axis=a)
slicing1 = [slice(None)] * len(x.shape)
slicing1[a] = slice(x.shape[a], -1)
slicing2 = [slice(None)] * len(x.shape)
slicing2[a] = slice(0, -x.shape[a] - 1)
y = y[tuple(slicing1)] - y[tuple(slicing2)]
return y
if img2 is None:
img2 = img1
img1 = img1.astype(np.float32)
img2 = img2.astype(np.float32)
for a in axes:
img2 = np.flip(img2, axis=a)
cc = sp.signal.fftconvolve(img1, img2, mode="full", axes=axes)
if not mode_full:
slices = [slice(None)] * len(cc.shape)
for a in axes:
start_ind = (cc.shape[a] - img1.shape[a]) // 2
end_ind = start_ind + img1.shape[a]
slices[a] = slice(start_ind, end_ind)
slices = tuple(slices)
cc = cc[slices]
if t_match:
local_sums_img2 = local_sum(img2, axes=axes)
local_sums_img2_2 = local_sum(img2**2, axes=axes)
if not mode_full:
local_sums_img2 = local_sums_img2[slices]
local_sums_img2_2 = local_sums_img2_2[slices]
cc_n = cc - local_sums_img2 * np.mean(img1)
cc_n /= np.std(img1) * np.sqrt(np.prod(np.array(img1.shape)[list(axes)]))
diff_local_sums = local_sums_img2_2 - (local_sums_img2**2) / np.prod(np.array(img2.shape)[list(axes)])
cc_n /= np.sqrt(diff_local_sums.clip(0, None))
else:
cc_n = cc / (np.linalg.norm(img1) * np.linalg.norm(img2))
cc_n = np.fft.ifftshift(cc_n)
if compute_profile:
cc_l = azimuthal_integration(cc_n, axes=axes, domain="fourier")
cc_o = azimuthal_integration(np.ones_like(cc_n), axes=axes, domain="fourier")
cc_l /= cc_o
cc_l = cc_l[: np.min(np.array(img1.shape)[list(axes)])]
if plot:
p_xy = lines_intersection(cc_l, 0.5, position="first")
f, ax = plt.subplots()
ax.plot(cc_l, label="Cross-correlation")
ax.plot(np.ones_like(cc_l) * 0.5, label="Half-maximum")
if p_xy is not None:
ax.scatter(p_xy[0], p_xy[1])
ax.plot([p_xy[0], p_xy[0]], [0, 1], label=f"Resolution: {p_xy[0]:.3} (pix)")
ax.grid()
ax.legend()
ax.set_title("Cross-correlation")
f.tight_layout()
plt.show(block=False)
return cc_n, cc_l
else:
return cc_n
[docs]
def inspect_fourier_img(img: NDArray, remove_zero: bool = False) -> None:
"""Display Fourier representation of the input image.
Parameters
----------
img : NDArray
Input image.
remove_zero : bool, optional
Remove the zero frequency value. The default is False.
"""
img_f = np.squeeze(np.fft.fft2(img))
if remove_zero is True:
img_f[0, 0] = 0
img_f_sh = np.fft.fftshift(img_f)
f, axs = plt.subplots(2, 3)
f.suptitle("Fourier representation")
axs[0, 0].imshow(np.real(img_f_sh))
axs[0, 0].set_title("Real")
axs[0, 1].imshow(np.imag(img_f_sh))
axs[0, 1].set_title("Imag")
axs[0, 2].imshow(np.abs(img_f_sh))
axs[0, 2].set_title("Abs")
vert_img_f = np.fft.fftshift(img_f[:, 0])
axs[1, 0].plot(np.stack((np.real(vert_img_f), np.imag(vert_img_f), np.abs(vert_img_f)), axis=1))
axs[1, 0].set_title("Vertical profiles")
horz_img_f = np.fft.fftshift(img_f[0, :])
axs[1, 1].plot(np.stack((np.real(horz_img_f), np.imag(horz_img_f), np.abs(horz_img_f)), axis=1))
axs[1, 1].set_title("Horizontal profiles")
diag_img_f = np.fft.fftshift(np.diag(img_f))
axs[1, 2].plot(np.stack((np.real(diag_img_f), np.imag(diag_img_f), np.abs(diag_img_f)), axis=1))
axs[1, 2].set_title("Diagonal profiles")
plt.show(block=False)