"""
Filtered back-projection filters.
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, DTypeLike, NDArray
from scipy.interpolate import interp1d
from skimage.transform.radon_transform import _get_fourier_filter
from .operators import BaseTransform
from .processing import circular_mask
try:
import pywt
has_pywt = True
except ImportError:
print("WARNING: You need to install PyWavelets to benefit from wavelet bases.")
has_pywt = False
[docs]class BasisOptions(ABC, Mapping):
"""Options for the different types of bases."""
def __len__(self) -> int:
"""Return the number of options.
Returns
-------
int
The number of options.
"""
return self.__dict__.__len__()
def __getitem__(self, k: Any) -> Any:
"""Return the selected option.
Parameters
----------
k : Any
The key of the selected option.
Returns
-------
Any
The selected option.
"""
return self.__dict__.__getitem__(k)
def __iter__(self) -> Any:
"""Iterate the options list.
Returns
-------
Any
The following option.
"""
return iter(self.__dict__)
[docs]@dataclass
class BasisOptionsBlocks(BasisOptions):
"""Options for the wavelet bases."""
binning_start: Optional[int] = 2
binning_type: str = "exponential"
order: int = 1
normalized: bool = True
[docs]@dataclass
class BasisOptionsWavelets(BasisOptions):
"""Options for the wavelet bases."""
wavelet: str = "bior2.2"
level: int = 5
norm: float = 1.0
[docs]def create_basis(
num_pixels: int,
binning_start: Optional[int] = 2,
binning_type: str = "exponential",
normalized: bool = False,
order: int = 1,
dtype: DTypeLike = np.float32,
) -> NDArray:
"""Compute filter basis matrix.
Parameters
----------
num_pixels : int
Number of filter fixels.
binning_start : Optional[int], optional
Starting displacement of the binning, by default 2.
binning_type : str, optional
Type of pixel binning, by default "exponential".
normalized : bool, optional
Whether to normalize the bins by the window size, by default True.
order : int, optional
Order of the basis functions. Only 0 and 1 supported, by default 1.
dtype : DTypeLike, optional
Data type, by default np.float32.
Returns
-------
NDArray
The filter basis.
"""
filter_positions = np.abs(np.fft.fftfreq(num_pixels, 1 / num_pixels))
window_size = 1
window_position = 0
basis_r = []
while window_position < filter_positions.max():
basis_tmp = np.zeros(filter_positions.shape, dtype=dtype)
if order == 0:
binning_positions = np.logical_and(
window_position <= filter_positions, filter_positions < (window_position + window_size)
)
basis_val = 1.0
if normalized:
basis_val /= window_size
basis_tmp[binning_positions] = basis_val
else:
basis_tmp = np.fmax(1.0 - filter_positions / (window_position + window_size), 0.0)
basis_r.append(basis_tmp)
window_position += window_size
if binning_start is not None and window_position > binning_start:
if binning_type == "exponential":
window_size = 2 * window_size
elif binning_type == "incremental":
window_size += 1
elif binning_type == "custom":
window_size += int(np.sqrt(window_position - binning_start))
else:
raise ValueError(f"Invalid 'binning_type' = {binning_type}.")
basis_r = np.array(basis_r, dtype=dtype)
if order > 0:
for ii in range(1, basis_r.shape[0]):
for jj in range(0, ii):
other_vec = basis_r[jj, ...]
other_vec_norm_2 = other_vec.dot(other_vec)
basis_r[ii, ...] -= other_vec * other_vec.dot(basis_r[ii, ...]) / other_vec_norm_2
if normalized:
basis_r /= np.linalg.norm(basis_r, ord=1, axis=-1, keepdims=True)
return basis_r
[docs]def create_basis_wavelet(
num_pixels: int,
wavelet: str = "bior2.2",
level: int = 5,
norm: float = 1.0,
dtype: DTypeLike = np.float32,
) -> NDArray:
"""Compute filter basis matrix.
Parameters
----------
num_pixels : int
Number of wavelet filters.
wavelet: str, optional
The wavelet to use, by default "bior4.4".
level : int, optional
The decomposition level to reach, by default 5.
norm : float, optional
The norm to use, by default 1.0.
dtype : DTypeLike, optional
Data type, by default np.float32.
Returns
-------
NDArray
The filter basis.
"""
if not has_pywt:
print("WARNING: You need to install PyWavelets to benefit from wavelet bases.")
raise ImportError("PyWavelets (pywt) module not found.")
# max_level = pywt.swt_max_level(num_pixels)
# if level > max_level:
# print(f"WARNING: Requested level {level} is too high for {num_pixels} pixels. Max allowed is {max_level}.")
# level = max_level
w = pywt.Wavelet(wavelet)
dec_lo = np.trim_zeros(w.dec_lo)
dec_hi = np.trim_zeros(w.dec_hi)
crop_size_l = num_pixels // 2
crop_size_u = num_pixels - crop_size_l
pad_size_u = (num_pixels - len(dec_hi)) // 2
pad_size_l = num_pixels - len(dec_hi) - pad_size_u
pad_width = (pad_size_l, pad_size_u)
basis_hi_tmp = np.pad(dec_hi, pad_width=np.array(pad_width))
pad_size_u = (num_pixels - len(dec_lo)) // 2
pad_size_l = num_pixels - len(dec_lo) - pad_size_u
pad_width = (pad_size_l, pad_size_u)
basis_lo_tmp = np.pad(dec_lo, pad_width=np.array(pad_width))
coords = np.fft.fftshift(np.fft.fftfreq(num_pixels, 1 / num_pixels))
basis_r = []
basis_r.append(basis_hi_tmp)
for _ in range(level - 1):
basis_lo_old = basis_lo_tmp.copy()
int_obj = interp1d(coords, basis_hi_tmp, kind="linear")
basis_hi_tmp = int_obj((coords + 1 - len(coords) % 2) / 2)
basis_hi_tmp = np.convolve(basis_hi_tmp, basis_lo_old, mode="same")
int_obj = interp1d(coords, basis_lo_tmp, kind="linear")
basis_lo_tmp = int_obj((coords + 1 - len(coords) % 2) / 2)
basis_lo_tmp = np.convolve(basis_lo_tmp, basis_lo_old, mode="same")
basis_r.append(basis_hi_tmp)
basis_r.append(basis_lo_tmp)
basis_r = np.array(basis_r, dtype=dtype)
basis_r /= np.linalg.norm(basis_r, axis=-1, ord=norm, keepdims=True)
return np.fft.ifftshift(basis_r, axes=(-1,))
[docs]class Filter(ABC):
"""Base FBP filter."""
fbp_filter: NDArray[np.floating]
pad_mode: str
use_rfft: bool
dtype: DTypeLike
def __init__(
self,
fbp_filter: Union[ArrayLike, NDArray[np.floating], None],
pad_mode: str,
use_rfft: bool,
dtype: DTypeLike,
) -> None:
"""Initialize Base FBP filter.
Parameters
----------
fbp_filter : Union[ArrayLike, NDArray[np.floating], None]
The filter.
pad_mode : str
The padding mode.
use_rfft : bool
Whethert to use the `rfft` or complex `fft`.
dtype : DTypeLike
The data type of the filter.
"""
self.dtype = dtype
self.pad_mode = pad_mode.lower()
self.use_rfft = use_rfft
if fbp_filter is None:
self.fbp_filter = np.array([1.0], dtype=dtype)
else:
self.fbp_filter = np.array(np.real(fbp_filter), dtype=dtype)
[docs] def get_padding_size(self, data_wu_shape: Sequence[int]) -> int:
"""Compute the projection padding size for a linear convolution.
Parameters
----------
data_wu_shape : Sequence[int]
The shape of the data
Returns
-------
int
The padding size of the last dimension.
"""
return max(64, int(2 ** np.ceil(np.log2(2 * data_wu_shape[-1]))))
[docs] def to_fourier(self, data_wu: NDArray) -> NDArray:
if self.use_rfft:
return np.fft.rfft(data_wu, axis=-1)
else:
return np.fft.fft(data_wu, axis=-1)
[docs] def to_real(self, data_wu: NDArray) -> NDArray:
if self.use_rfft:
return np.fft.irfft(data_wu, axis=-1)
else:
return np.fft.ifft(data_wu, axis=-1).real
@property
def filter_fourier(self) -> NDArray[np.floating]:
"""Fourier representation of the filter.
Returns
-------
NDArray[np.floating]
The filter in Fourier.
"""
return self.fbp_filter.copy()
@property
def filter_real(self) -> NDArray[np.floating]:
"""Real-space representation of the filter.
Returns
-------
NDArray[np.floating]
The filter in real-space.
"""
fbp_filter_r = self.to_real(self.fbp_filter)
return np.fft.fftshift(fbp_filter_r, axes=(-1,))
@filter_real.setter
def filter_real(self, fbp_filter_r: NDArray[np.floating]) -> None:
self.fbp_filter = self.to_fourier(fbp_filter_r)
@filter_fourier.setter
def filter_fourier(self, fbp_filter_f: NDArray[np.floating]) -> None:
self.fbp_filter = fbp_filter_f.copy()
@property
def num_filters(self) -> int:
return np.array(self.fbp_filter, ndmin=2).shape[-2]
[docs] def apply_filter(self, data_wu: NDArray, fbp_filter: Optional[NDArray] = None) -> NDArray:
"""Apply the filter to the data_wu.
Parameters
----------
data_wu : NDArray
The sinogram.
fbp_filter : NDArray, optional
The filter to use. The default is None
Returns
-------
NDArray
The filtered sinogram.
"""
data_wu_shape = data_wu.shape
if fbp_filter is None:
local_filter = self.fbp_filter
else:
local_filter = fbp_filter
prj_size_pad = self.get_padding_size(data_wu_shape)
pad_edge_u = (prj_size_pad - data_wu_shape[-1]) / 2
pad_width = np.zeros((len(data_wu_shape), 2), dtype=int)
pad_width[-1, :] = (int(np.ceil(pad_edge_u)), int(np.floor(pad_edge_u)))
prj_pad = np.pad(data_wu, pad_width=tuple(pad_width), mode=self.pad_mode) # type: ignore
prj_pad = np.roll(prj_pad, shift=-pad_width[-1][0], axis=-1)
prj_f = self.to_fourier(prj_pad)
local_filter = np.array(local_filter, ndmin=2)
num_filters = local_filter.shape[-2]
if num_filters > 1:
prjs_f = [np.array([])] * num_filters
for ii, f in enumerate(local_filter):
prj_f_ii = prj_f * np.array(f, ndmin=len(data_wu_shape))
prj_f_ii = self.to_real(prj_f_ii)
prjs_f[ii] = prj_f_ii[..., : data_wu_shape[-1]]
return np.array(prjs_f)
else:
prj_f *= np.array(local_filter, ndmin=len(data_wu_shape))
return self.to_real(prj_f)[..., : data_wu_shape[-1]]
[docs] @abstractmethod
def compute_filter(self, data_wu: NDArray) -> None:
"""Compute the FBP filter for the given data.
Parameters
----------
data_wu : NDArray
The reference sinogram / projection data.
"""
def __call__(self, data_wu: NDArray) -> NDArray:
"""Filter the sinogram, by first computing the filter, and then applying it.
Parameters
----------
data_wu : NDArray
The unfiltered sinogram.
Returns
-------
NDArray
The filtered sinogram.
"""
self.compute_filter(data_wu)
return self.apply_filter(data_wu, self.fbp_filter)
[docs] def plot_filters(self, fourier_abs: bool = False):
filters_r = np.array(self.filter_real, ndmin=2)
filters_f = np.array(self.filter_fourier, ndmin=2)
f, axes = plt.subplots(1, 2, figsize=(10, 4))
for ii in range(self.num_filters):
axes[0].plot(filters_r[ii, ...], label=f"Filter-{ii}")
axes[0].set_title("Real-space")
axes[0].set_xlabel("Pixel")
for ii in range(self.num_filters):
filt_f = filters_f[ii, ...]
if fourier_abs:
filt_f = np.abs(filt_f)
axes[1].plot(filt_f, label=f"Filter-{ii}")
axes[1].set_title("Fourier-space")
axes[1].set_xlabel("Frequency")
axes[0].grid()
axes[1].grid()
if self.num_filters > 1:
axes[1].legend()
f.tight_layout()
[docs]class FilterCustom(Filter):
"""Custom FBP filter."""
def __init__(
self,
fbp_filter: Union[ArrayLike, NDArray[np.floating], None],
pad_mode: str = "constant",
use_rfft: bool = True,
dtype: DTypeLike = np.float32,
) -> None:
"""Initialize Custom FBP filter.
Parameters
----------
fbp_filter : Union[ArrayLike, NDArray[np.floating], None]
The filter.
pad_mode : str, optional
The padding mode, by default "constant".
use_rfft : bool, optional
Whethert to use the `rfft` or complex `fft`, by default True.
dtype : DTypeLike, optional
The data type of the filter, by default np.float32.
"""
super().__init__(fbp_filter, pad_mode, use_rfft, dtype)
[docs] def compute_filter(self, data_wu: NDArray) -> None:
"""Provide dummy implementation, because it is not required."""
[docs]class FilterFBP(Filter):
"""Traditional FBP filter."""
filter_name: str
FILTERS = ("ramp", "shepp-logan", "cosine", "hamming", "hann")
def __init__(
self, filter_name: str = "ramp", pad_mode: str = "constant", use_rfft: bool = True, dtype: DTypeLike = np.float32
) -> None:
"""Initialize traditional FBP filter.
Parameters
----------
filter_name : str
The filter name
use_rfft : bool, optional
Whethert to use the `rfft` or complex `fft`, by default True
dtype : DTypeLike, optional
The type of the filter, by default np.float32
"""
if filter_name.lower() not in self.FILTERS:
raise ValueError(f"Unknown filter {filter_name}. Available filters: {self.FILTERS}")
super().__init__(fbp_filter=None, pad_mode=pad_mode, use_rfft=use_rfft, dtype=dtype)
self.filter_name = filter_name.lower()
[docs] def compute_filter(self, data_wu: NDArray) -> None:
"""Compute the traditional FBP filter for the given data.
Parameters
----------
data_wu : NDArray
The reference sinogram / projection data.
"""
prj_size_pad = self.get_padding_size(data_wu.shape)
self.fbp_filter = _get_fourier_filter(prj_size_pad, self.filter_name)
self.fbp_filter = np.squeeze(self.fbp_filter) * np.pi / (2 * data_wu.shape[-2])
if self.use_rfft:
self.fbp_filter = self.fbp_filter[: (self.fbp_filter.shape[-1]) // 2 + 1]
[docs] def get_available_filters(self) -> Sequence[str]:
"""Provide available FBP filters.
Returns
-------
Sequence[str]
The available filters.
"""
return self.FILTERS
[docs]class FilterMR(Filter):
"""Data dependent FBP filter.
This is a simplified implementation from:
[1] Pelt, D. M., & Batenburg, K. J. (2014). Improving filtered backprojection
reconstruction by data-dependent filtering. Image Processing, IEEE
Transactions on, 23(11), 4750-4762.
Code inspired by: https://github.com/dmpelt/pymrfbp
"""
projector: BaseTransform
binning_type: str
binning_start: Union[int, None]
lambda_smooth: Union[float, None]
is_initialized: bool
def __init__(
self,
projector: BaseTransform,
binning_type: str = "exponential",
binning_start: Union[int, None] = 2,
lambda_smooth: Optional[float] = None,
pad_mode: str = "constant",
use_rfft: bool = True,
dtype: DTypeLike = np.float32,
) -> None:
"""Initialize data-driven FBP filter.
Parameters
----------
projector : BaseTransform
The projector to use for handling the data.
start_exp_binning : int, optional
From which distance to start exponentional binning. The default is 2.
lambda_smooth : float, optional
Smoothing parameter. The default is None.
dtype : DTypeLike, optional
Filter data type. The default is np.float32.
"""
super().__init__(fbp_filter=None, pad_mode=pad_mode, use_rfft=use_rfft, dtype=dtype)
self.projector = projector
self.binning_type = binning_type.lower()
if self.binning_type not in ("exponential", "incremental"):
raise ValueError("Binning type should be either 'exponential' or 'incremental'.")
self.binning_start = binning_start
self.lambda_smooth = lambda_smooth
self.is_initialized = False
[docs] def initialize(self, data_wu_shape: Sequence[int]) -> None:
"""Initialize filter.
Parameters
----------
data_wu_shape : Sequence[int]
Shape of the data.
"""
num_pad_pixels = self.get_padding_size(data_wu_shape)
self.basis_r = create_basis(
num_pad_pixels, binning_type=self.binning_type, binning_start=self.binning_start, dtype=self.dtype
)
self.basis_f = self.to_fourier(self.basis_r).real
self.is_initialized = True
[docs] def compute_filter(self, data_wu: NDArray) -> None:
"""Compute the filter.
Parameters
----------
data_wu : NDArray
The sinogram.
projector : ProjectorOperator
The projector used in the FBP.
"""
if not self.is_initialized:
self.initialize(data_wu.shape)
num_sino_pixels = data_wu.shape[-1]
sino_size = data_wu.shape[-2] * num_sino_pixels
nrows = sino_size
ncols = self.basis_f.shape[-2]
if self.lambda_smooth:
grad_vol_size = num_sino_pixels * (num_sino_pixels - 1)
nrows += 2 * grad_vol_size
A = np.zeros((nrows, ncols), dtype=self.dtype)
vol_mask = circular_mask([num_sino_pixels] * 2)
for ii, bas_f in enumerate(self.basis_f):
data_wu_f = self.apply_filter(data_wu, bas_f)
img = self.projector.T(data_wu_f)
img *= vol_mask
A[:sino_size, ii] = self.projector(img).flatten()
if self.lambda_smooth:
dx = np.diff(img, axis=-2)
dy = np.diff(img, axis=-1)
d = np.concatenate((dx.flatten(), dy.flatten()))
A[sino_size:, ii] = self.lambda_smooth * d
b = np.zeros((nrows,), dtype=self.dtype)
b[:sino_size] = data_wu.flatten()
fitted_components = np.linalg.lstsq(A, b, rcond=None)[0].astype(self.dtype)
self.fbp_filter = fitted_components.dot(self.basis_f)