Source code for corrct.operators

#!/usr/bin/env python3
"""
Operators module.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

import copy as cp
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Callable, Optional, Union
import numpy as np
import scipy.signal as spsig
from numpy.typing import ArrayLike, NDArray
from scipy.sparse.linalg import LinearOperator


try:
    import pywt

    has_pywt = True
    use_swtn = pywt.version.version >= "1.0.2"  # type: ignore
    if not use_swtn:
        print("WARNING - pywavelets version is too old (<1.0.2)")
except ImportError:
    has_pywt = False
    use_swtn = False
    print("WARNING - pywt was not found")


NDArrayInt = NDArray[np.integer]


[docs] class BaseTransform(LinearOperator, ABC): """Base operator class. It implements the linear operator behavior that can be used with the solvers in the `.solvers` module, and by the solvers in `scipy.sparse.linalg`. Parameters ---------- dir_shape : NDArrayInt Shape of the direct space. adj_shape : NDArrayInt Shape of the adjoint space Attributes ---------- is_dir_operator : bool Flag indicating if the operator is a direct operator. Notes ----- It assumes that the fields `dir_shape` and `adj_shape` have been set during the initialization of the derived classes. """ dir_shape: NDArrayInt adj_shape: NDArrayInt def __init__(self) -> None: """Initialize the base operator class.""" num_cols = np.prod(self.dir_shape) num_rows = np.prod(self.adj_shape) super().__init__(np.float32, [num_rows, num_cols]) self.is_dir_operator = True def _matvec(self, x: NDArray) -> NDArray: """Implement the direct operator for column vectors from the right. Parameters ---------- x : NDArray Either row from the left or column from the right. Returns ------- NDArray Result of applying the direct operator. """ if self.is_dir_operator: x = x.reshape(self.dir_shape) return self._op_direct(x).flatten() else: x = x.reshape(self.adj_shape) return self._op_adjoint(x).flatten()
[docs] def rmatvec(self, x: NDArray) -> NDArray: """Implement the direct operator for row vectors from the left. Parameters ---------- x : NDArray Either row from the left or column from the right on transpose. Returns ------- NDArray Result of applying the direct operator for row vectors. """ if self.is_dir_operator: x = x.reshape(self.adj_shape) return self._op_adjoint(x).flatten() else: x = x.reshape(self.dir_shape) return self._op_direct(x).flatten()
def _transpose(self) -> "BaseTransform": """Create the transpose operator. Returns ------- BaseTransform The transpose operator. """ Op_t = cp.copy(self) Op_t.shape = [Op_t.shape[1], Op_t.shape[0]] Op_t.is_dir_operator = False return Op_t def _adjoint(self) -> "BaseTransform": """Create the adjoint operator. Returns ------- BaseTransform The adjoint operator. """ return self._transpose()
[docs] def absolute(self) -> "BaseTransform": """Return the absolute value of the operator. Returns ------- BaseTransform The absolute value operator. """ return self
[docs] def explicit(self) -> NDArray: """Return the explicit transformation matrix associated with the operator. Returns ------- NDArray The explicit transformation matrix, as a NumPy array. """ He = np.empty(self.shape, dtype=self.dtype) if self.is_dir_operator: dim_size = np.prod(self.dir_shape) else: dim_size = np.prod(self.adj_shape) for ii in range(dim_size): xii = np.zeros((dim_size,)) xii[ii] = 1 He[:, ii] = self * xii return He
def __call__(self, x: NDArray) -> NDArray: """Apply the operator to the input vector `x`. Parameters ---------- x : NDArray Input vector. Returns ------- result : NDArray The result of applying the operator to `x`. """ if self.is_dir_operator: return self._op_direct(x) else: return self._op_adjoint(x) @abstractmethod def _op_direct(self, x: NDArray) -> NDArray: """Apply the operator to the data. Parameters ---------- x : NDArray Data to process. Returns ------- NDArray The processed data. """ @abstractmethod def _op_adjoint(self, x: NDArray) -> NDArray: """Apply the adjoint operator to the data. Parameters ---------- x : NDArray Data to process. Returns ------- NDArray The processed data. """
[docs] class TransformFunctions(BaseTransform): """Transform class that uses callables.""" def __init__( self, dir_shape: Union[ArrayLike, NDArray], adj_shape: Union[ArrayLike, NDArray], A: Callable[[NDArray], NDArray], At: Optional[Callable[[NDArray], NDArray]] = None, ) -> None: """Initialize the callable transform. If the adjoint of the function is not given, the function is considered symmetric. Parameters ---------- dir_shape : ArrayLike Shape of the direct space. adj_shape : ArrayLike Shape of the adjoint space. A : Callable[[NDArray], NDArray] The transform function. At : Optional[Callable[[NDArray], NDArray]], optional The adjoint transform function, by default None """ self.dir_shape = np.array(dir_shape, ndmin=1, dtype=int) self.adj_shape = np.array(adj_shape, ndmin=1, dtype=int) self.A = A self.At = At super().__init__() def _op_direct(self, x: NDArray) -> NDArray: """Apply the operator to the data. Parameters ---------- x : NDArray Data to process. Returns ------- NDArray The processed data. """ return self.A(x) def _op_adjoint(self, x: NDArray) -> NDArray: """Apply the adjoint operator to the data. Parameters ---------- x : NDArray Data to process. Returns ------- NDArray The processed data. """ if self.At is not None: return self.At(x) else: return self.A(x)
[docs] def absolute(self) -> "TransformFunctions": """Compute the absolute value of the operator. Raise an error, because not supported. Raises ------ AttributeError Not supported operation. """ raise AttributeError("Callable transform class does not support computing its absolute value.")
[docs] class ProjectorOperator(BaseTransform): """Base projector class that fixes the projection interface.""" @property def vol_shape(self) -> NDArrayInt: """Expose the direct space shape as volume shape. Returns ------- NDArray The volume shape. """ return self.dir_shape @property def prj_shape(self) -> NDArrayInt: """Expose the adjoint space shape as projection shape. Returns ------- NDArray The projection shape. """ return self.adj_shape @vol_shape.setter def vol_shape(self, new_shape: Union[Sequence[int], NDArray]) -> None: self.dir_shape = np.array(new_shape, ndmin=1, dtype=int) @prj_shape.setter def prj_shape(self, new_shape: Union[Sequence[int], NDArray]) -> None: self.adj_shape = np.array(new_shape, ndmin=1, dtype=int)
[docs] @abstractmethod def fp(self, x: NDArray) -> NDArray: """ Define the interface of the forward-projection. Parameters ---------- x : NDArray Input volume. Returns ------- NDArray The projection data. """
[docs] @abstractmethod def bp(self, x: NDArray) -> NDArray: """ Define the interface for the back-projection. Parameters ---------- x : NDArray Input projection data. Returns ------- NDArray The back-projected volume. """
def _op_direct(self, x: NDArray) -> NDArray: return self.fp(x) def _op_adjoint(self, x: NDArray) -> NDArray: return self.bp(x)
[docs] def get_pre_weights(self) -> Union[NDArray, None]: """Compute the pre-weights of the projector geometry (notably for cone-beam geometries). Returns ------- Union[NDArray, None] The computed detector weights """ return None
[docs] class TransformIdentity(BaseTransform): """Identity operator.""" def __init__(self, x_shape: Union[ArrayLike, NDArray]): """Identity operator. Parameters ---------- x_shape : ArrayLike | NDArray Shape of the data. """ self.dir_shape = np.array(x_shape, ndmin=1, dtype=int) self.adj_shape = np.array(x_shape, ndmin=1, dtype=int) super().__init__() def _op_direct(self, x: NDArray) -> NDArray: return x def _op_adjoint(self, x: NDArray) -> NDArray: return x
[docs] class TransformDiagonalScaling(BaseTransform): """Diagonal scaling operator.""" scale: NDArray def __init__(self, x_shape: Union[ArrayLike, NDArray], scale: Union[ArrayLike, NDArray]): """Diagonal scaling operator. Parameters ---------- x_shape : ArrayLike Shape of the data. scale : float or ArrayLike Operator diagonal. """ self.scale = np.array(scale) self.dir_shape = np.array(x_shape, ndmin=1, dtype=int) self.adj_shape = np.array(x_shape, ndmin=1, dtype=int) super().__init__()
[docs] def absolute(self) -> "TransformDiagonalScaling": """Return the projection operator using the absolute value of the projection coefficients. Returns ------- TransformDiagonalScaling The absolute value operator """ return TransformDiagonalScaling(self.dir_shape, np.abs(self.scale))
def _op_direct(self, x: NDArray) -> NDArray: return self.scale * x def _op_adjoint(self, x: NDArray) -> NDArray: return self.scale * x
[docs] class TransformConvolution(BaseTransform): """ Convolution operator. Parameters ---------- x_shape : ArrayLike Shape of the direct space. kernel : ArrayLike The convolution kernel. pad_mode: str, optional The padding mode to use for the linear convolution. The default is "edge". is_symm : bool, optional Whether the operator is symmetric or not. The default is True. flip_adjoint : bool, optional Whether the adjoint kernel should be flipped. The default is False. This is useful when the kernel is not symmetric. """ kernel: NDArray pad_mode: str is_symm: bool flip_adjoint: bool def __init__( self, x_shape: ArrayLike, kernel: ArrayLike, pad_mode: str = "edge", is_symm: bool = True, flip_adjoint: bool = False ): self.dir_shape = np.array(x_shape, ndmin=1, dtype=int) self.adj_shape = np.array(x_shape, ndmin=1, dtype=int) self.kernel = np.array(kernel, ndmin=len(self.dir_shape)) self.pad_mode = pad_mode.lower() self.is_symm = is_symm self.flip_adjoint = flip_adjoint super().__init__()
[docs] def absolute(self) -> "TransformConvolution": """ Return the convolution operator using the absolute value of the kernel coefficients. Returns ------- TransformConvolution The absolute value of the convolution operator. """ return TransformConvolution(self.dir_shape, np.abs(self.kernel))
def _pad_valid(self, x: NDArray) -> tuple[NDArray, NDArray]: pad_width = (np.array(self.kernel.shape) - 1) // 2 return np.pad(x, pad_width=pad_width[:, None], mode=self.pad_mode), pad_width # type: ignore def _crop_valid(self, x: NDArray, pad_width: NDArray) -> NDArray: slices = [slice(pw if pw else None, -pw if pw else None) for pw in pad_width] return x[tuple(slices)] def _op_direct(self, x: NDArray) -> NDArray: x, pw = self._pad_valid(x) x = spsig.convolve(x, self.kernel, mode="same") return self._crop_valid(x, pw) def _op_adjoint(self, x: NDArray) -> NDArray: if self.is_symm: x, pw = self._pad_valid(x) if self.flip_adjoint: adj_kernel = np.flip(self.kernel) else: adj_kernel = self.kernel x = spsig.convolve(x, adj_kernel, mode="same") return self._crop_valid(x, pw) else: return x
[docs] class BaseWaveletTransform(BaseTransform, ABC): """Base Wavelet transform.""" axes: NDArrayInt wavelet: str labels: list[str] def _initialize_filter_bank(self) -> None: num_axes = len(self.axes) self.labels = [bin(x)[2:].zfill(num_axes).replace("0", "a").replace("1", "d") for x in range(1, 2**num_axes)] self.w = pywt.Wavelet(self.wavelet) # type: ignore filt_bank_l1norm = np.linalg.norm(self.w.filter_bank, ord=1, axis=-1) self.wlet_dec_filter_mult = np.array( [(filt_bank_l1norm[0] ** lab.count("a")) * (filt_bank_l1norm[1] ** lab.count("d")) for lab in self.labels] ) self.wlet_rec_filter_mult = np.array( [(filt_bank_l1norm[2] ** lab.count("a")) * (filt_bank_l1norm[3] ** lab.count("d")) for lab in self.labels] )
[docs] class TransformDecimatedWavelet(BaseWaveletTransform): """Decimated wavelet Transform operator.""" def __init__( self, x_shape: Union[ArrayLike, NDArray], wavelet: str, level: int, axes: Optional[ArrayLike] = None, pad_on_demand: str = "edge", ): """ Decimated wavelet Transform operator. Parameters ---------- x_shape : ArrayLike Shape of the data to be wavelet transformed. wavelet : str Wavelet type. level : int Number of wavelet decomposition levels. axes : int or tuple of int, optional Axes along which to do the transform. Defaults to None. pad_on_demand : str, optional Padding type to fit the `2 ** level` shape requirements. Defaults to 'edge'. Options are all the `numpy.pad` padding modes. Raises ------ ValueError If the pywavelets package is not available or its version is not adequate. """ x_shape = np.array(x_shape, ndmin=1, dtype=int) if not has_pywt: raise ValueError("Cannot use Wavelet transform because pywavelets is not installed.") self.wavelet = wavelet self.level = level if axes is None: axes = np.arange(-len(x_shape), 0, dtype=int) self.axes = np.array(axes, ndmin=1, dtype=int) self.pad_on_demand = pad_on_demand self._initialize_filter_bank() num_axes = len(self.axes) self.dir_shape = x_shape self.sub_band_shapes = pywt.wavedecn_shapes( self.dir_shape, self.wavelet, mode=self.pad_on_demand, level=self.level, axes=self.axes ) self.adj_shape = self.dir_shape.copy() for ax in self.axes: self.adj_shape[ax] = self.sub_band_shapes[0][ax] + np.sum( [self.sub_band_shapes[x]["d" * num_axes][ax] for x in range(1, self.level + 1)] ) self.slicing_info = None super().__init__()
[docs] def direct_dwt(self, x: NDArray) -> list: """ Perform the direct wavelet transform. Parameters ---------- x : NDArray Data to transform. Returns ------- list Transformed data. """ return pywt.wavedecn(x, wavelet=self.wavelet, axes=self.axes, mode=self.pad_on_demand, level=self.level)
[docs] def inverse_dwt(self, y: list) -> NDArray: """ Perform the inverse wavelet transform. Parameters ---------- y : list Data to anti-transform. Returns ------- NDArray Anti-transformed data. """ rec = pywt.waverecn(y, wavelet=self.wavelet, axes=self.axes, mode=self.pad_on_demand) if not np.all(rec.shape == self.dir_shape): slices = [slice(0, s) for s in self.dir_shape] rec = rec[tuple(slices)] return rec
def _op_direct(self, x: NDArray) -> NDArray: c = self.direct_dwt(x) y, self.slicing_info = pywt.coeffs_to_array(c, axes=self.axes) return y def _op_adjoint(self, y: NDArray) -> NDArray: if self.slicing_info is None: _ = self._op_direct(np.zeros(self.dir_shape)) c = pywt.array_to_coeffs(y, self.slicing_info) return self.inverse_dwt(c)
[docs] class TransformStationaryWavelet(BaseWaveletTransform): """Stationary wavelet Transform operator.""" def __init__( self, x_shape: ArrayLike, wavelet: str, level: int, axes: Optional[ArrayLike] = None, pad_on_demand: str = "edge", normalized: bool = True, ): """Stationary wavelet Transform operator. Parameters ---------- x_shape : ArrayLike The shape of the data to be wavelet transformed. wavelet : str The type of wavelet to use. level : int Number of wavelet decomposition levels. axes : int or tuple of int, optional Axes along which to perform the transform. Default is None. pad_on_demand : str, optional The padding type to fit the `2 ** level` shape requirements. Default is 'constant'. Options are all the `numpy.pad` padding modes. normalized : bool, optional Whether to use a normalized transform. Default is True. Raises ------ ValueError If the pywavelets package is not available or its version is not adequate. """ x_shape = np.array(x_shape, ndmin=1, dtype=int) if not has_pywt: raise ValueError("Cannot use Wavelet transform because pywavelets is not installed.") if not use_swtn: raise ValueError("Cannot use Wavelet transform because pywavelets is too old (<1.0.2).") self.wavelet = wavelet self.level = level self.normalized = normalized if axes is None: axes = np.arange(-len(x_shape), 0, dtype=int) self.axes = np.array(axes, ndmin=1, dtype=int) self.pad_on_demand = pad_on_demand self._initialize_filter_bank() self.dir_shape = x_shape if self.pad_on_demand is not None: alignment = 2**self.level x_axes = np.array(self.dir_shape)[np.array(self.axes)] self.pad_axes = (alignment - x_axes % alignment) % alignment adj_x_shape = cp.deepcopy(self.dir_shape) adj_x_shape[np.array(self.axes)] += self.pad_axes else: adj_x_shape = self.dir_shape self.adj_shape = np.array((self.level * (2 ** len(self.axes) - 1) + 1, *adj_x_shape)) super().__init__()
[docs] def direct_swt(self, x: NDArray) -> list: """ Perform the direct wavelet transform. Parameters ---------- x : NDArray Data to transform. Returns ------- list Transformed data. """ if self.pad_on_demand is not None and np.any(self.pad_axes): for ax in np.nonzero(self.pad_axes)[0]: pad_l = np.ceil(self.pad_axes[ax] / 2).astype(int) pad_h = np.floor(self.pad_axes[ax] / 2).astype(int) pad_width = [(0, 0)] * len(x.shape) pad_width[self.axes[ax]] = (pad_l, pad_h) x = np.pad(x, pad_width, mode=self.pad_on_demand) return pywt.swtn(x, wavelet=self.wavelet, axes=self.axes, norm=self.normalized, level=self.level, trim_approx=True)
[docs] def inverse_swt(self, y: list) -> NDArray: """ Perform the inverse wavelet transform. Parameters ---------- y : list Data to anti-transform. Returns ------- NDArray Anti-transformed data. """ x = pywt.iswtn(y, wavelet=self.wavelet, axes=self.axes, norm=self.normalized) if self.pad_on_demand is not None and np.any(self.pad_axes): for ax in np.nonzero(self.pad_axes)[0]: pad_l = np.ceil(self.pad_axes[ax] / 2).astype(int) pad_h = np.floor(self.pad_axes[ax] / 2).astype(int) slices = [slice(None)] * len(x.shape) slices[self.axes[ax]] = slice(pad_l, x.shape[self.axes[ax]] - pad_h, 1) x = x[tuple(slices)] return x
def _op_direct(self, x: NDArray) -> NDArray: y = self.direct_swt(x) y = [y[0]] + [y[lvl][x] for lvl in range(1, self.level + 1) for x in self.labels] return np.array(y) def _op_adjoint(self, y: NDArray) -> NDArray: def get_lvl_pos(lvl): return (lvl - 1) * (2 ** len(self.axes) - 1) + 1 x = [y[0]] + [ {k: y[ii_lbl + get_lvl_pos(lvl), ...] for ii_lbl, k in enumerate(self.labels)} for lvl in range(1, self.level + 1) ] return self.inverse_swt(x)
[docs] class TransformGradient(BaseTransform): """ Gradient operator. Parameters ---------- x_shape : ArrayLike Shape of the data to be transformed. axes : Optional[ArrayLike], optional Axes along which to do the gradient. The default is None. pad_mode : str, optional Padding mode of the gradient. The default is "edge". """ def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None, pad_mode: str = "edge"): x_shape = np.array(x_shape, ndmin=1, dtype=int) if axes is None: axes = np.arange(-len(x_shape), 0, dtype=int) self.axes = np.array(axes, ndmin=1, dtype=int) self.ndims = len(x_shape) self.pad_mode = pad_mode.lower() self.dir_shape = x_shape self.adj_shape = np.array((len(self.axes), *self.dir_shape), ndmin=1, dtype=int) super().__init__()
[docs] def gradient(self, x: NDArray) -> NDArray: """ Compute the gradient. Parameters ---------- x : NDArray Input data. Returns ------- NDArray Gradient of data. """ d = [np.array([])] * len(self.axes) for ii, ax in enumerate(self.axes): padding = np.zeros((self.ndims, 2), dtype=int) padding[ax, 1] = 1 temp_x = np.pad(x, padding, mode=self.pad_mode) # type: ignore d[ii] = np.diff(temp_x, n=1, axis=ax) return np.stack(d, axis=0)
[docs] def divergence(self, x: NDArray) -> NDArray: """ Compute the divergence - transpose of gradient. Parameters ---------- x : NDArray Input data. Returns ------- NDArray Divergence of data. """ d = [np.array([])] * len(self.axes) for ii, ax in enumerate(self.axes): padding = np.zeros((self.ndims, 2), dtype=int) padding[ax, 0] = 1 temp_x = np.pad(x[ii, ...], padding, mode=self.pad_mode) # type: ignore d[ii] = np.diff(temp_x, n=1, axis=ax) return np.sum(np.stack(d, axis=0), axis=0)
def _op_direct(self, x: NDArray) -> NDArray: return self.gradient(x) def _op_adjoint(self, y: NDArray) -> NDArray: return -self.divergence(y)
[docs] class TransformFourier(BaseTransform): """Fourier transform operator.""" def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None): """ Fourier transform. Parameters ---------- x_shape : ArrayLike Shape of the data to be Fourier transformed. axes : Optional[ArrayLike], optional Axes along which to do the Fourier transform. Returns ------- None """ x_shape = np.array(x_shape, ndmin=1, dtype=int) if axes is None: axes = np.arange(-len(x_shape), 0, dtype=int) self.axes = np.array(axes, ndmin=1, dtype=int) self.ndims = len(x_shape) self.dir_shape = x_shape self.adj_shape = np.array((2, *self.dir_shape), ndmin=1, dtype=int) super().__init__()
[docs] def fft(self, x: NDArray) -> NDArray: """ Compute the fft. Parameters ---------- x : NDArray Input data. Returns ------- NDArray FFT of data. """ d = np.empty(self.adj_shape, dtype=x.dtype) x_f = np.fft.fftn(x, axes=tuple(self.axes), norm="ortho") d[0, ...] = x_f.real d[1, ...] = x_f.imag return d
[docs] def ifft(self, x: NDArray) -> NDArray: """Compute the inverse of the fft. Parameters ---------- x : NDArray Input data. Returns ------- NDArray iFFT of data. """ d = x[0, ...] + 1j * x[1, ...] return np.fft.ifftn(d, axes=tuple(self.axes), norm="ortho").real
def _op_direct(self, x: NDArray) -> NDArray: return self.fft(x) def _op_adjoint(self, y: NDArray) -> NDArray: return self.ifft(y)
[docs] class TransformLaplacian(BaseTransform): """ Laplacian transform operator. Parameters ---------- x_shape : ArrayLike Shape of the data to be transformed. axes : ArrayLike, optional Axes along which to do the Laplacian. The default is None. pad_mode : str, optional Padding mode of the Laplacian. The default is "edge". """ def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None, pad_mode: str = "edge"): x_shape = np.array(x_shape, ndmin=1, dtype=int) if axes is None: axes = np.arange(-len(x_shape), 0, dtype=int) self.axes = np.array(axes, ndmin=1, dtype=int) self.ndims = len(x_shape) self.pad_mode = pad_mode.lower() self.dir_shape = x_shape self.adj_shape = x_shape super().__init__()
[docs] def laplacian(self, x: NDArray) -> NDArray: """Compute the laplacian. Parameters ---------- x : NDArray Input data. Returns ------- NDArray Laplacian of the input data. """ d = [np.array([])] * len(self.axes) for ii, ax in enumerate(self.axes): padding = np.zeros((self.ndims, 2), dtype=int) padding[ax, :] = 1 temp_x = np.pad(x, padding, mode=self.pad_mode) # type: ignore d[ii] = np.diff(temp_x, n=2, axis=ax) return np.sum(d, axis=0)
def _op_direct(self, x: NDArray) -> NDArray: return self.laplacian(x) def _op_adjoint(self, y: NDArray) -> NDArray: return self.laplacian(y)
[docs] class TransformSVD(BaseTransform): """Singular value decomposition based decomposition operator.""" U: Optional[NDArray] Vt: Optional[NDArray] def __init__(self, x_shape, axes_rows=(0,), axes_cols=(-1,), rescale: bool = False): """ Singular value decomposition operator. The SVD decomposition will be done over the flattened rows vs flatted cols. This means that the channels should always be the rows (expected to be only one dimension, usually), while the volume dimensions should always be the columns (expected to be the last two or three dimensions). Parameters ---------- x_shape : numpy.array_like Shape of the data to be wavelet transformed. axes_rows : tuple of int, optional Axes expanded in rows of the SVD. Defaults to (0, ). axes_cols : tuple of int, optional Axes expanded in cols of the SVD. Defaults to (-1, ). Raises ------ IndexError In case the axes are outside the range. """ self.dir_shape = np.array(x_shape, ndmin=1, dtype=int) self.axes_rows = np.atleast_1d(axes_rows) % len(self.dir_shape) self.axes_cols = np.atleast_1d(axes_cols) % len(self.dir_shape) # Dimensions to decompose self.append_dims = np.concatenate((self.axes_rows, self.axes_cols)).astype(int) # Dimensions to NOT decompose temp_dims = np.arange(len(self.dir_shape), dtype=int) self.invariant_dims = np.delete(temp_dims, self.append_dims) self.invariant_dims_shape = self.dir_shape[self.invariant_dims] # Transpose operation to prepare data for decomposition self.fwd_transpose = np.concatenate((self.invariant_dims, self.append_dims)) # Transpose operation to recover data after re-composition self.bwd_transpose = np.argsort(self.fwd_transpose) self.axes_rows_shape = self.dir_shape[self.axes_rows] self.axes_cols_shape = self.dir_shape[self.axes_cols] self.axes_rows_size = (np.prod(self.axes_rows_shape),) self.axes_cols_size = (np.prod(self.axes_cols_shape),) # Reshape operation to prepare data for decomposition self.fwd_shape = np.concatenate((self.invariant_dims_shape, self.axes_rows_size, self.axes_cols_size)) # Reshape operation to recover data after re-composition self.bwd_shape = np.concatenate((self.invariant_dims_shape, self.axes_rows_shape, self.axes_cols_shape)) self.adj_shape = np.concatenate((self.invariant_dims_shape, np.fmin(self.axes_rows_size, self.axes_cols_size))) self.U = None self.Vt = None self.rescale = rescale super().__init__()
[docs] def direct_svd(self, x): """ Performs the SVD decomposition. Parameters ---------- x : `numpy.array_like` Data to transform. Returns ------- tuple(U, s, Vt) Transformed data. """ return np.linalg.svd(x, full_matrices=False)
[docs] def inverse_svd(self, U, s, Vt): """ Performs the inverse SVD decomposition. Parameters ---------- U : `numpy.array_like` Rows of the SVD decomposition. s : `numpy.array_like` Singular values. Vt : `numpy.array_like` Columns of the SVD decomposition. Returns ------- `numpy.array_like` Anti-transformed data. """ return np.matmul(U, s[..., None] * Vt)
def _op_direct(self, x): x = np.transpose(x, self.fwd_transpose) x = np.reshape(x, self.fwd_shape) (self.U, s, self.Vt) = self.direct_svd(x) if self.rescale: s /= np.sqrt(self.Vt.shape[-1] * self.U.shape[-2]) self.Vt *= np.sqrt(self.Vt.shape[-1]) self.U *= np.sqrt(self.U.shape[-2]) return s def _op_adjoint(self, y): if self.U is None or self.Vt is None: raise ValueError("Operator not initialized!") x = self.inverse_svd(self.U, y, self.Vt) x = np.reshape(x, self.bwd_shape) return np.transpose(x, self.bwd_transpose)
if __name__ == "__main__": test_vol = np.zeros((10, 10), dtype=np.float32) H = TransformStationaryWavelet(test_vol.shape, "db1", 2) Htw = H.T.explicit() Hw = H.explicit() D = TransformGradient(test_vol.shape) Dg = D.explicit()