#!/usr/bin/env python3
"""
Regularizers module.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from abc import ABC, abstractmethod
from typing import Optional, Union
from collections.abc import Sequence
import numpy as np
import scipy.ndimage as spimg
from numpy.typing import DTypeLike, NDArray
from . import data_terms as dt
from . import operators
try:
import pywt
has_pywt = True
use_swtn = pywt.version.version >= "1.0.2" # type: ignore
except ImportError:
has_pywt = False
use_swtn = False
print("WARNING - pywt was not found")
NDArrayInt = NDArray[np.integer]
# ---- Regularizers ----
[docs]
class BaseRegularizer(ABC):
"""
Initialize a base regularizer class, that defines the Regularizer object interface.
Parameters
----------
weight : Union[float, NDArray]
The weight of the regularizer.
norm : DataFidelityBase
The norm of the regularizer minimization.
"""
__reg_name__ = ""
weight: NDArray
dtype: DTypeLike
op: Union[operators.BaseTransform, None]
sigma: Union[float, NDArray]
upd_mask: Optional[NDArray]
def __init__(
self,
weight: Union[float, NDArray],
norm: dt.DataFidelityBase,
upd_mask: Optional[NDArray] = None,
dtype: DTypeLike = np.float32,
):
self.weight = np.array(weight)
self.dtype = dtype
self.op = None
self.norm = norm
self.upd_mask = upd_mask
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + f"(w:{self.weight.max():.3e})"
[docs]
def upper(self) -> str:
"""
Return the upper case name of the regularizer.
Returns
-------
str
Upper case string name of the regularizer.
"""
return self.__reg_name__.upper()
[docs]
def lower(self) -> str:
"""
Return the lower case name of the regularizer.
Returns
-------
str
Lower case string name of the regularizer.
"""
return self.__reg_name__.lower()
[docs]
@abstractmethod
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
"""
Initialize the internal state, operator, and sigma. It then returns the tau.
Parameters
----------
primal : NDArray
The primal vector.
Returns
-------
Union[float, NDArray]
The tau to be used in the SIRT or PDHG algorithm.
"""
[docs]
def initialize_dual(self) -> NDArray:
"""
Return the initialized dual.
Returns
-------
NDArray
Initialized (zero) dual.
"""
if self.op is None:
raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.")
return np.zeros(self.op.adj_shape, dtype=self.dtype)
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
"""
Update the dual in-place.
Parameters
----------
dual : NDArray
Current stat of the dual.
primal : NDArray
Primal or over-relaxation of the primal.
"""
if self.op is None:
raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.")
dual += self.sigma * self.op(primal)
[docs]
def apply_proximal(self, dual: NDArray) -> None:
"""
Apply the proximal operator to the dual in-place.
Parameters
----------
dual : NDArray
The dual to be applied the proximal on.
"""
if isinstance(self.norm, dt.DataFidelity_l1):
self.norm.apply_proximal(dual, self.weight)
else:
self.norm.apply_proximal(dual)
[docs]
def compute_update_primal(self, dual: NDArray) -> NDArray:
"""
Compute the partial update of a primal term, from this regularizer.
Parameters
----------
dual : NDArray
The dual associated to this regularizer.
Returns
-------
upd : NDArray
The update to the primal.
"""
if self.op is None:
raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.")
upd = self.op.T(dual)
if self.upd_mask is not None:
upd *= self.upd_mask
if not isinstance(self.norm, dt.DataFidelity_l1):
upd *= self.weight
return upd
def _check_primal(self, primal: NDArray) -> None:
if self.dtype != primal.dtype:
print(f"WARNING: Regularizer dtype ({self.dtype}) and primal dtype ({primal.dtype}) are different!")
self.dtype = primal.dtype
[docs]
class Regularizer_Grad(BaseRegularizer):
"""Gradient regularizer.
When used with l1-norms, it promotes piece-wise constant reconstructions.
When used with l2-norm, it promotes smooth reconstructions.
Parameters
----------
weight : Union[float, NDArray]
The weight of the regularizer.
ndims : int, optional
The number of dimensions. The default is 2.
axes : Sequence, optional
The axes over which it computes the gradient. If None, it uses the last 2. The default is None.
pad_mode: str, optional
The padding mode to use for the linear convolution. The default is "edge".
norm : DataFidelityBase, optional
The norm of the regularizer minimization. The default is DataFidelity_l12().
"""
__reg_name__ = "grad"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l12(),
):
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.pad_mode = pad_mode.lower()
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformGradient(primal.shape, axes=self.axes, pad_mode=self.pad_mode)
self.sigma = 0.5
self.norm.assign_data(None, sigma=self.sigma)
tau = 2 * self.ndims
if not isinstance(self.norm, dt.DataFidelity_l1):
tau *= self.weight
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
class Regularizer_TV1D(Regularizer_Grad):
"""Total Variation (TV) regularizer in 1D. It can be used to promote piece-wise constant reconstructions."""
__reg_name__ = "TV1D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l12(),
):
super().__init__(weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_TV2D(Regularizer_Grad):
"""Total Variation (TV) regularizer in 2D. It can be used to promote piece-wise constant reconstructions."""
__reg_name__ = "TV2D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l12(),
):
super().__init__(weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_TV3D(Regularizer_Grad):
"""Total Variation (TV) regularizer in 3D. It can be used to promote piece-wise constant reconstructions."""
__reg_name__ = "TV3D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l12(),
):
super().__init__(weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_HubTV2D(Regularizer_Grad):
"""Total Variation (TV) regularizer in 2D. It can be used to promote piece-wise constant reconstructions."""
__reg_name__ = "HubTV2D"
def __init__(
self,
weight: Union[float, NDArray],
huber_size: float,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
super().__init__(
weight=weight,
ndims=2,
axes=axes,
pad_mode=pad_mode,
upd_mask=upd_mask,
norm=dt.DataFidelity_Huber(huber_size, l2_axis=0),
)
[docs]
class Regularizer_HubTV3D(Regularizer_Grad):
"""Total Variation (TV) regularizer in 3D. It can be used to promote piece-wise constant reconstructions."""
__reg_name__ = "HubTV3D"
def __init__(
self,
weight: Union[float, NDArray],
huber_size: float,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
super().__init__(
weight=weight,
ndims=3,
axes=axes,
pad_mode=pad_mode,
upd_mask=upd_mask,
norm=dt.DataFidelity_Huber(huber_size, l2_axis=0),
)
[docs]
class Regularizer_smooth1D(Regularizer_Grad):
"""It can be used to promote smooth reconstructions."""
__reg_name__ = "smooth1D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l2(),
):
super().__init__(weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_smooth2D(Regularizer_Grad):
"""It can be used to promote smooth reconstructions."""
__reg_name__ = "smooth2D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l2(),
):
super().__init__(weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_smooth3D(Regularizer_Grad):
"""It can be used to promote smooth reconstructions."""
__reg_name__ = "smooth3D"
def __init__(
self,
weight: Union[float, NDArray],
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l2(),
):
super().__init__(weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, norm=norm, upd_mask=upd_mask)
[docs]
class Regularizer_lap(BaseRegularizer):
"""Laplacian regularizer. It can be used to promote smooth reconstructions."""
__reg_name__ = "lap"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
super().__init__(weight=weight, norm=dt.DataFidelity_l1(), upd_mask=upd_mask)
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.pad_mode = pad_mode.lower()
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformLaplacian(primal.shape, axes=self.axes, pad_mode=self.pad_mode)
self.sigma = 0.25
self.norm.assign_data(None, sigma=self.sigma)
tau = 4 * self.ndims
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
class Regularizer_lap1D(Regularizer_lap):
"""Laplacian regularizer in 1D. It can be used to promote smooth reconstructions."""
__reg_name__ = "lap1D"
def __init__(
self,
weight,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
Regularizer_lap.__init__(self, weight=weight, ndims=1, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]
class Regularizer_lap2D(Regularizer_lap):
"""Laplacian regularizer in 2D. It can be used to promote smooth reconstructions."""
__reg_name__ = "lap2D"
def __init__(
self,
weight,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
Regularizer_lap.__init__(self, weight=weight, ndims=2, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]
class Regularizer_lap3D(Regularizer_lap):
"""Laplacian regularizer in 3D. It can be used to promote smooth reconstructions."""
__reg_name__ = "lap3D"
def __init__(
self,
weight,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
):
Regularizer_lap.__init__(self, weight=weight, ndims=3, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
[docs]
class Regularizer_l1(BaseRegularizer):
"""l1-norm regularizer. It can be used to promote sparse reconstructions."""
__reg_name__ = "l1"
def __init__(
self,
weight: Union[float, NDArray],
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
):
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformIdentity(primal.shape)
self.norm.assign_data(None, sigma=1)
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
dual += primal
[docs]
class Regularizer_swl(BaseRegularizer):
"""Base stationary wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain."""
__reg_name__ = "swl"
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + "(t:" + self.wavelet + "-l:%d" % self.level + "-w:%g" % self.weight.max() + ")"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
):
if not has_pywt:
raise ValueError("Cannot use wavelet regularizer because pywavelets is not installed.")
if not use_swtn:
raise ValueError("Cannot use stationary wavelet regularizer because pywavelets is too old (<1.0.2).")
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
self.wavelet = wavelet
self.level = level
self.normalized = normalized
self.min_approx = min_approx
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.pad_on_demand = pad_on_demand
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformStationaryWavelet(
primal.shape,
wavelet=self.wavelet,
level=self.level,
axes=self.axes,
pad_on_demand=self.pad_on_demand,
normalized=self.normalized,
)
filt_bank_l1norm = np.linalg.norm(self.op.w.filter_bank, ord=1, axis=-1)
lo_dec_mult = filt_bank_l1norm[0] ** self.ndims
lo_rec_mult = filt_bank_l1norm[2] ** self.ndims
self.dec_func_mult = (
self.op.wlet_dec_filter_mult[None, :] * (lo_dec_mult ** np.arange(self.level - 1, -1, -1))[:, None]
)
self.dec_func_mult = np.concatenate(([lo_dec_mult**self.level], self.dec_func_mult.flatten()))
self.rec_func_mult = (
self.op.wlet_rec_filter_mult[None, :] * (lo_rec_mult ** np.arange(self.level - 1, -1, -1))[:, None]
)
self.rec_func_mult = np.concatenate(([lo_rec_mult**self.level], self.rec_func_mult.flatten()))
# self.dec_func_mult = 2 ** np.arange(self.level, 0, -1)
# self.dec_func_mult = np.tile(self.dec_func_mult[:, None], [1, (2 ** self.ndims) - 1])
# self.dec_func_mult = np.concatenate(([self.dec_func_mult[0, 0]], self.dec_func_mult.flatten()))
if self.normalized:
self.sigma = 1
self.norm.assign_data(None, sigma=self.sigma)
tau = self.dec_func_mult.size
else:
self.sigma = np.reshape(1 / self.dec_func_mult, [-1] + [1] * len(self.op.dir_shape))
self.norm.assign_data(None, sigma=self.sigma)
tau = np.ones_like(self.rec_func_mult) * ((2**self.ndims) - 1)
tau[0] += 1
tau = np.sum(tau / self.rec_func_mult)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau *= self.weight
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
if self.op is None:
raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.")
upd = self.op(primal)
if not self.normalized:
upd *= self.sigma
dual += upd
if not self.min_approx:
dual[0, ...] = 0
[docs]
def apply_proximal(self, dual: NDArray) -> None:
if isinstance(self.norm, dt.DataFidelity_l12):
tmp_dual = dual[1:]
tmp_dual = tmp_dual.reshape([-1, self.level, *dual.shape[1:]])
self.norm.apply_proximal(tmp_dual, self.weight)
tmp_dual = dual[0:1:]
self.norm.apply_proximal(tmp_dual, self.weight)
else:
super().apply_proximal(dual)
[docs]
class Regularizer_l1swl(Regularizer_swl):
"""l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain."""
__reg_name__ = "l1swl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
normalized=normalized,
min_approx=min_approx,
norm=dt.DataFidelity_l1(),
)
[docs]
class Regularizer_l12swl(Regularizer_swl):
"""l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain."""
__reg_name__ = "l12swl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
normalized=normalized,
min_approx=min_approx,
norm=dt.DataFidelity_l12(),
)
[docs]
class Regularizer_Hub_swl(Regularizer_swl):
"""l1-norm Wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain."""
__reg_name__ = "Hubswl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
huber_size: Optional[int] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
normalized=normalized,
min_approx=min_approx,
norm=dt.DataFidelity_Huber(huber_size),
)
[docs]
class Regularizer_dwl(BaseRegularizer):
"""Base decimated wavelet regularizer. It can be used to promote sparse reconstructions in the wavelet domain."""
__reg_name__ = "dwl"
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + "(t:" + self.wavelet + "-l:%d" % self.level + "-w:%g" % self.weight.max() + ")"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
min_approx: bool = True,
norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
):
if not has_pywt:
raise ValueError("Cannot use wavelet regularizer because pywavelets is not installed.")
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
self.wavelet = wavelet
self.level = level
self.min_approx = min_approx
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.pad_on_demand = pad_on_demand
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformDecimatedWavelet(
primal.shape, wavelet=self.wavelet, level=self.level, axes=self.axes, pad_on_demand=self.pad_on_demand
)
filt_bank_l1norm = np.linalg.norm(self.op.w.filter_bank, ord=1, axis=-1)
lo_dec_mult = filt_bank_l1norm[0] ** self.ndims
lo_rec_mult = filt_bank_l1norm[2] ** self.ndims
self.dec_func_mult = (
self.op.wlet_dec_filter_mult[None, :] * (lo_dec_mult ** np.arange(self.level - 1, -1, -1))[:, None]
)
self.dec_func_mult = np.concatenate(([lo_dec_mult**self.level], self.dec_func_mult.flatten()))
self.rec_func_mult = (
self.op.wlet_rec_filter_mult[None, :] * (lo_rec_mult ** np.arange(self.level - 1, -1, -1))[:, None]
)
self.rec_func_mult = np.concatenate(([lo_rec_mult**self.level], self.rec_func_mult.flatten()))
# self.dec_func_mult = 2 ** np.arange(self.level, 0, -1)
# self.rec_func_mult = self.dec_func_mult
tmp_sigma = [np.ones(self.op.sub_band_shapes[0], self.dtype) * self.dec_func_mult[0]]
count = 0
for ii_l in range(self.level):
d = {}
for label in self.op.sub_band_shapes[ii_l + 1].keys():
# d[label] = np.ones(self.op.sub_band_shapes[ii_l + 1][label], self.dtype) * self.dec_func_mult[ii_l]
d[label] = np.ones(self.op.sub_band_shapes[ii_l + 1][label], self.dtype) * self.dec_func_mult[count]
count += 1
tmp_sigma.append(d)
self.sigma, _ = pywt.coeffs_to_array(tmp_sigma, axes=self.axes)
self.norm.assign_data(None, sigma=self.sigma)
tau = np.ones_like(self.rec_func_mult) * ((2**self.ndims) - 1)
tau[0] += 1
tau = np.sum(tau / self.rec_func_mult)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau *= self.weight
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
super().update_dual(dual, primal)
if not self.min_approx:
if self.op is None:
raise ValueError("Regularizer not initialized! Please use method: `initialize_sigma_tau`.")
op_wl: operators.TransformDecimatedWavelet = self.op # type: ignore
slices = [slice(0, x) for x in op_wl.sub_band_shapes[0]]
dual[tuple(slices)] = 0
[docs]
def apply_proximal(self, dual: NDArray) -> None:
if isinstance(self.norm, dt.DataFidelity_l12):
op_wl: operators.TransformDecimatedWavelet = self.op
coeffs = pywt.array_to_coeffs(dual, op_wl.slicing_info)
for ii_l in range(1, len(coeffs)):
c_l = coeffs[ii_l]
labels = []
details = []
for lab, det in c_l.items():
labels.append(lab)
details.append(det)
c_ll = np.stack(details, axis=0)
self.norm.apply_proximal(c_ll, self.weight)
for ii, lab in enumerate(labels):
c_l[lab] = c_ll[ii]
coeffs[ii_l] = c_l
self.norm.apply_proximal(coeffs[0], self.weight)
dual[:] = pywt.coeffs_to_array(coeffs)[0]
else:
super().apply_proximal(dual)
[docs]
class Regularizer_l1dwl(Regularizer_dwl):
"""l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions."""
__reg_name__ = "l1dwl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
norm=dt.DataFidelity_l1(),
)
[docs]
class Regularizer_l12dwl(Regularizer_dwl):
"""l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions."""
__reg_name__ = "l12dwl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
norm=dt.DataFidelity_l12(),
)
[docs]
class Regularizer_Hub_dwl(Regularizer_dwl):
"""l1-norm decimated wavelet regularizer. It can be used to promote sparse reconstructions."""
__reg_name__ = "Hubdwl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
huber_size: Optional[int] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
norm=dt.DataFidelity_Huber(huber_size),
)
[docs]
class BaseRegularizer_med(BaseRegularizer):
"""Median filter regularizer base class. It can be used to promote filtered reconstructions."""
__reg_name__ = "med"
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + "(s:%s" % np.array(self.filt_size) + "-w:%g" % self.weight.max() + ")"
def __init__(
self,
weight: Union[float, NDArray],
filt_size: int = 3,
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
):
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
self.filt_size = filt_size
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformIdentity(primal.shape)
self.norm.assign_data(None, sigma=1)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau = self.weight
else:
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
dual += primal - spimg.median_filter(primal, self.filt_size)
[docs]
class Regularizer_l1med(BaseRegularizer_med):
"""l1-norm median filter regularizer. It can be used to promote filtered reconstructions."""
__reg_name__ = "l1med"
def __init__(self, weight: Union[float, NDArray], filt_size: int = 3):
BaseRegularizer_med.__init__(self, weight, filt_size=filt_size, norm=dt.DataFidelity_l1())
[docs]
class Regularizer_l2med(BaseRegularizer_med):
"""l2-norm median filter regularizer. It can be used to promote filtered reconstructions."""
__reg_name__ = "l2med"
def __init__(self, weight: Union[float, NDArray], filt_size: int = 3):
BaseRegularizer_med.__init__(self, weight, filt_size=filt_size, norm=dt.DataFidelity_l2())
[docs]
class Regularizer_fft(BaseRegularizer):
"""Fourier regularizer. It can be used to promote sparse reconstructions in the Fourier domain."""
__reg_name__ = "fft"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
fft_filter: str = "exp",
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l12(),
):
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.fft_filter = fft_filter
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformFourier(primal.shape, axes=self.axes)
if isinstance(self.fft_filter, str):
coords = [np.fft.fftfreq(s) for s in self.op.adj_shape[self.axes]]
coords = np.array(np.meshgrid(*coords, indexing="ij"))
if self.fft_filter.lower() == "delta":
self.sigma = 1 - np.all(coords == 0, axis=0)
elif self.fft_filter.lower() == "exp":
self.sigma = 1 - np.exp(-np.sqrt(np.sum(coords**2, axis=0)) * 12)
elif self.fft_filter.lower() == "exp2":
self.sigma = 1 - np.exp(-np.sum(coords**2, axis=0) * 36)
else:
raise ValueError('Unknown FFT mask: %s. Options are: "delta", "exp". and "exp2".' % self.fft_filter)
new_shape = np.ones_like(self.op.adj_shape)
new_shape[self.axes] = self.op.adj_shape[self.axes]
self.sigma = np.reshape(self.sigma, new_shape)
else:
self.sigma = 1
self.norm.assign_data(None, sigma=self.sigma)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau = self.weight
else:
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
# Multi-channel regularizers
[docs]
class Regularizer_TNV(Regularizer_Grad):
"""Total Nuclear Variation (TNV) regularizer.
It can be used to promote piece-wise constant reconstructions, for multi-channel volumes.
"""
__reg_name__ = "TNV"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_mode: str = "edge",
upd_mask: Optional[NDArray] = None,
spectral_norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
x_ref: Optional[NDArray] = None,
):
super().__init__(weight=weight, ndims=ndims, axes=axes, pad_mode=pad_mode, upd_mask=upd_mask)
# Here we assume that the channels will be the rows and the derivatives the columns
self.norm = dt.DataFidelity_ln(ln_axes=(1, 0), spectral_norm=spectral_norm)
self.x_ref = x_ref
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
tau = super().initialize_sigma_tau(primal)
if self.x_ref is not None:
if self.op is None:
raise ValueError("Regularizer should have been initialized... this is a bug!")
self.q_ref = self.op(self.x_ref)
self.q_ref = np.expand_dims(self.q_ref, axis=1)
return tau
[docs]
class Regularizer_VTV(Regularizer_Grad):
"""Vectorial Total Variation (VTV) regularizer.
It can be used to promote piece-wise constant reconstructions, for multi-channel volumes.
"""
__reg_name__ = "VTV"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
pwise_der_norm: Union[int, float] = 2,
pwise_chan_norm: Union[int, float] = np.inf,
x_ref: Optional[NDArray] = None,
upd_mask: Optional[NDArray] = None,
):
super().__init__(weight=weight, ndims=ndims, upd_mask=upd_mask)
self.pwise_der_norm = pwise_der_norm
if self.pwise_der_norm not in [1, 2, np.inf]:
self._raise_pwise_norm_error()
self.pwise_chan_norm = pwise_chan_norm
if self.pwise_chan_norm not in [1, 2, np.inf]:
self._raise_pwise_norm_error()
if x_ref is not None:
self.initialize_sigma_tau(x_ref)
q_ref = self.initialize_dual()
self.update_dual(q_ref, x_ref)
self.q_ref = np.expand_dims(q_ref, axis=1)
else:
self.q_ref = None
def _raise_pwise_norm_error(self):
raise ValueError(
"The only supported point-wise norm exponents are: 1, 2, and Inf."
+ f" Provided the following instead: derivatives={self.pwise_der_norm}, channel={self.pwise_chan_norm}"
)
[docs]
def apply_proximal(self, dual: NDArray) -> None:
# Following assignments will detach the local array from the original one
dual_tmp = dual.copy()
dual_is_scalar = len(dual_tmp.shape) == (self.ndims + 1)
if dual_is_scalar:
dual_tmp = np.expand_dims(dual_tmp, axis=1)
if self.q_ref is not None:
dual_tmp = np.concatenate((dual_tmp, self.q_ref), axis=1)
if self.pwise_der_norm == 1:
grad_norm = np.abs(dual_tmp)
elif self.pwise_der_norm == 2:
grad_norm = np.linalg.norm(dual_tmp, axis=0, ord=2, keepdims=True)
elif self.pwise_der_norm == np.inf:
grad_norm = np.linalg.norm(dual_tmp, axis=0, ord=1, keepdims=True)
else:
self._raise_pwise_norm_error()
if self.pwise_chan_norm == 1:
dual_norm = grad_norm
elif self.pwise_chan_norm == 2:
dual_norm = np.linalg.norm(grad_norm, axis=1, ord=2, keepdims=True)
elif self.pwise_chan_norm == np.inf:
dual_norm = np.linalg.norm(grad_norm, axis=1, ord=1, keepdims=True)
else:
self._raise_pwise_norm_error()
dual_tmp /= np.fmax(dual_norm, self.weight)
dual_tmp *= self.weight
if self.q_ref is not None:
dual_tmp = dual_tmp[:, : dual_tmp.shape[1] - 1 :, ...]
if dual_is_scalar:
dual_tmp = np.squeeze(dual_tmp, axis=1)
dual[:] = dual_tmp[:] # Replacing values
[docs]
class Regularizer_lnswl(Regularizer_l1swl):
"""Nuclear-norm Wavelet regularizer.
It can be used to promote compressed multi-channel reconstructions.
"""
__reg_name__ = "lnswl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
spectral_norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
x_ref: Optional[NDArray] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
normalized=normalized,
min_approx=min_approx,
)
self.norm = dt.DataFidelity_ln(ln_axes=(1, 0), spectral_norm=spectral_norm)
self.x_ref = x_ref
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
tau = super().initialize_sigma_tau(primal)
if self.x_ref is not None:
if self.op is None:
raise ValueError("Regularizer should have been initialized... this is a bug!")
self.q_ref = self.op(self.x_ref)
self.q_ref = np.expand_dims(self.q_ref, axis=1)
return tau
[docs]
class Regularizer_vl1wl(Regularizer_l1swl):
"""l1-norm vectorial Wavelet regularizer. It can be used to promote compressed reconstructions."""
__reg_name__ = "vl1wl"
def __init__(
self,
weight: Union[float, NDArray],
wavelet: str,
level: int,
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
pad_on_demand: str = "constant",
upd_mask: Optional[NDArray] = None,
normalized: bool = False,
min_approx: bool = True,
pwise_lvl_norm: Union[int, float] = 1,
pwise_chan_norm: Union[int, float] = np.inf,
x_ref: Optional[NDArray] = None,
):
super().__init__(
weight,
wavelet,
level,
ndims=ndims,
axes=axes,
pad_on_demand=pad_on_demand,
upd_mask=upd_mask,
normalized=normalized,
min_approx=min_approx,
)
self.pwise_lvl_norm = pwise_lvl_norm
if self.pwise_lvl_norm not in [1, 2, np.inf]:
self._raise_pwise_norm_error()
self.pwise_chan_norm = pwise_chan_norm
if self.pwise_chan_norm not in [1, 2, np.inf]:
self._raise_pwise_norm_error()
self.x_ref = x_ref
self.q_ref = None
def _raise_pwise_norm_error(self):
raise ValueError(
"The only supported point-wise norm exponents are: 1, 2, and Inf."
+ f" Provided the following instead: level={self.pwise_lvl_norm}, channel={self.pwise_chan_norm}"
)
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
tau = super().initialize_sigma_tau(primal)
if self.x_ref is not None:
if self.op is None:
raise ValueError("Regularizer should have been initialized... this is a bug!")
self.q_ref = self.op(self.x_ref)
return tau
[docs]
def apply_proximal(self, dual: NDArray) -> None:
dual_tmp = dual.copy()
if self.q_ref is not None:
dual_tmp = np.concatenate((dual_tmp, self.q_ref), axis=1)
if self.pwise_lvl_norm == 1:
lvl_norm = np.abs(dual_tmp)
elif self.pwise_lvl_norm == 2:
lvl_norm = np.linalg.norm(dual_tmp, axis=0, ord=2, keepdims=True)
elif self.pwise_lvl_norm == np.inf:
lvl_norm = np.linalg.norm(dual_tmp, axis=0, ord=1, keepdims=True)
else:
self._raise_pwise_norm_error()
if self.pwise_chan_norm == 1:
dual_norm = lvl_norm
elif self.pwise_chan_norm == 2:
dual_norm = np.linalg.norm(lvl_norm, axis=1, ord=2, keepdims=True)
elif self.pwise_chan_norm == np.inf:
dual_norm = np.linalg.norm(lvl_norm, axis=1, ord=1, keepdims=True)
else:
self._raise_pwise_norm_error()
dual_tmp /= np.fmax(dual_norm, self.weight)
dual_tmp *= self.weight
if self.q_ref is not None:
dual_tmp = dual_tmp[:, : dual_tmp.shape[0] - 1 :, ...]
dual[:] = dual_tmp[:]
[docs]
class Regularizer_vSVD(BaseRegularizer):
"""Regularizer based on the Singular Value Decomposition.
It can be used to promote similar reconstructions across different channels.
"""
__reg_name__ = "vsvd"
def __init__(
self,
weight: Union[float, NDArray],
ndims: int = 2,
axes: Union[Sequence[int], NDArray, None] = None,
axis_channels: Sequence[int] = (0,),
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l1(),
):
super().__init__(weight=weight, norm=norm, upd_mask=upd_mask)
if axes is None:
axes = np.arange(-ndims, 0, dtype=int)
elif not ndims == len(axes):
print("WARNING - Number of axes different from number of dimensions. Updating dimensions accordingly.")
ndims = len(axes)
self.ndims = ndims
self.axes = axes
self.axis_channels = axis_channels
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformSVD(primal.shape, axes_rows=self.axis_channels, axes_cols=self.axes, rescale=True)
self.sigma = 1
self.norm.assign_data(None, sigma=self.sigma)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau = self.weight
else:
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
# ---- Constraints ----
[docs]
class Constraint_LowerLimit(BaseRegularizer):
"""Lower limit constraint. It can be used to promote reconstructions in certain regions of solution space."""
__reg_name__ = "lowlim"
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + "(l:%g" % self.limit + ")"
def __init__(
self,
limit: Union[float, NDArray],
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l2(),
):
super().__init__(weight=1, norm=norm, upd_mask=upd_mask)
self.limit = limit
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformIdentity(primal.shape)
self.norm.assign_data(self.limit, sigma=1)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau = self.weight
else:
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
dual += primal
[docs]
def apply_proximal(self, dual: NDArray) -> None:
dual[dual > self.limit] = self.limit
self.norm.apply_proximal(dual)
[docs]
class Constraint_UpperLimit(BaseRegularizer):
"""Upper limit constraint. It can be used to promote reconstructions in certain regions of solution space."""
__reg_name__ = "uplim"
[docs]
def info(self) -> str:
"""
Return the regularizer info.
Returns
-------
str
Regularizer info string.
"""
return self.__reg_name__ + "(l:%g" % self.limit + ")"
def __init__(
self,
limit: Union[float, NDArray],
upd_mask: Optional[NDArray] = None,
norm: dt.DataFidelityBase = dt.DataFidelity_l2(),
):
super().__init__(weight=1, norm=norm, upd_mask=upd_mask)
self.limit = limit
[docs]
def initialize_sigma_tau(self, primal: NDArray) -> Union[float, NDArray]:
self._check_primal(primal)
self.op = operators.TransformIdentity(primal.shape)
self.norm.assign_data(self.limit, sigma=1)
if not isinstance(self.norm, dt.DataFidelity_l1):
tau = self.weight
else:
tau = 1.0
if self.upd_mask is not None:
tau = tau * self.upd_mask
return tau
[docs]
def update_dual(self, dual: NDArray, primal: NDArray) -> None:
dual += primal
[docs]
def apply_proximal(self, dual: NDArray) -> None:
dual[dual < self.limit] = self.limit
self.norm.apply_proximal(dual)