Source code for corrct.regularizers

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Regularizers module.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

import numpy as np
import numpy.random

import scipy.ndimage as spimg

from . import operators
from . import data_terms

from abc import ABC, abstractmethod

try:
    import pywt

    has_pywt = True
    use_swtn = pywt.version.version >= "1.0.2"  # type: ignore
except ImportError:
    has_pywt = False
    use_swtn = False
    print("WARNING - pywt was not found")


from typing import Union, Sequence, Optional
from numpy.typing import NDArray, DTypeLike

NDArrayInt = NDArray[np.integer]


# ---- Data Fidelity terms ----


DataFidelityBase = data_terms.DataFidelityBase
DataFidelity_l2 = data_terms.DataFidelity_l2
DataFidelity_wl2 = data_terms.DataFidelity_wl2
DataFidelity_l2b = data_terms.DataFidelity_l2b
DataFidelity_l12 = data_terms.DataFidelity_l12
DataFidelity_l1 = data_terms.DataFidelity_l1
DataFidelity_l1b = data_terms.DataFidelity_l1b
DataFidelity_Huber = data_terms.DataFidelity_Huber
DataFidelity_KL = data_terms.DataFidelity_KL
DataFidelity_ln = data_terms.DataFidelity_ln


# ---- Regularizers ----


