Source code for corrct.data_terms

#!/usr/bin/env python3
"""
Data fidelity classes.

@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""

from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
from typing import Any

import numpy as np
from numpy.typing import NDArray

from . import operators

eps = np.finfo(np.float32).eps


NDArrayFloat = NDArray[np.floating]


[docs] def _soft_threshold(values: NDArrayFloat, threshold: float | NDArrayFloat) -> None: abs_values = np.abs(values) valid_values = abs_values > 0 if isinstance(threshold, (float, int)) or threshold.size == 1: local_threshold = threshold else: local_threshold = threshold[valid_values] values[valid_values] *= np.fmax((abs_values[valid_values] - local_threshold) / abs_values[valid_values], 0)
[docs] class DataFidelityBase(ABC): """Define the DataFidelity classes interface.""" data: NDArrayFloat | None sigma: float | NDArrayFloat background: NDArrayFloat | None sigma_data: NDArrayFloat | None __data_fidelity_name__ = "" def __init__(self, background: float | NDArrayFloat | None = None) -> None: """ Initialize the base data-fidelity class. Parameters ---------- background : float | NDArrayFloat | None, optional The data background. The default is None. """ self.background = np.array(background) if background is not None else None self.data = None self.sigma = 1.0 self.sigma_data = None
[docs] def _slice_attr(self, attr: str, ind: Any) -> None: attr_val = self.__getattribute__(attr) if attr_val is not None and isinstance(attr_val, np.ndarray) and attr_val.size > 1: self.__setattr__(attr, attr_val[ind])
[docs] def __getitem__(self, ind: Any) -> "DataFidelityBase": """ Slice the norm and all its attributes. Parameters ---------- ind : Any Slicing indices. Returns ------- DataFidelityBase The sliced norm. """ new_self = deepcopy(self) for attr in self.__dict__.keys(): new_self._slice_attr(attr, ind) return new_self
[docs] def info(self) -> str: """ Return the data-fidelity info. Returns ------- str Data fidelity info string. """ if self.background is not None: if np.array(self.background).size > 1: bckgrnd_str = "(B:<array>)" else: bckgrnd_str = "(B:%g)" % self.background else: bckgrnd_str = "" return self.__data_fidelity_name__ + bckgrnd_str
[docs] def upper(self) -> str: """ Return the upper case name of the data-fidelity. Returns ------- str Upper case string name of the data-fidelity. """ return self.info().upper()
[docs] def lower(self) -> str: """ Return the lower case name of the data-fidelity. Returns ------- str Lower case string name of the data-fidelity. """ return self.info().lower()
[docs] def assign_data(self, data: float | NDArrayFloat | None = None, sigma: float | NDArrayFloat = 1.0) -> None: """Initialize the data bias, and sigma of the data term. Parameters ---------- data : float | NDArrayFloat | None, optional The data bias, by default None sigma : float | NDArrayFloat, optional The sigma, by default 1.0 """ self.data = np.array(data) if data is not None else None self.sigma = sigma self.sigma_data = self._compute_sigma_data()
[docs] def compute_residual(self, proj_primal: NDArrayFloat, mask: NDArrayFloat | None = None) -> NDArrayFloat: """Compute the residual in the dual domain. Parameters ---------- proj_primal : NDArrayFloat Projection of the primal solution mask : NDArrayFloat | None, optional Mask of the dual domain, by default None Returns ------- NDArrayFloat The residual """ if self.background is not None: proj_primal = proj_primal + self.background if self.data is not None: residual = self.data - proj_primal else: residual = proj_primal.copy() if mask is not None: residual *= mask return residual
[docs] @abstractmethod def compute_residual_norm(self, dual: NDArrayFloat) -> float: """Compute the norm of the residual. Parameters ---------- dual : NDArrayFloat The residual in the dual domain. Returns ------- float The residual norm. """
[docs] def _compute_sigma_data(self): if self.data is None: return None else: return self.sigma * self.data
[docs] def compute_data_dual_dot(self, dual: NDArrayFloat, mask: NDArrayFloat | None = None) -> float: """Compute the dot product of the data bias and the dual solution. Parameters ---------- dual : NDArrayFloat The dual solution. mask : NDArrayFloat | None, optional Mask of the dual domain, by default None Returns ------- float The dot product between the data bias and the dual solution """ if self.data is not None: if mask is not None: dual = dual * mask return np.dot(dual.flatten(), self.data.flatten()) else: return 0.0
[docs] def initialize_dual(self) -> NDArrayFloat: """Initialize the dual domain solution. Returns ------- NDArrayFloat A zero array with the dimensions of the dual domain. """ return np.zeros_like(self.data)
[docs] def update_dual(self, dual: NDArrayFloat, proj_primal: NDArrayFloat) -> None: """Update the dual solution. Parameters ---------- dual : NDArrayFloat The current dual solution proj_primal : NDArrayFloat The projected primal solution """ if self.background is None: dual += proj_primal * self.sigma else: dual += (proj_primal + self.background) * self.sigma
[docs] @abstractmethod def apply_proximal(self, dual: NDArrayFloat) -> None: """Apply the proximal in the dual domain. Parameters ---------- dual : NDArrayFloat The dual solution """
[docs] @abstractmethod def compute_primal_dual_gap( self, proj_primal: NDArrayFloat, dual: NDArrayFloat, mask: NDArrayFloat | None = None ) -> float: """Compute the primal-dual gap of the current solution. Parameters ---------- proj_primal : NDArrayFloat The projected primal solution (in the dual domain) dual : NDArrayFloat The dual solution mask : NDArrayFloat | None, optional Mask in the dual domain, by default None Returns ------- float The primal-dual gap """
[docs] class DataFidelity_l2(DataFidelityBase): """l2-norm data-fidelity class.""" __data_fidelity_name__ = "l2" sigma1: float | NDArrayFloat def __init__(self, background: float | NDArrayFloat | None = None) -> None: super().__init__(background=background) self.sigma1 = 1.0
[docs] def assign_data(self, data: float | NDArrayFloat | None = None, sigma: float | NDArrayFloat = 1.0) -> None: super().assign_data(data=data, sigma=sigma) self.sigma1 = 1 / (1 + sigma)
[docs] def compute_residual_norm(self, dual: NDArrayFloat) -> float: return float(np.linalg.norm(dual.flatten(), ord=2) ** 2)
[docs] def apply_proximal(self, dual: NDArrayFloat) -> None: if self.data is not None and self.sigma_data is not None: dual -= self.sigma_data dual *= self.sigma1
[docs] def compute_primal_dual_gap( self, proj_primal: NDArrayFloat, dual: NDArrayFloat, mask: NDArrayFloat | None = None ) -> float: return float( np.linalg.norm(self.compute_residual(proj_primal, mask), ord=2) + np.linalg.norm(dual, ord=2) ) / 2 + self.compute_data_dual_dot(dual)
[docs] class DataFidelity_wl2(DataFidelity_l2): """Weighted l2-norm data-fidelity class.""" __data_fidelity_name__ = "wl2" sigma1: float | NDArrayFloat weights: NDArrayFloat def __init__(self, weights: float | NDArrayFloat, background: float | NDArrayFloat | None = None) -> None: super().__init__(background=background) self.sigma1 = 1.0 self.weights = np.array(weights)
[docs] def assign_data(self, data: float | NDArrayFloat | None, sigma: float | NDArrayFloat = 1.0): super().assign_data(data=data, sigma=sigma) if isinstance(self.sigma, np.ndarray): dtype = self.sigma.dtype else: dtype = type(self.sigma) invalid_weights = (self.weights == 0).astype(dtype) self.sigma1 = 1 / (1 + sigma / (self.weights + invalid_weights)) * (1 - invalid_weights)
[docs] def compute_residual(self, proj_primal, mask: float | NDArrayFloat | None = None): if self.background is not None: proj_primal = proj_primal + self.background if self.data is not None: residual = (self.data - proj_primal) * self.weights else: residual = proj_primal * self.weights if mask is not None: residual *= mask return residual
[docs] def compute_residual_norm(self, dual: float | NDArrayFloat) -> float: valid_weights = self.weights != 0 if isinstance(dual, np.ndarray): dual = dual[valid_weights] weights = self.weights[valid_weights] return float(np.linalg.norm((dual / np.sqrt(weights)).flatten(), ord=2) ** 2)
[docs] class DataFidelity_l2b(DataFidelity_l2): """l2-norm ball data-fidelity class.""" __data_fidelity_name__ = "l2b" sigma1: float | NDArrayFloat sigma_error: float | NDArrayFloat sigma_sqrt_error: float | NDArrayFloat def __init__(self, local_error: float | NDArrayFloat, background: float | NDArrayFloat | None = None): super().__init__(background=background) self.sigma1 = 1.0 self.local_error = local_error self.sigma_error = 1.0 * self.local_error self.sigma_sqrt_error = 1.0 * np.sqrt(self.local_error)
[docs] def assign_data(self, data: float | NDArrayFloat | None, sigma: float | NDArrayFloat = 1.0): self.sigma_error = sigma * self.local_error self.sigma_sqrt_error = sigma * np.sqrt(self.local_error) super().assign_data(data=data, sigma=sigma) self.sigma1 = 1 / (1 + self.sigma_error)
[docs] def compute_residual(self, proj_primal: NDArrayFloat, mask: NDArrayFloat | None = None) -> NDArrayFloat: residual = super().compute_residual(proj_primal, mask) _soft_threshold(residual, self.sigma_sqrt_error) return residual
[docs] def apply_proximal(self, dual: NDArrayFloat) -> None: if self.data is not None and self.sigma_data is not None: dual -= self.sigma_data _soft_threshold(dual, self.sigma_sqrt_error) dual *= self.sigma1
[docs] def compute_primal_dual_gap( self, proj_primal: NDArrayFloat, dual: NDArrayFloat, mask: NDArrayFloat | None = None ) -> float: return float( np.linalg.norm(self.compute_residual(proj_primal, mask), ord=2) + np.linalg.norm(np.sqrt(self.local_error) * dual, ord=2) ) / 2 + self.compute_data_dual_dot(dual)
[docs] class DataFidelity_Huber(DataFidelityBase): """Huber-norm data-fidelity class. Given a parameter a: l2-norm for x < a, and l1-norm for x > a.""" __data_fidelity_name__ = "Hub" one_sigma_error: float | NDArrayFloat def __init__(self, local_error, background=None, l2_axis=None): super().__init__(background=background) self.local_error = local_error self.l2_axis = l2_axis self.one_sigma_error = 1.0
[docs] def assign_data(self, data, sigma=1.0): self.one_sigma_error = 1.0 / (1.0 + sigma * self.local_error) super().assign_data(data=data, sigma=sigma)
[docs] def compute_residual_norm(self, dual): l2_points = dual <= self.local_error l1_points = 1 - l2_points return np.linalg.norm(dual[l2_points].flatten(), ord=2) ** 2 + np.linalg.norm(dual[l1_points].flatten(), ord=1)
[docs] def apply_proximal(self, dual): if self.data is not None and self.sigma_data is not None: dual -= self.sigma_data dual *= self.one_sigma_error if self.l2_axis is None: dual /= np.fmax(1, np.abs(dual)) else: dual_dir_norm_l2 = np.linalg.norm(dual, ord=2, axis=self.l2_axis, keepdims=True) dual /= np.fmax(1, dual_dir_norm_l2)
[docs] def compute_primal_dual_gap(self, proj_primal, dual, mask=None): if self.background is not None: proj_primal = proj_primal + self.background return ( np.linalg.norm(self.compute_residual(proj_primal, mask), ord=2) + self.compute_data_dual_dot(dual) + self.local_error * np.linalg.norm(dual, ord=2) )
[docs] class DataFidelity_l1(DataFidelityBase): """l1-norm data-fidelity class.""" __data_fidelity_name__ = "l1" def __init__(self, background=None): super().__init__(background=background)
[docs] def _get_inner_norm(self, dual): return np.abs(dual)
[docs] def _apply_threshold(self, dual): pass
[docs] def apply_proximal(self, dual, weight=1.0): if self.data is not None: dual -= self.sigma_data self._apply_threshold(dual) dual_inner_norm = self._get_inner_norm(dual) dual /= np.fmax(dual_inner_norm, weight) dual *= weight
[docs] def compute_residual_norm(self, dual): dual = dual.copy() self._apply_threshold(dual) dual_inner_norm = self._get_inner_norm(dual) return np.linalg.norm(dual_inner_norm, ord=1)
[docs] def compute_primal_dual_gap(self, proj_primal, dual, mask=None): if self.background is not None: proj_primal = proj_primal + self.background residual = self.compute_residual(proj_primal, mask) self._apply_threshold(residual) residual_inner_norm = self._get_inner_norm(residual) return np.linalg.norm(residual_inner_norm, ord=1) + self.compute_data_dual_dot(dual)
[docs] class DataFidelity_l12(DataFidelity_l1): """l12-norm data-fidelity class.""" __data_fidelity_name__ = "l12" def __init__(self, background=None, l2_axis=0): super().__init__(background=background) self.l2_axis = l2_axis
[docs] def _get_inner_norm(self, dual): return np.linalg.norm(dual, ord=2, axis=self.l2_axis, keepdims=True)
[docs] class DataFidelity_l1b(DataFidelity_l1): """l1-norm ball data-fidelity class.""" __data_fidelity_name__ = "l1b" sigma_error: float | NDArrayFloat def __init__(self, local_error, background=None): super().__init__(background=background) self.local_error = local_error self.sigma_error = 1.0 * self.local_error
[docs] def assign_data(self, data, sigma=1.0): self.sigma_error = sigma * self.local_error super().assign_data(data=data, sigma=sigma)
[docs] def _apply_threshold(self, dual): _soft_threshold(dual, self.local_error)
[docs] class DataFidelity_KL(DataFidelityBase): """Kullback-Leibler data-fidelity class.""" __data_fidelity_name__ = "KL"
[docs] def _compute_sigma_data(self): if self.data is None: return None else: return 4 * self.sigma * np.fmax(self.data, 0.0)
[docs] def apply_proximal(self, dual): if self.sigma_data is not None: dual[:] = (1 + dual[:] - np.sqrt((dual[:] - 1) ** 2 + self.sigma_data[:])) / 2 else: dual[:] = (1 + dual[:] - np.sqrt((dual[:] - 1) ** 2)) / 2
[docs] def compute_residual(self, proj_primal, mask=None): if self.background is not None: proj_primal = proj_primal + self.background # we take the Moreau envelope here, and apply the proximal to it residual = np.fmax(proj_primal, eps) * self.sigma self.apply_proximal(residual) if mask is not None: residual *= mask return -residual
[docs] def compute_residual_norm(self, dual): return np.linalg.norm(dual.flatten(), ord=1)
[docs] def compute_primal_dual_gap(self, proj_primal, dual, mask=None): if self.background is not None: proj_primal = proj_primal + self.background if self.data is not None: data_nn = np.fmax(self.data, eps) proj_primal_nn = np.fmax(proj_primal, eps) residual = proj_primal_nn - data_nn * (1 - np.log(data_nn) + np.log(proj_primal_nn)) else: residual = np.copy(proj_primal) if mask is not None: residual *= mask return np.linalg.norm(residual, ord=1)
[docs] class DataFidelity_ln(DataFidelityBase): """nuclear-norm data-fidelity class.""" __data_fidelity_name__ = "ln" def __init__(self, background=None, ln_axes: Sequence[int] = (1, -1), spectral_norm: DataFidelityBase = DataFidelity_l1()): super().__init__(background=background) self.ln_axes = ln_axes self.spectral_norm = spectral_norm self.use_fallback = False
[docs] def apply_proximal(self, dual): dual_tmp = dual.copy() if self.sigma_data is not None: # If we have a bias term, we interpret it as an addition to the rows in the SVD decomposition. # Performing this operation before the transpose is a waste of computation, but it simplifies the logic. dual_tmp = np.concatenate((dual_tmp, self.sigma_data), axis=self.ln_axes[0]) if self.use_fallback: t_range = [*range(len(dual_tmp.shape))] t_range.append(t_range.pop(self.ln_axes[0])) t_range.append(t_range.pop(self.ln_axes[1])) dual_tmp = np.transpose(dual_tmp, t_range) U, s_p, Vt = np.linalg.svd(dual_tmp, full_matrices=False) self.spectral_norm.apply_proximal(s_p) dual_tmp = np.matmul(U, s_p[..., None] * Vt) dual_tmp = np.transpose(dual_tmp, np.argsort(t_range)) else: op_svd = operators.TransformSVD(dual_tmp.shape, axes_rows=self.ln_axes[0], axes_cols=self.ln_axes[1]) s_p = op_svd(dual_tmp) self.spectral_norm.apply_proximal(s_p) dual_tmp = op_svd.T(s_p) if self.data is not None: # We now strip the bias data, to make sure that we don't change dimensionality. # dual_tmp = dual_tmp[..., : dual_tmp.shape[-2] - 1 :, :] dual_tmp = np.take(dual_tmp, np.arange(dual_tmp.shape[self.ln_axes[0]] - 1), axis=self.ln_axes[0]) dual[:] = dual_tmp[:]
[docs] def compute_residual_norm(self, dual): op_svd = operators.TransformSVD(dual.shape, axes_rows=self.ln_axes[0], axes_cols=self.ln_axes[1]) s_p = op_svd(dual) return np.linalg.norm(s_p, ord=1)
[docs] def compute_primal_dual_gap(self, proj_primal, dual, mask=None): if self.background is not None: proj_primal = proj_primal + self.background residual = self.compute_residual(proj_primal, mask) return self.compute_residual_norm(residual) + self.compute_data_dual_dot(dual)