Source code for corrct.operators
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Operators module.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
import numpy as np
from scipy.sparse.linalg import LinearOperator
import scipy.signal as spsig
import copy as cp
from numpy.typing import ArrayLike, NDArray
from typing import Callable, Optional, Sequence, Tuple, Union
from abc import abstractmethod
try:
import pywt
has_pywt = True
use_swtn = pywt.version.version >= "1.0.2" # type: ignore
if not use_swtn:
print("WARNING - pywavelets version is too old (<1.0.2)")
except ImportError:
has_pywt = False
use_swtn = False
print("WARNING - pywt was not found")
NDArrayInt = NDArray[np.integer]
[docs]class BaseTransform(LinearOperator):
"""Base operator class.
It implements the linear operator behavior that can be used with the solvers in the `.solvers` module,
and by the solvers in `scipy.sparse.linalg`.
"""
dir_shape: NDArrayInt
adj_shape: NDArrayInt
def __init__(self):
"""Initialize the base operator class.
It assumes that the fields `dir_shape` and `adj_shape` have been set during the initialization of the derived classes.
"""
num_cols = np.prod(self.dir_shape)
num_rows = np.prod(self.adj_shape)
super().__init__(np.float32, [num_rows, num_cols])
self.is_dir_operator = True
def _matvec(self, x: NDArray) -> NDArray:
"""Implement the direct operator for column vectors from the right.
:param x: Either row from the left or column from the right.
:type x: numpy.array_like
"""
if self.is_dir_operator:
x = x.reshape(self.dir_shape)
return self._op_direct(x).flatten()
else:
x = x.reshape(self.adj_shape)
return self._op_adjoint(x).flatten()
[docs] def rmatvec(self, x: NDArray) -> NDArray:
"""Implement the direct operator for row vectors from the left.
:param x: Either row from the left or column from the right on transpose.
:type x: numpy.array_like
"""
if self.is_dir_operator:
x = x.reshape(self.adj_shape)
return self._op_adjoint(x).flatten()
else:
x = x.reshape(self.dir_shape)
return self._op_direct(x).flatten()
def _transpose(self):
"""Create the transpose operator.
:returns: The transpose operator
:rtype: BaseTransform
"""
Op_t = cp.copy(self)
Op_t.shape = [Op_t.shape[1], Op_t.shape[0]]
Op_t.is_dir_operator = False
return Op_t
def _adjoint(self):
return self._transpose()
[docs] def absolute(self):
"""Return the projection operator using the absolute value of the projection coefficients.
:returns: The absolute value operator
:rtype: ProjectorOperator
"""
return self
[docs] def explicit(self) -> NDArray:
"""Return the explicit transformation matrix associated to the operator.
:returns: The explicit transformation matrix
:rtype: NDArray
"""
He = np.empty(self.shape, dtype=self.dtype)
if self.is_dir_operator:
dim_size = np.prod(self.dir_shape)
else:
dim_size = np.prod(self.adj_shape)
for ii in range(dim_size):
xii = np.zeros((dim_size,))
xii[ii] = 1
He[:, ii] = self * xii
return He
def __call__(self, x: NDArray) -> NDArray:
"""Apply the operator to the input vector.
:param x: Input vector.
:type x: NDArray
:returns: The result of the application of the operator on the input vector.
:rtype: NDArray
"""
if self.is_dir_operator:
return self._op_direct(x)
else:
return self._op_adjoint(x)
@abstractmethod
def _op_direct(self, x: NDArray) -> NDArray:
"""Apply the operator to the data.
Parameters
----------
x : NDArray
Data to process.
Returns
-------
NDArray
The processed data.
"""
@abstractmethod
def _op_adjoint(self, x: NDArray) -> NDArray:
"""Apply the adjoint operator to the data.
Parameters
----------
x : NDArray
Data to process.
Returns
-------
NDArray
The processed data.
"""
[docs]class TransformFunctions(BaseTransform):
"""Transform class that uses callables."""
def __init__(
self,
dir_shape: ArrayLike,
adj_shape: ArrayLike,
A: Callable[[NDArray], NDArray],
At: Optional[Callable[[NDArray], NDArray]] = None,
) -> None:
"""Initialize the callable transform.
If the adjoint of the function is not given, the function is considered symmetric.
Parameters
----------
dir_shape : ArrayLike
Shape of the direct space.
adj_shape : ArrayLike
Shape of the adjoint space.
A : Callable[[NDArray], NDArray]
The transform function.
At : Optional[Callable[[NDArray], NDArray]], optional
The adjoint transform function, by default None
"""
self.dir_shape = np.array(dir_shape, ndmin=1, dtype=int)
self.adj_shape = np.array(adj_shape, ndmin=1, dtype=int)
self.A = A
self.At = At
super().__init__()
def _op_direct(self, x: NDArray) -> NDArray:
"""Apply the operator to the data.
Parameters
----------
x : NDArray
Data to process.
Returns
-------
NDArray
The processed data.
"""
return self.A(x)
def _op_adjoint(self, x: NDArray) -> NDArray:
"""Apply the adjoint operator to the data.
Parameters
----------
x : NDArray
Data to process.
Returns
-------
NDArray
The processed data.
"""
if self.At is not None:
return self.At(x)
else:
return self.A(x)
[docs] def absolute(self):
"""Compute the absolute value of the operator. Raise an error, because not supported.
Raises
------
AttributeError
Not supported operation.
"""
raise AttributeError("Callable transform class does not support computing its absolute value.")
[docs]class ProjectorOperator(BaseTransform):
"""Base projector class that fixes the projection interface."""
@property
def vol_shape(self) -> NDArrayInt:
"""Expose the direct space shape as volume shape.
Returns
-------
NDArray
The volume shape.
"""
return self.dir_shape
@property
def prj_shape(self) -> NDArrayInt:
"""Expose the adjoint space shape as projection shape.
Returns
-------
NDArray
The projection shape.
"""
return self.adj_shape
@vol_shape.setter
def vol_shape(self, new_shape: Union[Sequence[int], NDArray]) -> None:
self.dir_shape = np.array(new_shape, ndmin=1, dtype=int)
@prj_shape.setter
def prj_shape(self, new_shape: Union[Sequence[int], NDArray]) -> None:
self.adj_shape = np.array(new_shape, ndmin=1, dtype=int)
[docs] @abstractmethod
def fp(self, x: NDArray) -> NDArray:
"""Define the interface for the forward-projection.
:param x: Input volume.
:type x: NDArray
:returns: The projection data.
:rtype: NDArray
"""
[docs] @abstractmethod
def bp(self, x: NDArray) -> NDArray:
"""Define the interface for the back-projection.
:param x: Input projection data.
:type x: NDArray
:returns: The back-projected volume.
:rtype: NDArray
"""
def _op_direct(self, x: NDArray) -> NDArray:
return self.fp(x)
def _op_adjoint(self, x: NDArray) -> NDArray:
return self.bp(x)
[docs] def get_pre_weights(self) -> Union[NDArray, None]:
"""Compute the pre-weights of the projector geometry (notably for cone-beam geometries).
Returns
-------
Union[NDArray, None]
The computed detector weights
"""
return None
[docs]class TransformIdentity(BaseTransform):
"""Identity operator."""
def __init__(self, x_shape: ArrayLike):
"""Identity operator.
:param x_shape: Shape of the data.
:type x_shape: ArrayLike
"""
self.dir_shape = np.array(x_shape, ndmin=1, dtype=int)
self.adj_shape = np.array(x_shape, ndmin=1, dtype=int)
super().__init__()
def _op_direct(self, x: NDArray) -> NDArray:
return x
def _op_adjoint(self, x: NDArray) -> NDArray:
return x
[docs]class TransformDiagonalScaling(BaseTransform):
"""Diagonal scaling operator."""
scale: NDArray
def __init__(self, x_shape: ArrayLike, scale: Union[float, ArrayLike]):
"""Diagonal scaling operator.
:param x_shape: Shape of the data.
:type x_shape: ArrayLike
:param scale: Operator diagonal.
:type scale: float or ArrayLike
"""
self.scale = np.array(scale)
self.dir_shape = np.array(x_shape, ndmin=1, dtype=int)
self.adj_shape = np.array(x_shape, ndmin=1, dtype=int)
super().__init__()
[docs] def absolute(self):
"""Return the projection operator using the absolute value of the projection coefficients.
:returns: The absolute value operator
:rtype: Diagonal operator of the absolute values
"""
return TransformDiagonalScaling(self.dir_shape, np.abs(self.scale))
def _op_direct(self, x: NDArray) -> NDArray:
return self.scale * x
def _op_adjoint(self, x: NDArray) -> NDArray:
return self.scale * x
[docs]class TransformConvolution(BaseTransform):
"""
Convolution operator.
Parameters
----------
x_shape : ArrayLike
Shape of the direct space.
kernel : ArrayLike
The convolution kernel.
pad_mode: str, optional
The padding mode to use for the linear convolution. The default is "edge".
is_symm : bool, optional
Whether the operator is symmetric or not. The default is True.
flip_adjoint : bool, optional
Whether the adjoint kernel should be flipped. The default is False.
This is useful when the kernel is not symmetric.
"""
kernel: NDArray
pad_mode: str
is_symm: bool
flip_adjoint: bool
def __init__(
self, x_shape: ArrayLike, kernel: ArrayLike, pad_mode: str = "edge", is_symm: bool = True, flip_adjoint: bool = False
):
self.dir_shape = np.array(x_shape, ndmin=1, dtype=int)
self.adj_shape = np.array(x_shape, ndmin=1, dtype=int)
self.kernel = np.array(kernel, ndmin=len(self.dir_shape))
self.pad_mode = pad_mode.lower()
self.is_symm = is_symm
self.flip_adjoint = flip_adjoint
super().__init__()
[docs] def absolute(self) -> "TransformConvolution":
"""
Return the convolution operator using the absolute value of the kernel coefficients.
Returns
-------
TransformConvolution
The absolute value of the convolution operator.
"""
return TransformConvolution(self.dir_shape, np.abs(self.kernel))
def _pad_valid(self, x: NDArray) -> Tuple[NDArray, NDArray]:
pad_width = (np.array(self.kernel.shape) - 1) // 2
return np.pad(x, pad_width=pad_width[:, None], mode=self.pad_mode), pad_width # type: ignore
def _crop_valid(self, x: NDArray, pad_width: NDArray) -> NDArray:
slices = [slice(pw if pw else None, -pw if pw else None) for pw in pad_width]
return x[tuple(slices)]
def _op_direct(self, x: NDArray) -> NDArray:
x, pw = self._pad_valid(x)
x = spsig.convolve(x, self.kernel, mode="same")
return self._crop_valid(x, pw)
def _op_adjoint(self, x: NDArray) -> NDArray:
if self.is_symm:
x, pw = self._pad_valid(x)
if self.flip_adjoint:
adj_kernel = np.flip(self.kernel)
else:
adj_kernel = self.kernel
x = spsig.convolve(x, adj_kernel, mode="same")
return self._crop_valid(x, pw)
else:
return x
[docs]class BaseWaveletTransform(BaseTransform):
"""Base Wavelet transform."""
axes: NDArrayInt
wavelet: str
def _initialize_filter_bank(self) -> None:
num_axes = len(self.axes)
self.labels = [bin(x)[2:].zfill(num_axes).replace("0", "a").replace("1", "d") for x in range(1, 2**num_axes)]
self.w = pywt.Wavelet(self.wavelet) # type: ignore
filt_bank_l1norm = np.linalg.norm(self.w.filter_bank, ord=1, axis=-1)
self.wlet_dec_filter_mult = np.array(
[(filt_bank_l1norm[0] ** lab.count("a")) * (filt_bank_l1norm[1] ** lab.count("d")) for lab in self.labels]
)
self.wlet_rec_filter_mult = np.array(
[(filt_bank_l1norm[2] ** lab.count("a")) * (filt_bank_l1norm[3] ** lab.count("d")) for lab in self.labels]
)
[docs]class TransformDecimatedWavelet(BaseWaveletTransform):
"""Decimated wavelet Transform operator."""
def __init__(
self, x_shape: ArrayLike, wavelet: str, level: int, axes: Optional[ArrayLike] = None, pad_on_demand: str = "edge"
):
"""Decimated wavelet Transform operator.
:param x_shape: Shape of the data to be wavelet transformed.
:type x_shape: ArrayLike
:param wavelet: Wavelet type
:type wavelet: string
:param level: Numer of wavelet decomposition levels
:type level: int
:param axes: Axes along which to do the transform, defaults to None
:type axes: int or tuple of int, optional
:param pad_on_demand: Padding type to fit the `2 ** level` shape requirements, defaults to 'edge'
:type pad_on_demand: string, optional. Options are all the `numpy.pad` padding modes.
:raises ValueError: In case the pywavelets package is not available or its version is not adequate.
"""
x_shape = np.array(x_shape, ndmin=1, dtype=int)
if not has_pywt:
raise ValueError("Cannot use Wavelet transform because pywavelets is not installed.")
self.wavelet = wavelet
self.level = level
if axes is None:
axes = np.arange(-len(x_shape), 0, dtype=int)
self.axes = np.array(axes, ndmin=1, dtype=int)
self.pad_on_demand = pad_on_demand
self._initialize_filter_bank()
num_axes = len(self.axes)
self.dir_shape = x_shape
self.sub_band_shapes = pywt.wavedecn_shapes(
self.dir_shape, self.wavelet, mode=self.pad_on_demand, level=self.level, axes=self.axes
)
self.adj_shape = self.dir_shape.copy()
for ax in self.axes:
self.adj_shape[ax] = self.sub_band_shapes[0][ax] + np.sum(
[self.sub_band_shapes[x]["d" * num_axes][ax] for x in range(1, self.level + 1)]
)
self.slicing_info = None
super().__init__()
[docs] def direct_dwt(self, x: NDArray) -> list:
"""Perform the direct wavelet transform.
:param x: Data to transform.
:type x: NDArray
:return: Transformed data.
:rtype: list
"""
return pywt.wavedecn(x, wavelet=self.wavelet, axes=self.axes, mode=self.pad_on_demand, level=self.level)
[docs] def inverse_dwt(self, y: list) -> NDArray:
"""Perform the inverse wavelet transform.
:param x: Data to anti-transform.
:type x: list
:return: Anti-transformed data.
:rtype: NDArray
"""
rec = pywt.waverecn(y, wavelet=self.wavelet, axes=self.axes, mode=self.pad_on_demand)
if not np.all(rec.shape == self.dir_shape):
slices = [slice(0, s) for s in self.dir_shape]
rec = rec[tuple(slices)]
return rec
def _op_direct(self, x: NDArray) -> NDArray:
c = self.direct_dwt(x)
y, self.slicing_info = pywt.coeffs_to_array(c, axes=self.axes)
return y
def _op_adjoint(self, y: NDArray) -> NDArray:
if self.slicing_info is None:
_ = self._op_direct(np.zeros(self.dir_shape))
c = pywt.array_to_coeffs(y, self.slicing_info)
return self.inverse_dwt(c)
[docs]class TransformStationaryWavelet(BaseWaveletTransform):
"""Stationary avelet Transform operator."""
def __init__(
self,
x_shape: ArrayLike,
wavelet: str,
level: int,
axes: Optional[ArrayLike] = None,
pad_on_demand: str = "edge",
normalized: bool = True,
):
"""Stationary wavelet Transform operator.
:param x_shape: Shape of the data to be wavelet transformed.
:type x_shape: ArrayLike
:param wavelet: Wavelet type
:type wavelet: string
:param level: Numer of wavelet decomposition levels
:type level: int
:param axes: Axes along which to do the transform, defaults to None
:type axes: int or tuple of int, optional
:param pad_on_demand: Padding type to fit the `2 ** level` shape requirements, defaults to 'constant'
:type pad_on_demand: string, optional. Options are all the `numpy.pad` padding modes.
:param normalized: Whether to use a normalized transform. Defaults to True.
:type normalized: boolean, optional.
:raises ValueError: In case the pywavelets package is not available or its version is not adequate.
"""
x_shape = np.array(x_shape, ndmin=1, dtype=int)
if not has_pywt:
raise ValueError("Cannot use Wavelet transform because pywavelets is not installed.")
if not use_swtn:
raise ValueError("Cannot use Wavelet transform because pywavelets is too old (<1.0.2).")
self.wavelet = wavelet
self.level = level
self.normalized = normalized
if axes is None:
axes = np.arange(-len(x_shape), 0, dtype=int)
self.axes = np.array(axes, ndmin=1, dtype=int)
self.pad_on_demand = pad_on_demand
self._initialize_filter_bank()
self.dir_shape = x_shape
if self.pad_on_demand is not None:
alignment = 2**self.level
x_axes = np.array(self.dir_shape)[np.array(self.axes)]
self.pad_axes = (alignment - x_axes % alignment) % alignment
adj_x_shape = cp.deepcopy(self.dir_shape)
adj_x_shape[np.array(self.axes)] += self.pad_axes
else:
adj_x_shape = self.dir_shape
self.adj_shape = np.array((self.level * (2 ** len(self.axes) - 1) + 1, *adj_x_shape))
super().__init__()
[docs] def direct_swt(self, x: NDArray) -> list:
"""Perform the direct wavelet transform.
:param x: Data to transform.
:type x: NDArray
:return: Transformed data.
:rtype: list
"""
if self.pad_on_demand is not None and np.any(self.pad_axes):
for ax in np.nonzero(self.pad_axes)[0]:
pad_l = np.ceil(self.pad_axes[ax] / 2).astype(int)
pad_h = np.floor(self.pad_axes[ax] / 2).astype(int)
pad_width = [(0, 0)] * len(x.shape)
pad_width[self.axes[ax]] = (pad_l, pad_h)
x = np.pad(x, pad_width, mode=self.pad_on_demand)
return pywt.swtn(x, wavelet=self.wavelet, axes=self.axes, norm=self.normalized, level=self.level, trim_approx=True)
[docs] def inverse_swt(self, y: list) -> NDArray:
"""Perform the inverse wavelet transform.
:param x: Data to anti-transform.
:type x: list
:return: Anti-transformed data.
:rtype: NDArray
"""
x = pywt.iswtn(y, wavelet=self.wavelet, axes=self.axes, norm=self.normalized)
if self.pad_on_demand is not None and np.any(self.pad_axes):
for ax in np.nonzero(self.pad_axes)[0]:
pad_l = np.ceil(self.pad_axes[ax] / 2).astype(int)
pad_h = np.floor(self.pad_axes[ax] / 2).astype(int)
slices = [slice(None)] * len(x.shape)
slices[self.axes[ax]] = slice(pad_l, x.shape[self.axes[ax]] - pad_h, 1)
x = x[tuple(slices)]
return x
def _op_direct(self, x: NDArray) -> NDArray:
y = self.direct_swt(x)
y = [y[0]] + [y[lvl][x] for lvl in range(1, self.level + 1) for x in self.labels]
return np.array(y)
def _op_adjoint(self, y: NDArray) -> NDArray:
def get_lvl_pos(lvl):
return (lvl - 1) * (2 ** len(self.axes) - 1) + 1
x = [y[0]] + [
dict(((k, y[ii_lbl + get_lvl_pos(lvl), ...]) for ii_lbl, k in enumerate(self.labels)))
for lvl in range(1, self.level + 1)
]
return self.inverse_swt(x)
[docs]class TransformGradient(BaseTransform):
"""
Gradient operator.
Parameters
----------
x_shape : ArrayLike
Shape of the data to be transformed.
axes : Optional[ArrayLike], optional
Axes along which to do the gradient. The default is None.
pad_mode : str, optional
Padding mode of the gradient. The default is "edge".
"""
def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None, pad_mode: str = "edge"):
x_shape = np.array(x_shape, ndmin=1, dtype=int)
if axes is None:
axes = np.arange(-len(x_shape), 0, dtype=int)
self.axes = np.array(axes, ndmin=1, dtype=int)
self.ndims = len(x_shape)
self.pad_mode = pad_mode.lower()
self.dir_shape = x_shape
self.adj_shape = np.array((len(self.axes), *self.dir_shape), ndmin=1, dtype=int)
super().__init__()
[docs] def gradient(self, x: NDArray) -> NDArray:
"""Compute the gradient.
:param x: Input data.
:type x: NDArray
:return: Gradient of data.
:rtype: NDArray
"""
d = [np.array([])] * len(self.axes)
for ii, ax in enumerate(self.axes):
padding = np.zeros((self.ndims, 2), dtype=int)
padding[ax, 1] = 1
temp_x = np.pad(x, padding, mode=self.pad_mode) # type: ignore
d[ii] = np.diff(temp_x, n=1, axis=ax)
return np.stack(d, axis=0)
[docs] def divergence(self, x: NDArray) -> NDArray:
"""Compute the divergence - transpose of gradient.
:param x: Input data.
:type x: NDArray
:return: Divergence of data.
:rtype: NDArray
"""
d = [np.array([])] * len(self.axes)
for ii, ax in enumerate(self.axes):
padding = np.zeros((self.ndims, 2), dtype=int)
padding[ax, 0] = 1
temp_x = np.pad(x[ii, ...], padding, mode=self.pad_mode) # type: ignore
d[ii] = np.diff(temp_x, n=1, axis=ax)
return np.sum(np.stack(d, axis=0), axis=0)
def _op_direct(self, x: NDArray) -> NDArray:
return self.gradient(x)
def _op_adjoint(self, y: NDArray) -> NDArray:
return -self.divergence(y)
[docs]class TransformFourier(BaseTransform):
"""Fourier transform operator."""
def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None):
"""Fourier transform.
:param x_shape: Shape of the data to be wavelet transformed.
:type x_shape: ArrayLike
:param axes: Axes along which to do the gradient, defaults to None
:type axes: int or tuple of int, optional
"""
x_shape = np.array(x_shape, ndmin=1, dtype=int)
if axes is None:
axes = np.arange(-len(x_shape), 0, dtype=int)
self.axes = np.array(axes, ndmin=1, dtype=int)
self.ndims = len(x_shape)
self.dir_shape = x_shape
self.adj_shape = np.array((2, *self.dir_shape), ndmin=1, dtype=int)
super().__init__()
[docs] def fft(self, x: NDArray) -> NDArray:
"""Compute the fft.
:param x: Input data.
:type x: NDArray
:return: FFT of data.
:rtype: NDArray
"""
d = np.empty(self.adj_shape, dtype=x.dtype)
x_f = np.fft.fftn(x, axes=tuple(self.axes), norm="ortho")
d[0, ...] = x_f.real
d[1, ...] = x_f.imag
return d
[docs] def ifft(self, x: NDArray) -> NDArray:
"""Compute the inverse of the fft.
:param x: Input data.
:type x: NDArray
:return: iFFT of data.
:rtype: NDArray
"""
d = x[0, ...] + 1j * x[1, ...]
return np.fft.ifftn(d, axes=tuple(self.axes), norm="ortho").real
def _op_direct(self, x: NDArray) -> NDArray:
return self.fft(x)
def _op_adjoint(self, y: NDArray) -> NDArray:
return self.ifft(y)
[docs]class TransformLaplacian(BaseTransform):
"""
Laplacian transform operator.
Parameters
----------
x_shape : ArrayLike
Shape of the data to be transformed.
axes : ArrayLike, optional
Axes along which to do the Laplacian. The default is None.
pad_mode : str, optional
Padding mode of the Laplacian. The default is "edge".
"""
def __init__(self, x_shape: ArrayLike, axes: Optional[ArrayLike] = None, pad_mode: str = "edge"):
x_shape = np.array(x_shape, ndmin=1, dtype=int)
if axes is None:
axes = np.arange(-len(x_shape), 0, dtype=int)
self.axes = np.array(axes, ndmin=1, dtype=int)
self.ndims = len(x_shape)
self.pad_mode = pad_mode.lower()
self.dir_shape = x_shape
self.adj_shape = x_shape
super().__init__()
[docs] def laplacian(self, x: NDArray) -> NDArray:
"""Compute the laplacian.
:param x: Input data.
:type x: NDArray
:return: Gradient of data.
:rtype: NDArray
"""
d = [np.array([])] * len(self.axes)
for ii, ax in enumerate(self.axes):
padding = np.zeros((self.ndims, 2), dtype=int)
padding[ax, :] = 1
temp_x = np.pad(x, padding, mode=self.pad_mode) # type: ignore
d[ii] = np.diff(temp_x, n=2, axis=ax)
return np.sum(d, axis=0)
def _op_direct(self, x: NDArray) -> NDArray:
return self.laplacian(x)
def _op_adjoint(self, y: NDArray) -> NDArray:
return self.laplacian(y)
[docs]class TransformSVD(BaseTransform):
"""Singular value decomposition based decomposition operator."""
U: Optional[NDArray]
Vt: Optional[NDArray]
def __init__(self, x_shape, axes_rows=(0,), axes_cols=(-1,), rescale: bool = False):
"""Singular value decomposition operator.
The SVD decomposition will be done over the flattened rows vs flatted cols.
This means that the channels should always be the rows (expected to be only
one dimension, usually), while the volume dimensions should always be the
columns (expected to be the last two or three dimensions).
:param x_shape: Shape of the data to be wavelet transformed.
:type x_shape: `numpy.array_like`
:param axes_rows: Axes expanded in rows of the SVD, defaults to (0, )
:type axes_rows: tuple of int, optional
:param axes_cols: Axes expanded in cols of the SVD, defaults to (-1, )
:type axes_cols: tuple of int, optional
:raises IndexError: In case the the axes are outside the range.
"""
self.dir_shape = np.array(x_shape, ndmin=1, dtype=int)
self.axes_rows = np.atleast_1d(axes_rows) % len(self.dir_shape)
self.axes_cols = np.atleast_1d(axes_cols) % len(self.dir_shape)
# Dimensions to decompose
self.append_dims = np.concatenate((self.axes_rows, self.axes_cols)).astype(int)
# Dimensions to NOT decompose
temp_dims = np.arange(len(self.dir_shape), dtype=int)
self.invariant_dims = np.delete(temp_dims, self.append_dims)
self.invariant_dims_shape = self.dir_shape[self.invariant_dims]
# Transpose operation to prepare data for decomposition
self.fwd_transpose = np.concatenate((self.invariant_dims, self.append_dims))
# Transpose operation to recover data after re-composition
self.bwd_transpose = np.argsort(self.fwd_transpose)
self.axes_rows_shape = self.dir_shape[self.axes_rows]
self.axes_cols_shape = self.dir_shape[self.axes_cols]
self.axes_rows_size = (np.prod(self.axes_rows_shape),)
self.axes_cols_size = (np.prod(self.axes_cols_shape),)
# Reshape operation to prepare data for decomposition
self.fwd_shape = np.concatenate((self.invariant_dims_shape, self.axes_rows_size, self.axes_cols_size))
# Reshape operation to recover data after re-composition
self.bwd_shape = np.concatenate((self.invariant_dims_shape, self.axes_rows_shape, self.axes_cols_shape))
self.adj_shape = np.concatenate((self.invariant_dims_shape, np.fmin(self.axes_rows_size, self.axes_cols_size)))
self.U = None
self.Vt = None
self.rescale = rescale
super().__init__()
[docs] def direct_svd(self, x):
"""Performs the SVD decomposition.
:param x: Data to transform.
:type x: `numpy.array_like`
:return: Transformed data.
:rtype: tuple(U, s, Vt)
"""
return np.linalg.svd(x, full_matrices=False)
[docs] def inverse_svd(self, U, s, Vt):
"""Performs the inverse SVD decomposition.
:param U: Rows of the SVD dcomposition.
:type U: `numpy.array_like`
:param s: Singular values.
:type s: `numpy.array_like`
:param Vt: Columns of the SVD dcomposition.
:type Vt: `numpy.array_like`
:return: Anti-transformed data.
:rtype: `numpy.array_like`
"""
return np.matmul(U, s[..., None] * Vt)
def _op_direct(self, x):
x = np.transpose(x, self.fwd_transpose)
x = np.reshape(x, self.fwd_shape)
(self.U, s, self.Vt) = self.direct_svd(x)
if self.rescale:
s /= np.sqrt(self.Vt.shape[-1] * self.U.shape[-2])
self.Vt *= np.sqrt(self.Vt.shape[-1])
self.U *= np.sqrt(self.U.shape[-2])
return s
def _op_adjoint(self, y):
if self.U is None or self.Vt is None:
raise ValueError("Operator not initialized!")
x = self.inverse_svd(self.U, y, self.Vt)
x = np.reshape(x, self.bwd_shape)
return np.transpose(x, self.bwd_transpose)
if __name__ == "__main__":
test_vol = np.zeros((10, 10), dtype=np.float32)
H = TransformStationaryWavelet(test_vol.shape, "db1", 2)
Htw = H.T.explicit()
Hw = H.explicit()
D = TransformGradient(test_vol.shape)
Dg = D.explicit()