[docs]class BaseRegularizer(ABC): """ Initilizie a base regularizer class, that defines the Regularizer object interface. Parameters ---------- weight : Union[float, NDArray] The weight of the regularizer. norm : DataFidelityBase The norm of the regularizer minimization. """ __reg_name__ = "" weight: NDArray dtype: DTypeLike op: Union[operators.BaseTransform, None] sigma: Union[float, NDArray] upd_mask: Optional[NDArray] def __init__( self, weight: Union[float, NDArray], norm: data_terms.DataFidelityBase, upd_mask: Optional[NDArray] = None, dtype: DTypeLike = np.float32, ): self.weight = np.array(weight) self.dtype = dtype self.op = None self.norm = norm self.upd_mask = upd_mask
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(w:%g" % self.weight.max() + ")"
[docs] def upper(self) -> str: """ Return the upper case name of the regularizer. Returns ------- str Upper case string name of the regularizer. """ return self.__reg_name__.upper()
[docs] def lower(self) -> str: """ Return the lower case name of the regularizer. Returns ------- str Lower case string name of the regularizer. """ return self.__reg_name__.lower()
[docs] @abstractmethod def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: """ Initialize the internal state, operator, and sigma. It then returns the tau. Parameters ---------- primal : NDArray The primal vector. Returns ------- Union[float, NDArray] The tau to be used in the SIRT or PDHG algorithm. """
[docs] def initialize_dual(self) -> NDArray: """ Return the initialized dual. Returns ------- NDArray Initialized (zero) dual. """ if self.op is None: raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.") return np.zeros(self.op.adj_shape, dtype=self.dtype)
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: """ Update the dual in-place. Parameters ---------- dual : NDArray Current stat of the dual. primal : NDArray Primal or over-relaxation of the primal. """ if self.op is None: raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.") dual += self.sigma * self.op(primal)
[docs] def apply_proximal(self, dual: NDArray) -> None: """ Apply the proximal operator to the dual in-place. Parameters ---------- dual : NDArray The dual to be applied the proximal on. """ if isinstance(self.norm, DataFidelity_l1): self.norm.apply_proximal(dual, self.weight) else: self.norm.apply_proximal(dual)
[docs] def compute_update_primal(self, dual: NDArray) -> NDArray: """ Compute the partial update of a primal term, from this regularizer. Parameters ---------- dual : NDArray The dual associated to this regularizer. Returns ------- upd : NDArray The update to the primal. """ if self.op is None: raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.") upd = self.op.T(dual) if self.upd_mask is not None: upd *= self.upd_mask if not isinstance(self.norm, DataFidelity_l1): upd *= self.weight return upd
def _check_primal(self, primal: NDArray) -> None: if self.dtype != primal.dtype: print(f"WARNING: Regularizer dtype ({self.dtype}) and primal dtype ({primal.dtype}) are different!") self.dtype = primal.dtype
[docs]class Regularizer_Grad(BaseRegularizer): """Gradient regularizer. When used with l1-norms, it promotes piece-wise constant reconstructions. When used with l2-norm, it promotes smooth reconstructions. Parameters ---------- weight : Union[float, NDArray] The weight of the regularizer. ndims : int, optional The number of dimensions. The default is 2. axes : Sequence, optional The axes over which it computes the gradient. If None, it uses the last 2. The default is None. pad_mode: str, optional The padding mode to use for the linear convolution. The default is "edge". norm : DataFidelityBase, optional The norm of the regularizer minimization. The default is DataFidelity_l12(). """ __reg_name__ = "grad" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l12(), ): super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.pad_mode = pad_mode.lower()
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformGradient(primal.shape, axes=self.axes, pad_mode=self.pad_mode) self.sigma = 0.5 self.norm.assign_data(None, sigma=self.sigma) tau = 2 * self.ndims if not isinstance(self.norm, DataFidelity_l1): tau *= self.weight if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs]class Regularizer_TV1D(Regularizer_Grad): """Total Variation (TV) regularizer in 1D. It can be used to promote piece-wise constant reconstructions.""" __reg_name__ = "TV1D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l12(), ): super().__init__(weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_TV2D(Regularizer_Grad): """Total Variation (TV) regularizer in 2D. It can be used to promote piece-wise constant reconstructions.""" __reg_name__ = "TV2D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l12(), ): super().__init__(weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_TV3D(Regularizer_Grad): """Total Variation (TV) regularizer in 3D. It can be used to promote piece-wise constant reconstructions.""" __reg_name__ = "TV3D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l12(), ): super().__init__(weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_HubTV2D(Regularizer_Grad): """Total Variation (TV) regularizer in 2D. It can be used to promote piece-wise constant reconstructions.""" __reg_name__ = "HubTV2D" def __init__( self, weight: Union[float, NDArray], huber_size: float, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): super().__init__( weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask, norm=DataFidelity_Huber(huber_size, l2_axis=0), )
[docs]class Regularizer_HubTV3D(Regularizer_Grad): """Total Variation (TV) regularizer in 3D. It can be used to promote piece-wise constant reconstructions.""" __reg_name__ = "HubTV3D" def __init__( self, weight: Union[float, NDArray], huber_size: float, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): super().__init__( weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask, norm=DataFidelity_Huber(huber_size, l2_axis=0), )
[docs]class Regularizer_smooth1D(Regularizer_Grad): """It can be used to promote smooth reconstructions.""" __reg_name__ = "smooth1D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l2(), ): super().__init__(weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_smooth2D(Regularizer_Grad): """It can be used to promote smooth reconstructions.""" __reg_name__ = "smooth2D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l2(), ): super().__init__(weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_smooth3D(Regularizer_Grad): """It can be used to promote smooth reconstructions.""" __reg_name__ = "smooth3D" def __init__( self, weight: Union[float, NDArray], axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l2(), ): super().__init__(weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]class Regularizer_lap(BaseRegularizer): """Laplacian regularizer. It can be used to promote smooth reconstructions.""" __reg_name__ = "lap" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): super().__init__(weight=weight, norm=DataFidelity_l1(), upd_mask=upd_mask) if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.pad_mode = pad_mode.lower()
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformLaplacian(primal.shape, axes=self.axes, pad_mode=self.pad_mode) self.sigma = 0.25 self.norm.assign_data(None, sigma=self.sigma) tau = 4 * self.ndims if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs]class Regularizer_lap1D(Regularizer_lap): """Laplacian regularizer in 1D. It can be used to promote smooth reconstructions.""" __reg_name__ = "lap1D" def __init__( self, weight, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): Regularizer_lap.__init__(self, weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]class Regularizer_lap2D(Regularizer_lap): """Laplacian regularizer in 2D. It can be used to promote smooth reconstructions.""" __reg_name__ = "lap2D" def __init__( self, weight, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): Regularizer_lap.__init__(self, weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]class Regularizer_lap3D(Regularizer_lap): """Laplacian regularizer in 3D. It can be used to promote smooth reconstructions.""" __reg_name__ = "lap3D" def __init__( self, weight, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, ): Regularizer_lap.__init__(self, weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]class Regularizer_l1(BaseRegularizer): """l1-norm regularizer. It can be used to promote sparse reconstructions.""" __reg_name__ = "l1" def __init__( self, weight: Union[float, NDArray], upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l1() ): super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformIdentity(primal.shape) self.norm.assign_data(None, sigma=1) tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: dual += primal
[docs]class Regularizer_swl(BaseRegularizer): """Base stationary wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain.""" __reg_name__ = "swl"
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(t:" + self.wavelet + "-l:%d" % self.level + "-w:%g" % self.weight.max() + ")"
def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, norm: DataFidelityBase = DataFidelity_l1(), ): if not has_pywt: raise ValueError("Cannot use wavelet regularizer because pywavelets is not installed.") if not use_swtn: raise ValueError("Cannot use stationary wavelet regularizer because pywavelets is too old (<1.0.2).") super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) self.wavelet = wavelet self.level = level self.normalized = normalized self.min_approx = min_approx if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.pad_on_demand = pad_on_demand
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformStationaryWavelet( primal.shape, wavelet=self.wavelet, level=self.level, axes=self.axes, pad_on_demand=self.pad_on_demand, normalized=self.normalized, ) filt_bank_l1norm = np.linalg.norm(self.op.w.filter_bank, ord=1, axis=-1) lo_dec_mult = filt_bank_l1norm[0] ** self.ndims lo_rec_mult = filt_bank_l1norm[2] ** self.ndims self.dec_func_mult = ( self.op.wlet_dec_filter_mult[None, :] * (lo_dec_mult ** np.arange(self.level - 1, -1, -1))[:, None] ) self.dec_func_mult = np.concatenate(([lo_dec_mult**self.level], self.dec_func_mult.flatten())) self.rec_func_mult = ( self.op.wlet_rec_filter_mult[None, :] * (lo_rec_mult ** np.arange(self.level - 1, -1, -1))[:, None] ) self.rec_func_mult = np.concatenate(([lo_rec_mult**self.level], self.rec_func_mult.flatten())) # self.dec_func_mult = 2 ** np.arange(self.level, 0, -1) # self.dec_func_mult = np.tile(self.dec_func_mult[:, None], [1, (2 ** self.ndims) - 1]) # self.dec_func_mult = np.concatenate(([self.dec_func_mult[0, 0]], self.dec_func_mult.flatten())) if self.normalized: self.sigma = 1 self.norm.assign_data(None, sigma=self.sigma) tau = self.dec_func_mult.size else: self.sigma = np.reshape(1 / self.dec_func_mult, [-1] + [1] * len(self.op.dir_shape)) self.norm.assign_data(None, sigma=self.sigma) tau = np.ones_like(self.rec_func_mult) * ((2**self.ndims) - 1) tau[0] += 1 tau = np.sum(tau / self.rec_func_mult) if not isinstance(self.norm, DataFidelity_l1): tau *= self.weight if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: if self.op is None: raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.") upd = self.op(primal) if not self.normalized: upd *= self.sigma dual += upd if not self.min_approx: dual[0, ...] = 0
[docs] def apply_proximal(self, dual: NDArray) -> None: if isinstance(self.norm, DataFidelity_l12): tmp_dual = dual[1:] tmp_dual = tmp_dual.reshape([-1, self.level, *dual.shape[1:]]) self.norm.apply_proximal(tmp_dual, self.weight) tmp_dual = dual[0:1:] self.norm.apply_proximal(tmp_dual, self.weight) else: super().apply_proximal(dual)
[docs]class Regularizer_l1swl(Regularizer_swl): """l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain.""" __reg_name__ = "l1swl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, normalized=normalized, min_approx=min_approx, norm=DataFidelity_l1(), )
[docs]class Regularizer_l12swl(Regularizer_swl): """l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain.""" __reg_name__ = "l12swl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, normalized=normalized, min_approx=min_approx, norm=DataFidelity_l12(), )
[docs]class Regularizer_Hub_swl(Regularizer_swl): """l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain.""" __reg_name__ = "Hubswl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, huber_size: Optional[int] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, normalized=normalized, min_approx=min_approx, norm=DataFidelity_Huber(huber_size), )
[docs]class Regularizer_dwl(BaseRegularizer): """Base decimated wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain.""" __reg_name__ = "dwl"
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(t:" + self.wavelet + "-l:%d" % self.level + "-w:%g" % self.weight.max() + ")"
def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, min_approx: bool = True, norm: DataFidelityBase = DataFidelity_l1(), ): if not has_pywt: raise ValueError("Cannot use wavelet regularizer because pywavelets is not installed.") super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) self.wavelet = wavelet self.level = level self.min_approx = min_approx if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.pad_on_demand = pad_on_demand
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformDecimatedWavelet( primal.shape, wavelet=self.wavelet, level=self.level, axes=self.axes, pad_on_demand=self.pad_on_demand ) filt_bank_l1norm = np.linalg.norm(self.op.w.filter_bank, ord=1, axis=-1) lo_dec_mult = filt_bank_l1norm[0] ** self.ndims lo_rec_mult = filt_bank_l1norm[2] ** self.ndims self.dec_func_mult = ( self.op.wlet_dec_filter_mult[None, :] * (lo_dec_mult ** np.arange(self.level - 1, -1, -1))[:, None] ) self.dec_func_mult = np.concatenate(([lo_dec_mult**self.level], self.dec_func_mult.flatten())) self.rec_func_mult = ( self.op.wlet_rec_filter_mult[None, :] * (lo_rec_mult ** np.arange(self.level - 1, -1, -1))[:, None] ) self.rec_func_mult = np.concatenate(([lo_rec_mult**self.level], self.rec_func_mult.flatten())) # self.dec_func_mult = 2 ** np.arange(self.level, 0, -1) # self.rec_func_mult = self.dec_func_mult tmp_sigma = [np.ones(self.op.sub_band_shapes[0], self.dtype) * self.dec_func_mult[0]] count = 0 for ii_l in range(self.level): d = {} for label in self.op.sub_band_shapes[ii_l + 1].keys(): # d[label] = np.ones(self.op.sub_band_shapes[ii_l + 1][label], self.dtype) * self.dec_func_mult[ii_l] d[label] = np.ones(self.op.sub_band_shapes[ii_l + 1][label], self.dtype) * self.dec_func_mult[count] count += 1 tmp_sigma.append(d) self.sigma, _ = pywt.coeffs_to_array(tmp_sigma, axes=self.axes) self.norm.assign_data(None, sigma=self.sigma) tau = np.ones_like(self.rec_func_mult) * ((2**self.ndims) - 1) tau[0] += 1 tau = np.sum(tau / self.rec_func_mult) if not isinstance(self.norm, DataFidelity_l1): tau *= self.weight if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: super().update_dual(dual, primal) if not self.min_approx: if self.op is None: raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.") op_wl: operators.TransformDecimatedWavelet = self.op # type: ignore slices = [slice(0, x) for x in op_wl.sub_band_shapes[0]] dual[tuple(slices)] = 0
[docs] def apply_proximal(self, dual: NDArray) -> None: if isinstance(self.norm, DataFidelity_l12): op_wl: operators.TransformDecimatedWavelet = self.op coeffs = pywt.array_to_coeffs(dual, op_wl.slicing_info) for ii_l in range(1, len(coeffs)): c_l = coeffs[ii_l] labels = [] details = [] for lab, det in c_l.items(): labels.append(lab) details.append(det) c_ll = np.stack(details, axis=0) self.norm.apply_proximal(c_ll, self.weight) for ii, lab in enumerate(labels): c_l[lab] = c_ll[ii] coeffs[ii_l] = c_l self.norm.apply_proximal(coeffs[0], self.weight) dual[:] = pywt.coeffs_to_array(coeffs)[0] else: super().apply_proximal(dual)
[docs]class Regularizer_l1dwl(Regularizer_dwl): """l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions.""" __reg_name__ = "l1dwl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, norm=DataFidelity_l1(), )
[docs]class Regularizer_l12dwl(Regularizer_dwl): """l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions.""" __reg_name__ = "l12dwl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, norm=DataFidelity_l12(), )
[docs]class Regularizer_Hub_dwl(Regularizer_dwl): """l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions.""" __reg_name__ = "Hubdwl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, huber_size: Optional[int] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, norm=DataFidelity_Huber(huber_size), )
[docs]class BaseRegularizer_med(BaseRegularizer): """Median filter regularizer base class. It can be used to promote filtered reconstructions.""" __reg_name__ = "med"
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(s:%s" % np.array(self.filt_size) + "-w:%g" % self.weight.max() + ")"
def __init__( self, weight: Union[float, NDArray], filt_size: int = 3, upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l1(), ): super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) self.filt_size = filt_size
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformIdentity(primal.shape) self.norm.assign_data(None, sigma=1) if not isinstance(self.norm, DataFidelity_l1): tau = self.weight else: tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: dual += primal - spimg.median_filter(primal, self.filt_size)
[docs]class Regularizer_l1med(BaseRegularizer_med): """l1-norm median filter regularizer. It can be used to promote filtered reconstructions.""" __reg_name__ = "l1med" def __init__(self, weight: Union[float, NDArray], filt_size: int = 3): BaseRegularizer_med.__init__(self, weight, filt_size=filt_size, norm=DataFidelity_l1())
[docs]class Regularizer_l2med(BaseRegularizer_med): """l2-norm median filter regularizer. It can be used to promote filtered reconstructions.""" __reg_name__ = "l2med" def __init__(self, weight: Union[float, NDArray], filt_size: int = 3): BaseRegularizer_med.__init__(self, weight, filt_size=filt_size, norm=DataFidelity_l2())
[docs]class Regularizer_fft(BaseRegularizer): """Fourier regularizer. It can be used to promote sparse reconstructions in the Fourier domain.""" __reg_name__ = "fft" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, fft_filter: str = "exp", upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l12(), ): super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.fft_filter = fft_filter
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformFourier(primal.shape, axes=self.axes) if isinstance(self.fft_filter, str): coords = [np.fft.fftfreq(s) for s in self.op.adj_shape[self.axes]] coords = np.array(np.meshgrid(*coords, indexing="ij")) if self.fft_filter.lower() == "delta": self.sigma = 1 - np.all(coords == 0, axis=0) elif self.fft_filter.lower() == "exp": self.sigma = 1 - np.exp(-np.sqrt(np.sum(coords**2, axis=0)) * 12) elif self.fft_filter.lower() == "exp2": self.sigma = 1 - np.exp(-np.sum(coords**2, axis=0) * 36) else: raise ValueError('Unknown FFT mask: %s. Options are: "delta", "exp". and "exp2".' % self.fft_filter) new_shape = np.ones_like(self.op.adj_shape) new_shape[self.axes] = self.op.adj_shape[self.axes] self.sigma = np.reshape(self.sigma, new_shape) else: self.sigma = 1 self.norm.assign_data(None, sigma=self.sigma) if not isinstance(self.norm, DataFidelity_l1): tau = self.weight else: tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
# Multi-channel regularizers
[docs]class Regularizer_TNV(Regularizer_Grad): """Total Nuclear Variation (TNV) regularizer. It can be used to promote piece-wise constant reconstructions, for multi-channel volumes. """ __reg_name__ = "TNV" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_mode: str = "edge", upd_mask: Optional[NDArray] = None, spectral_norm: DataFidelityBase = DataFidelity_l1(), x_ref: Optional[NDArray] = None, ): super().__init__(weight=weight, ndims=ndims, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask) # Here we assume that the channels will be the rows and the derivatives the columns self.norm = DataFidelity_ln(ln_axes=(1, 0), spectral_norm=spectral_norm) self.x_ref = x_ref
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: tau = super().initialize_sigma_tau(primal) if self.x_ref is not None: if self.op is None: raise ValueError("Regularizer should have been initialized... this is a bug!") self.q_ref = self.op(self.x_ref) self.q_ref = np.expand_dims(self.q_ref, axis=1) return tau
[docs]class Regularizer_VTV(Regularizer_Grad): """Vectorial Total Variation (VTV) regularizer. It can be used to promote piece-wise constant reconstructions, for multi-channel volumes. """ __reg_name__ = "VTV" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, pwise_der_norm: Union[int, float] = 2, pwise_chan_norm: Union[int, float] = np.inf, x_ref: Optional[NDArray] = None, upd_mask: Optional[NDArray] = None, ): super().__init__(weight=weight, ndims=ndims, upd_mask=upd_mask) self.pwise_der_norm = pwise_der_norm if self.pwise_der_norm not in [1, 2, np.Inf]: self._raise_pwise_norm_error() self.pwise_chan_norm = pwise_chan_norm if self.pwise_chan_norm not in [1, 2, np.Inf]: self._raise_pwise_norm_error() if x_ref is not None: self.initialize_sigma_tau(x_ref) q_ref = self.initialize_dual() self.update_dual(q_ref, x_ref) self.q_ref = np.expand_dims(q_ref, axis=1) else: self.q_ref = None def _raise_pwise_norm_error(self): raise ValueError( "The only supported point-wise norm exponents are: 1, 2, and Inf." + f" Provided the following instead: derivatives={self.pwise_der_norm}, channel={self.pwise_chan_norm}" )
[docs] def apply_proximal(self, dual: NDArray) -> None: # Following assignments will detach the local array from the original one dual_tmp = dual.copy() dual_is_scalar = len(dual_tmp.shape) == (self.ndims + 1) if dual_is_scalar: dual_tmp = np.expand_dims(dual_tmp, axis=1) if self.q_ref is not None: dual_tmp = np.concatenate((dual_tmp, self.q_ref), axis=1) if self.pwise_der_norm == 1: grad_norm = np.abs(dual_tmp) elif self.pwise_der_norm == 2: grad_norm = np.linalg.norm(dual_tmp, axis=0, ord=2, keepdims=True) elif self.pwise_der_norm == np.inf: grad_norm = np.linalg.norm(dual_tmp, axis=0, ord=1, keepdims=True) else: self._raise_pwise_norm_error() if self.pwise_chan_norm == 1: dual_norm = grad_norm elif self.pwise_chan_norm == 2: dual_norm = np.linalg.norm(grad_norm, axis=1, ord=2, keepdims=True) elif self.pwise_chan_norm == np.inf: dual_norm = np.linalg.norm(grad_norm, axis=1, ord=1, keepdims=True) else: self._raise_pwise_norm_error() dual_tmp /= np.fmax(dual_norm, self.weight) dual_tmp *= self.weight if self.q_ref is not None: dual_tmp = dual_tmp[:, : dual_tmp.shape[1] - 1 :, ...] if dual_is_scalar: dual_tmp = np.squeeze(dual_tmp, axis=1) dual[:] = dual_tmp[:] # Replacing values
[docs]class Regularizer_lnswl(Regularizer_l1swl): """Nuclear-norm Wavelet regularizer. It can be used to promote compressed multi-channel reconstructions. """ __reg_name__ = "lnswl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, spectral_norm: DataFidelityBase = DataFidelity_l1(), x_ref: Optional[NDArray] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, normalized=normalized, min_approx=min_approx, ) self.norm = DataFidelity_ln(ln_axes=(1, 0), spectral_norm=spectral_norm) self.x_ref = x_ref
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: tau = super().initialize_sigma_tau(primal) if self.x_ref is not None: if self.op is None: raise ValueError("Regularizer should have been initialized... this is a bug!") self.q_ref = self.op(self.x_ref) self.q_ref = np.expand_dims(self.q_ref, axis=1) return tau
[docs]class Regularizer_vl1wl(Regularizer_l1swl): """l1-norm vectorial Wavelet regularizer. It can be used to promote compressed reconstructions.""" __reg_name__ = "vl1wl" def __init__( self, weight: Union[float, NDArray], wavelet: str, level: int, ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, pad_on_demand: str = "constant", upd_mask: Optional[NDArray] = None, normalized: bool = False, min_approx: bool = True, pwise_lvl_norm: Union[int, float] = 1, pwise_chan_norm: Union[int, float] = np.Inf, x_ref: Optional[NDArray] = None, ): super().__init__( weight, wavelet, level, ndims=ndims, axes=axes, pad_on_demand=pad_on_demand, upd_mask=upd_mask, normalized=normalized, min_approx=min_approx, ) self.pwise_lvl_norm = pwise_lvl_norm if self.pwise_lvl_norm not in [1, 2, np.Inf]: self._raise_pwise_norm_error() self.pwise_chan_norm = pwise_chan_norm if self.pwise_chan_norm not in [1, 2, np.Inf]: self._raise_pwise_norm_error() self.x_ref = x_ref self.q_ref = None def _raise_pwise_norm_error(self): raise ValueError( "The only supported point-wise norm exponents are: 1, 2, and Inf." + f" Provided the following instead: level={self.pwise_lvl_norm}, channel={self.pwise_chan_norm}" )
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: tau = super().initialize_sigma_tau(primal) if self.x_ref is not None: if self.op is None: raise ValueError("Regularizer should have been initialized... this is a bug!") self.q_ref = self.op(self.x_ref) return tau
[docs] def apply_proximal(self, dual: NDArray) -> None: dual_tmp = dual.copy() if self.q_ref is not None: dual_tmp = np.concatenate((dual_tmp, self.q_ref), axis=1) if self.pwise_lvl_norm == 1: lvl_norm = np.abs(dual_tmp) elif self.pwise_lvl_norm == 2: lvl_norm = np.linalg.norm(dual_tmp, axis=0, ord=2, keepdims=True) elif self.pwise_lvl_norm == np.inf: lvl_norm = np.linalg.norm(dual_tmp, axis=0, ord=1, keepdims=True) else: self._raise_pwise_norm_error() if self.pwise_chan_norm == 1: dual_norm = lvl_norm elif self.pwise_chan_norm == 2: dual_norm = np.linalg.norm(lvl_norm, axis=1, ord=2, keepdims=True) elif self.pwise_chan_norm == np.inf: dual_norm = np.linalg.norm(lvl_norm, axis=1, ord=1, keepdims=True) else: self._raise_pwise_norm_error() dual_tmp /= np.fmax(dual_norm, self.weight) dual_tmp *= self.weight if self.q_ref is not None: dual_tmp = dual_tmp[:, : dual_tmp.shape[0] - 1 :, ...] dual[:] = dual_tmp[:]
[docs]class Regularizer_vSVD(BaseRegularizer): """Regularizer based on the Singular Value Decomposition. It can be used to promote similar reconstructions across different channels. """ __reg_name__ = "vsvd" def __init__( self, weight: Union[float, NDArray], ndims: int = 2, axes: Union[Sequence[int], NDArray, None] = None, axis_channels: Sequence[int] = (0,), upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l1(), ): super().__init__(weight=weight, norm=norm, upd_mask=upd_mask) if axes is None: axes = np.arange(-ndims, 0, dtype=int) elif not ndims == len(axes): print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.") ndims = len(axes) self.ndims = ndims self.axes = axes self.axis_channels = axis_channels
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformSVD(primal.shape, axes_rows=self.axis_channels, axes_cols=self.axes, rescale=True) self.sigma = 1 self.norm.assign_data(None, sigma=self.sigma) if not isinstance(self.norm, DataFidelity_l1): tau = self.weight else: tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
# ---- Constraints ----
[docs]class Constraint_LowerLimit(BaseRegularizer): """Lower limit constraint. It can be used to promote reconstructions in certain regions of solution space.""" __reg_name__ = "lowlim"
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(l:%g" % self.limit + ")"
def __init__( self, limit: Union[float, NDArray], upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l2() ): super().__init__(weight=1, norm=norm, upd_mask=upd_mask) self.limit = limit
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformIdentity(primal.shape) self.norm.assign_data(self.limit, sigma=1) if not isinstance(self.norm, DataFidelity_l1): tau = self.weight else: tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: dual += primal
[docs] def apply_proximal(self, dual: NDArray) -> None: dual[dual > self.limit] = self.limit self.norm.apply_proximal(dual)
[docs]class Constraint_UpperLimit(BaseRegularizer): """Upper limit constraint. It can be used to promote reconstructions in certain regions of solution space.""" __reg_name__ = "uplim"
[docs] def info(self) -> str: """ Return the regularizer info. Returns ------- str Regularizer info string. """ return self.__reg_name__ + "(l:%g" % self.limit + ")"
def __init__( self, limit: Union[float, NDArray], upd_mask: Optional[NDArray] = None, norm: DataFidelityBase = DataFidelity_l2() ): super().__init__(weight=1, norm=norm, upd_mask=upd_mask) self.limit = limit
[docs] def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]: self._check_primal(primal) self.op = operators.TransformIdentity(primal.shape) self.norm.assign_data(self.limit, sigma=1) if not isinstance(self.norm, DataFidelity_l1): tau = self.weight else: tau = 1.0 if self.upd_mask is not None: tau = tau * self.upd_mask return tau
[docs] def update_dual(self, dual: NDArray, primal: NDArray) -> None: dual += primal
[docs] def apply_proximal(self, dual: NDArray) -> None: dual[dual < self.limit] = self.limit self.norm.apply_proximal(dual)