Source code for corrct.processing.post

"""
Post-processing routines.

Created on Tue Mar 24 15:25:14 2020

@author: Nicola VIGANÒ, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Sequence

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes._axes import Axes
from matplotlib.figure import Figure
from numpy.typing import ArrayLike, NDArray
from scipy.ndimage import convolve, zoom
from scipy.optimize import minimize
from scipy.signal.windows import hann
from scipy.special import gamma
from tqdm.auto import tqdm

from corrct.operators import BaseTransform, TransformIdentity
from corrct.processing.misc import azimuthal_integration, circular_mask, lines_intersection

eps = np.finfo(np.float32).eps


[docs] def com(vol: NDArray, axes: ArrayLike | None = None) -> NDArray: """ Compute center-of-mass for given volume. Parameters ---------- vol : NDArray The input volume. axes : ArrayLike, optional Axes on which to compute center-of-mass. The default is None. Returns ------- NDArray The center-of-mass position. """ if axes is None: axes = np.arange(len(vol.shape)) else: axes = np.array(axes, ndmin=1) coords = [np.linspace(-(s - 1) / 2, (s - 1) / 2, s) for s in np.array(vol.shape)[list(axes)]] num_dims = len(vol.shape) center_of_mass = np.empty((len(axes),)) for ii, a in enumerate(axes): sum_axes = np.array(np.delete(np.arange(num_dims), a), ndmin=1, dtype=int) line = np.abs(vol).sum(axis=tuple(sum_axes)) center_of_mass[ii] = line.dot(coords[ii]) / line.sum() return center_of_mass
[docs] def power_spectrum( img: NDArray, axes: Sequence[int] | None = None, smooth: int | None = 5, taper_ratio: float | None = 0.05, power: int = 2, ) -> NDArray: """ Compute the power spectrum of a n-dimensional signal. Parameters ---------- img : NDArray The n-dimensional signal. axes : Sequence[int] | None, optional The axes over which we want to compute the power spectrum, by default None smooth : int | None, optional The smoothing kernel size, by default 5 taper_ratio : float | None, optional Whether to taper the signal at the edges (for truncated signals), by default 0.05 power : int, optional The exponent to use, by default 2 Returns ------- NDArray The power spectrum """ img_shape = np.array(img.shape) if axes is None: axes = list(np.arange(-len(img_shape), 0)) axes_shape = img_shape[list(axes)] cut_off = np.min(axes_shape) // 2 if taper_ratio is not None: taper_size = float(taper_ratio * np.mean(axes_shape)) vol_mask = circular_mask(img_shape, coords_ball=axes, radius_offset=-taper_size, taper_func="cos") img = img * vol_mask img_f = np.fft.fftn(img, axes=axes) f1 = np.abs(img_f) ** power f1_int = azimuthal_integration(f1, axes=axes, domain="fourier") rings_size = azimuthal_integration(np.ones_like(img), axes=axes, domain="fourier") ps = f1_int / rings_size dc_val = np.sqrt(np.min(axes_shape) ** len(axes_shape)) ** power ps /= dc_val if smooth is not None and smooth > 1: win = hann(smooth) win /= np.sum(win) win = win.reshape([*[1] * (ps.ndim - 1), -1]) ps = convolve(ps, win, mode="nearest") return ps[..., :cut_off]
[docs] def compute_frc( img1: NDArray, img2: NDArray | None, snrt: float = 0.2071, axes: Sequence[int] | None = None, smooth: int | None = 5, taper_ratio: float | None = 0.05, supersampling: int = 1, theo_threshold: bool = True, ) -> tuple[NDArray, NDArray]: """ Compute the FRC/FSC (Fourier ring/shell correlation) between two images / volumes. Please refer to the following article for more information: M. van Heel and M. Schatz, “Fourier shell correlation threshold criteria,” J. Struct. Biol., vol. 151, no. 3, pp. 250–262, Sep. 2005. Parameters ---------- img1 : NDArray First image / volume. img2 : NDArray | None Second image / volume. snrt : float, optional SNR to be used for generating the threshold curve for resolution definition. The SNR value of 0.4142 corresponds to the hald-bit curve for a full dataset. When splitting datasets in two sub-datasets, that value needs to be halved. The default is 0.2071, which corresponds to the half-bit threashold for half dataset. axes : Sequence[int] | None, optional The axes over which we want to compute the FRC/FSC. If None, all axes will be used The default is None. smooth : int | None, optional Size of the Hann smoothing window. The default is 5. taper_ratio : float | None, optional Ratio of the edge pixels to be tapered off. This is necessary when working with truncated volumes / local tomography, to avoid truncation artifacts. The default is 0.05. supersampling : int, optional Supersampling factor of the images. Larger values increase the high-frequency range of the FRC/FSC function. The default is 1, which corresponds to the Nyquist frequency. theo_threshold : bool, optional Whether to compute the theoretical threshold curve. The default is True. Raises ------ ValueError Error returned when not passing images of the same shape. Returns ------- NDArray The computed FRC/FSC. NDArray The threshold curve corresponding to the given threshod SNR. """ img1_shape = np.array(img1.shape) if axes is None: axes = list(np.arange(-len(img1_shape), 0)) if img2 is None: if np.any(img1_shape[axes] % 2 == 1): raise ValueError(f"Image shape {img1_shape} along the chosen axes {axes} needs to be even.") raise NotImplementedError("Self FRC not implemented, yet.") else: img2_shape = np.array(img2.shape) if len(img1_shape) != len(img2_shape) or np.any(img1_shape != img2_shape): raise ValueError( f"Image #1 size {img1_shape} and image #2 size {img2_shape} are different, while they should be equal." ) if img1.dtype != img2.dtype: print(f"WARNING: The two images have different dtype: img1 {img1.dtype}, img2 {img2.dtype}. Forcing the first.") img2 = img2.astype(img1.dtype) if supersampling > 1: img1 = zoom(img1, zoom=supersampling, order=5) img2 = zoom(img2, zoom=supersampling, order=5) img1_shape = np.array(img1.shape) axes_shape = img1_shape[list(axes)] cut_off = np.min(axes_shape) // 2 if taper_ratio is not None: taper_size = float(taper_ratio * np.mean(axes_shape)) vol_mask = circular_mask(img1_shape, coords_ball=axes, radius_offset=-taper_size, taper_func="cos") img1 = img1 * vol_mask img2 = img2 * vol_mask img1_f = np.fft.fftn(img1, axes=axes) img2_f = np.fft.fftn(img2, axes=axes) fc = img1_f * np.conj(img2_f) f1 = np.abs(img1_f) ** 2 f2 = np.abs(img2_f) ** 2 fc_r_int = azimuthal_integration(fc.real, axes=axes, domain="fourier") fc_i_int = azimuthal_integration(fc.imag, axes=axes, domain="fourier") fc_int = np.sqrt((fc_r_int**2) + (fc_i_int**2)) f1_int = azimuthal_integration(f1, axes=axes, domain="fourier") f2_int = azimuthal_integration(f2, axes=axes, domain="fourier") f1s_f2s = f1_int * f2_int f1s_f2s = f1s_f2s + (f1s_f2s == 0) f1s_f2s = np.sqrt(f1s_f2s) frc = fc_int / f1s_f2s if frc.ndim > 1: frc = frc.mean(axis=tuple(range(-frc.ndim, -1))) if theo_threshold: # The number of pixels in a ring is given by the surface. # We compute the n-dimensional hyper-sphere surface, where n is given by the number of axes. n = len(axes) num_surf = 2 * np.pi ** (n / 2) den_surf = gamma(n / 2) rings_size = np.concatenate(((1.0,), num_surf / den_surf * np.arange(1, len(frc)) ** (n - 1))) else: rings_size = azimuthal_integration(np.ones_like(img1), axes=axes, domain="fourier") t_num = snrt + (2 * np.sqrt(snrt) + 1) / np.sqrt(rings_size) t_den = snrt + 1 + 2 * np.sqrt(snrt) / np.sqrt(rings_size) t_hb = t_num / t_den if smooth is not None and smooth > 1: win = hann(smooth) win /= np.sum(win) win = win.reshape([*[1] * (frc.ndim - 1), -1]) frc = convolve(frc, win, mode="nearest") return frc[..., :cut_off], t_hb[..., :cut_off]
[docs] def estimate_resolution(frc: NDArray, t_hb: NDArray) -> tuple[float, float] | None: """Estimate the resolution or bandwidth, given an FRC and a threshold curve. Parameters ---------- frc : NDArray The FRC curve t_hb : NDArray The threshold curve Returns ------- tuple[float, float] | None The resolution or bandwidth, if a crossing point was found. Otherwise None. """ if t_hb.ndim > 1: reduce_axes = tuple(np.arange(t_hb.ndim - 1)) frc = frc.mean(axis=reduce_axes) t_hb = t_hb.mean(axis=reduce_axes) return lines_intersection(frc, t_hb, x_lims=(1, None))
[docs] def plot_frcs( volume_pairs: Sequence[tuple[NDArray, NDArray]], labels: Sequence[str], title: str | None = None, smooth: int | None = 5, snrt: float = 0.2071, axes: Sequence[int] | None = None, taper_ratio: float | None = 0.05, supersampling: int = 1, verbose: bool = False, ) -> tuple[Figure, Axes]: """Compute and plot the FSCs / FRCs of some volumes. Parameters ---------- volume_pairs : Sequence[tuple[NDArray, NDArray]] A list of pairs of volumes to compute the FRCs on. labels : Sequence[str] The labels associated with each pair. title : str | None, optional The axes title, by default None. smooth : int | None, optional The size of the smoothing window for the computed curves, by default 5. snrt : float, optional The SNR of the T curve, by default 0.2071 - as per half-dataset SNR. axes : Sequence[int] | None, optional The axes along which we want to compute the FRC. The unused axes will be averaged. The default is None. taper_ratio : float | None, optional Ratio of the edge pixels to be tapered off. This is necessary when working with truncated volumes / local tomography, to avoid truncation artifacts. The default is 0.05. supersampling : int, optional Supersampling factor of the images. Larger values increase the high-frequency range of the FRC/FSC function, but might also suffer from aliasing. The default is 1, which corresponds to the Nyquist frequency. verbose : bool, optional Whether to display verbose output, by default False. Returns ------- tuple[Figure, Axes] The figure and axes objects. """ frcs = [np.array([])] * len(volume_pairs) xps: list[tuple[float, float] | None] = [(0.0, 0.0)] * len(volume_pairs) for ii, pair in enumerate(tqdm(volume_pairs, desc="Computing FRCs", disable=not verbose)): frcs[ii], t_hb = compute_frc( pair[0], pair[1], snrt=snrt, smooth=smooth, axes=axes, supersampling=supersampling, taper_ratio=taper_ratio ) xps[ii] = estimate_resolution(frcs[ii], t_hb) num_samples = len(frcs[0]) nyquist = (num_samples - 1) / supersampling xx = np.linspace(0, supersampling, num_samples) fig, axs = plt.subplots(1, 1, sharex=True, sharey=True) for f, l in zip(frcs, labels): axs.plot(xx, np.squeeze(f), label=l) axs.plot(xx, np.squeeze(t_hb), label="T 1/2 bit", linestyle="dashed") for ii, p in enumerate(xps): if p is not None: bndw = p[0] / nyquist # f"Resolution ({labels[ii]}): {res:.3}" axs.stem(bndw, p[1], label=f"{labels[ii]}, bandwidth: {bndw:.3}$f_N$", linefmt=f"C{ii}-.", markerfmt=f"C{ii}o") axs.axvline(1.0, linestyle="-.", color="k", alpha=0.5) axs.set_xlim(0, supersampling) axs.set_ylim(0, None) axs.legend(fontsize=12) axs.grid() axs.set_ylabel("Magnitude", fontdict=dict(fontsize=16)) axs.set_xlabel("Spatial frequency / Nyquist ($f_N$)", fontdict=dict(fontsize=16)) if title is not None: axs.set_title(title) for tl in axs.get_xticklabels(): tl.set_fontsize(13) for tl in axs.get_yticklabels(): tl.set_fontsize(13) fig.tight_layout() plt.show(block=False) return fig, axs
[docs] def fit_scale_bias(img_data: NDArray, prj_data: NDArray, prj: BaseTransform | None = None) -> tuple[float, float]: """Fit the scale and bias of an image, against its projection in a different space. Parameters ---------- img_data : NDArray The image data prj_data : NDArray The projected data prj : BaseTransform | None, optional The projection operator. The default is None, which uses the identity (TransformIdentity) Returns ------- tuple[float, float] The scale and bias """ if prj is None: prj = TransformIdentity(img_data.shape) prj_x = prj(img_data) prj_1 = prj(np.ones_like(img_data)) m_y_dot_prj_x = -float(np.sum(prj_data * prj_x)) m_y_dot_prj_1 = -float(np.sum(prj_data * prj_1)) prj_x_2 = float(np.sum(prj_x**2)) prj_1_2 = float(np.sum(prj_1**2)) prj_1_dot_prj_x = float(np.sum(prj_1 * prj_x)) def obj_func(ab: NDArray) -> tuple[float, NDArray]: residual = prj(img_data * ab[0] + ab[1]) - prj_data grad_a = m_y_dot_prj_x + prj_x_2 * ab[0] + prj_1_dot_prj_x * ab[1] grad_b = m_y_dot_prj_1 + prj_1_2 * ab[1] + prj_1_dot_prj_x * ab[0] return float(np.linalg.norm(residual, ord=2) ** 2) / 2, np.array((grad_a, grad_b)) opt_res = minimize(obj_func, [1.0, 0.0], jac=True) return float(opt_res.x[0]), float(opt_res.x[1])