Source code for corrct.processing.post

"""
Post-processing routines.

Created on Tue Mar 24 15:25:14 2020

@author: Nicola VIGANÒ, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Sequence
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from matplotlib.axes._axes import Axes
from matplotlib.figure import Figure
from numpy.typing import ArrayLike, NDArray
from scipy.optimize import minimize

from tqdm.auto import tqdm

from corrct.operators import BaseTransform, TransformIdentity
from corrct.processing.misc import azimuthal_integration, circular_mask, lines_intersection


eps = np.finfo(np.float32).eps


[docs] def com(vol: NDArray, axes: Optional[ArrayLike] = None) -> NDArray: """ Compute center-of-mass for given volume. Parameters ---------- vol : NDArray The input volume. axes : ArrayLike, optional Axes on which to compute center-of-mass. The default is None. Returns ------- NDArray The center-of-mass position. """ if axes is None: axes = np.arange(len(vol.shape)) else: axes = np.array(axes, ndmin=1) coords = [np.linspace(-(s - 1) / 2, (s - 1) / 2, s) for s in np.array(vol.shape)[list(axes)]] num_dims = len(vol.shape) center_of_mass = np.empty((len(axes),)) for ii, a in enumerate(axes): sum_axes = np.array(np.delete(np.arange(num_dims), a), ndmin=1, dtype=int) line = np.abs(vol).sum(axis=tuple(sum_axes)) center_of_mass[ii] = line.dot(coords[ii]) / line.sum() return center_of_mass
[docs] def power_spectrum( img: NDArray, axes: Optional[Sequence[int]] = None, smooth: Optional[int] = 5, taper_ratio: Optional[float] = 0.05, power: int = 2, ) -> NDArray: """ Compute the power spectrum of a n-dimensional signal. Parameters ---------- img : NDArray The n-dimensional signal. axes : Optional[Sequence[int]], optional The axes over which we want to compute the power spectrum, by default None smooth : Optional[int], optional The smoothing kernel size, by default 5 taper_ratio : Optional[float], optional Whether to taper the signal at the edges (for truncated signals), by default 0.05 power : int, optional The exponent to use, by default 2 Returns ------- NDArray The power spectrum """ img_shape = np.array(img.shape) if axes is None: axes = list(np.arange(-len(img_shape), 0)) axes_shape = img_shape[list(axes)] cut_off = np.min(axes_shape) // 2 if taper_ratio is not None: taper_size = float(taper_ratio * np.mean(axes_shape)) vol_mask = circular_mask(img_shape, coords_ball=axes, radius_offset=-taper_size, taper_func="cos") img = img * vol_mask img_f = np.fft.fftn(img, axes=axes) f1 = np.abs(img_f) ** power f1_int = azimuthal_integration(f1, axes=axes, domain="fourier") rings_size = azimuthal_integration(np.ones_like(img), axes=axes, domain="fourier") ps = f1_int / rings_size dc_val = np.sqrt(np.min(axes_shape) ** len(axes_shape)) ** power ps /= dc_val if smooth is not None and smooth > 1: win = sp.signal.windows.hann(smooth) win /= np.sum(win) win = win.reshape([*[1] * (ps.ndim - 1), -1]) ps = sp.ndimage.convolve(ps, win, mode="nearest") return ps[..., :cut_off]
[docs] def compute_frc( img1: NDArray, img2: Optional[NDArray], snrt: float = 0.2071, axes: Optional[Sequence[int]] = None, smooth: Optional[int] = 5, taper_ratio: Optional[float] = 0.05, supersampling: int = 1, theo_threshold: bool = True, ) -> tuple[NDArray, NDArray]: """ Compute the FRC/FSC (Fourier ring/shell correlation) between two images / volumes. Please refer to the following article for more information: M. van Heel and M. Schatz, “Fourier shell correlation threshold criteria,” J. Struct. Biol., vol. 151, no. 3, pp. 250–262, Sep. 2005. Parameters ---------- img1 : NDArray First image / volume. img2 : NDArray Second image / volume. snrt : float, optional SNR to be used for generating the threshold curve for resolution definition. The SNR value of 0.4142 corresponds to the hald-bit curve for a full dataset. When splitting datasets in two sub-datasets, that value needs to be halved. The default is 0.2071, which corresponds to the half-bit threashold for half dataset. axes : Sequence[int], optional The axes over which we want to compute the FRC/FSC. If None, all axes will be used The default is None. smooth : Optional[int], optional Size of the Hann smoothing window. The default is 5. taper_ratio : Optional[float], optional Ratio of the edge pixels to be tapered off. This is necessary when working with truncated volumes / local tomography, to avoid truncation artifacts. The default is 0.05. supersampling : int, optional Supersampling factor of the images. Larger values increase the high-frequency range of the FRC/FSC function. The default is 1, which corresponds to the Nyquist frequency. Raises ------ ValueError Error returned when not passing images of the same shape. Returns ------- NDArray The computed FRC/FSC. NDArray The threshold curve corresponding to the given threshod SNR. """ img1_shape = np.array(img1.shape) if axes is None: axes = list(np.arange(-len(img1_shape), 0)) if img2 is None: if np.any(img1_shape[axes] % 2 == 1): raise ValueError(f"Image shape {img1_shape} along the chosen axes {axes} needs to be even.") raise NotImplementedError("Self FRC not implemented, yet.") else: img2_shape = np.array(img2.shape) if len(img1_shape) != len(img2_shape) or np.any(img1_shape != img2_shape): raise ValueError( f"Image #1 size {img1_shape} and image #2 size {img2_shape} are different, while they should be equal." ) if img1.dtype != img2.dtype: print(f"WARNING: The two images have different dtype: img1 {img1.dtype}, img2 {img2.dtype}. Forcing the first.") img2 = img2.astype(img1.dtype) dtype = img1.dtype if supersampling > 1: # Bodge to make interpolation work with recent scipy: because the cython implementation does not compile for float32 dtype = float img1 = img1.astype(dtype) img2 = img2.astype(dtype) base_grid = [np.linspace(-(d - 1) / 2, (d - 1) / 2, d, dtype=dtype) for d in img1_shape] interp_grid = [np.linspace(-(d - 1) / 2, (d - 1) / 2, d, dtype=dtype) for d in img1_shape] for a in axes: d = img1_shape[a] * 2 interp_grid[a] = np.linspace(-(d - 1) / 4, (d - 1) / 4, d, dtype=dtype) interp_grid = np.meshgrid(*interp_grid, indexing="ij") interp_grid = np.transpose(interp_grid, [*range(1, len(img1_shape) + 1), 0]) img1 = sp.interpolate.interpn(base_grid, img1, interp_grid, bounds_error=False, fill_value=None) img2 = sp.interpolate.interpn(base_grid, img2, interp_grid, bounds_error=False, fill_value=None) img1_shape = np.array(img1.shape) axes_shape = img1_shape[list(axes)] cut_off = np.min(axes_shape) // 2 if taper_ratio is not None: taper_size = float(taper_ratio * np.mean(axes_shape)) vol_mask = circular_mask(img1_shape, coords_ball=axes, radius_offset=-taper_size, taper_func="cos") img1 = img1 * vol_mask img2 = img2 * vol_mask img1_f = np.fft.fftn(img1, axes=axes) img2_f = np.fft.fftn(img2, axes=axes) fc = img1_f * np.conj(img2_f) f1 = np.abs(img1_f) ** 2 f2 = np.abs(img2_f) ** 2 fc_r_int = azimuthal_integration(fc.real, axes=axes, domain="fourier") fc_i_int = azimuthal_integration(fc.imag, axes=axes, domain="fourier") fc_int = np.sqrt((fc_r_int**2) + (fc_i_int**2)) f1_int = azimuthal_integration(f1, axes=axes, domain="fourier") f2_int = azimuthal_integration(f2, axes=axes, domain="fourier") f1s_f2s = f1_int * f2_int f1s_f2s = f1s_f2s + (f1s_f2s == 0) f1s_f2s = np.sqrt(f1s_f2s) frc = fc_int / f1s_f2s if theo_threshold: # The number of pixels in a ring is given by the surface. # We compute the n-dimensional hyper-sphere surface, where n is given by the number of axes. n = len(axes) num_surf = 2 * np.pi ** (n / 2) den_surf = sp.special.gamma(n / 2) rings_size = np.concatenate(((1.0,), num_surf / den_surf * np.arange(1, len(frc)) ** (n - 1))) else: rings_size = azimuthal_integration(np.ones_like(img1), axes=axes, domain="fourier") t_num = snrt + (2 * np.sqrt(snrt) + 1) / np.sqrt(rings_size) t_den = snrt + 1 + 2 * np.sqrt(snrt) / np.sqrt(rings_size) t_hb = t_num / t_den if smooth is not None and smooth > 1: win = sp.signal.windows.hann(smooth) win /= np.sum(win) win = win.reshape([*[1] * (frc.ndim - 1), -1]) frc = sp.ndimage.convolve(frc, win, mode="nearest") return frc[..., :cut_off], t_hb[..., :cut_off]
[docs] def estimate_resolution(frc: NDArray, t_hb: NDArray) -> Optional[tuple[float, float]]: """Estimate the resolution or bandwidth, given an FRC and a threshold curve. Parameters ---------- frc : NDArray The FRC curve t_hb : NDArray The threshold curve Returns ------- tuple[float, float] | None The resolution or bandwidth, if a crossing point was found. Otherwise None. """ if t_hb.ndim > 1: reduce_axes = tuple(np.arange(t_hb.ndim - 1)) frc = frc.mean(axis=reduce_axes) t_hb = t_hb.mean(axis=reduce_axes) return lines_intersection(frc, t_hb, x_lims=(1, None))
[docs] def plot_frcs( volume_pairs: Sequence[tuple[NDArray, NDArray]], labels: Sequence[str], title: Optional[str] = None, smooth: Optional[int] = 5, snrt: float = 0.2071, axes: Optional[Sequence[int]] = None, supersampling: int = 1, verbose: bool = False, ) -> tuple[Figure, Axes]: """Compute and plot the FSCs / FRCs of some volumes. Parameters ---------- volume_pairs : Sequence[Tuple[NDArray, NDArray]] A list of pairs of volumes to compute the FRCs on. labels : Sequence[str] The labels associated with each pair. title : Optional[str], optional The axes title, by default None. smooth : Optional[int], optional The size of the smoothing window for the computed curves, by default 5. snrt : float, optional The SNR of the T curve, by default 0.2071 - as per half-dataset SNR. axes : Sequence[int] | None, optional The axes along which we want to compute the FRC. The unused axes will be averaged. The default is None. verbose : bool, optional Whether to display verbose output, by default False. """ frcs = [np.array([])] * len(volume_pairs) xps: list[Optional[tuple[float, float]]] = [(0.0, 0.0)] * len(volume_pairs) for ii, pair in enumerate(tqdm(volume_pairs, desc="Computing FRCs", disable=not verbose)): frcs[ii], t_hb = compute_frc(pair[0], pair[1], snrt=snrt, smooth=smooth, axes=axes, supersampling=supersampling) xps[ii] = estimate_resolution(frcs[ii], t_hb) nyquist = len(frcs[0]) xx = np.linspace(0, 1, nyquist) fig, axs = plt.subplots(1, 1, sharex=True, sharey=True) for f, l in zip(frcs, labels): axs.plot(xx, np.squeeze(f), label=l) axs.plot(xx, np.squeeze(t_hb), label="T 1/2 bit", linestyle="dashed") for ii, p in enumerate(xps): if p is not None: res = p[0] / (nyquist - 1) axs.stem(res, p[1], label=f"Resolution ({labels[ii]}): {res:.3}", linefmt=f"C{ii}-.", markerfmt=f"C{ii}o") axs.set_xlim(0, 1) axs.set_ylim(0, None) axs.legend(fontsize=12) axs.grid() axs.set_ylabel("Magnitude", fontdict=dict(fontsize=16)) axs.set_xlabel("Spatial frequency / Nyquist", fontdict=dict(fontsize=16)) if title is not None: axs.set_title(title) for tl in axs.get_xticklabels(): tl.set_fontsize(13) for tl in axs.get_yticklabels(): tl.set_fontsize(13) fig.tight_layout() plt.show(block=False) return fig, axs
[docs] def fit_scale_bias(img_data: NDArray, prj_data: NDArray, prj: Optional[BaseTransform] = None) -> tuple[float, float]: """Fit the scale and bias of an image, against its projection in a different space. Parameters ---------- img_data : NDArray The image data prj_data : NDArray The projected data prj : BaseTransform | None, optional The projection operator. The default is None, which uses the identity (TransformIdentity) Returns ------- tuple[float, float] The scale and bias """ if prj is None: prj = TransformIdentity(img_data.shape) prj_x = prj(img_data) prj_1 = prj(np.ones_like(img_data)) m_y_dot_prj_x = -float(np.sum(prj_data * prj_x)) m_y_dot_prj_1 = -float(np.sum(prj_data * prj_1)) prj_x_2 = float(np.sum(prj_x**2)) prj_1_2 = float(np.sum(prj_1**2)) prj_1_dot_prj_x = float(np.sum(prj_1 * prj_x)) def obj_func(ab: NDArray) -> tuple[float, NDArray]: residual = prj(img_data * ab[0] + ab[1]) - prj_data grad_a = m_y_dot_prj_x + prj_x_2 * ab[0] + prj_1_dot_prj_x * ab[1] grad_b = m_y_dot_prj_1 + prj_1_2 * ab[1] + prj_1_dot_prj_x * ab[0] return float(np.linalg.norm(residual, ord=2) ** 2) / 2, np.array((grad_a, grad_b)) opt_res = minimize(obj_func, [1.0, 0.0], jac=True) return float(opt_res.x[0]), float(opt_res.x[1])