#!/usr/bin/env python3
"""
Fitting routines.
Created on Tue May 17 12:11:58 2022
@author: Nicola VIGANĂ’, Computational Imaging group, CWI, The Netherlands,
and ESRF - The European Synchrotron, Grenoble, France
"""
from typing import Optional, Union, Sequence
import numpy as np
from numpy.polynomial import Polynomial
import scipy.ndimage as spimg
import scipy.optimize as spopt
from numpy.typing import ArrayLike, NDArray
from skimage.transform import warp_polar
from skimage.filters import window
from scipy.optimize import minimize
import matplotlib.pyplot as plt
NDArrayFloat = NDArray[np.floating]
eps = np.finfo(np.float32).eps
[docs]
def fit_shifts_u_sad(
    data_wu: NDArrayFloat,
    proj_wu: NDArrayFloat,
    search_range: int = 16,
    pad_u: bool = False,
    error_norm: int = 1,
    decimals: int = 2,
) -> NDArrayFloat:
    """
    Find the U shifts between two sets of lines, by means of the sum-of-absolute-difference (SAD).
    Parameters
    ----------
    data_wu : NDArrayFloat
        The reference data.
    proj_wu : NDArrayFloat
        The other data.
    search_rage : int, optional
        The range in pixels of the search, by default 16
    error_norm : int, optional
        The error norm to use, by default 1
    decimals : int, optional
        The precision of the result, by default 2
    Returns
    -------
    NDArrayFloat
        A list of one shift for each row.
    """
    if pad_u:
        padding = np.zeros((len(data_wu.shape), 2), dtype=int)
        padding[-1, :] = (search_range, search_range)
        pad_data_wu = np.pad(data_wu, pad_width=padding, mode="edge")
        pad_proj_wu = np.pad(proj_wu, pad_width=padding, mode="constant")
    fft_proj_wu = np.fft.fft2(pad_proj_wu)
    num_shifts = search_range * 2 + 1
    shift_coords = np.fft.fftfreq(num_shifts, 1 / num_shifts)
    diffs = np.empty((data_wu.shape[-2], len(shift_coords)))
    for ii, s in enumerate(shift_coords):
        shifted_proj_wu = np.fft.ifft2(spimg.fourier_shift(fft_proj_wu, (0, s))).real
        diffs[:, ii] = np.linalg.norm(pad_data_wu - shifted_proj_wu, axis=-1, ord=error_norm)
    f_vals, f_h = extract_peak_regions_1d(-diffs, axis=-1, cc_coords=shift_coords)
    shifts_vu = f_h[1, :]
    if decimals > 0:
        shifts_vu += refine_max_position_1d(f_vals, decimals=decimals)
    return shifts_vu 
[docs]
def fit_shifts_vu_xc(
    data_vwu: NDArrayFloat,
    proj_vwu: NDArrayFloat,
    pad_u: bool = False,
    normalize_fourier: bool = False,
    margin: int = 0,
    use_rfft: bool = True,
    stack_axis: int = -2,
    decimals: int = 2,
) -> NDArrayFloat:
    """
    Find the VU shifts of the projected data, through cross-correlation.
    Parameters
    ----------
    data_vwu : NDArrayFloat
        The collected projection data.
    proj_vwu : NDArrayFloat
        The forward-projected images from the reconstruction.
    pad_u : bool, optional
        Pad the u coordinate. The default is False.
    normalize_fourier : bool, optional
        Whether to normalize the Fourier representation of the cross-correlation. The default is False.
    margin : int, optional
        The margin of the region to compare, the default is 0.
    use_rfft : bool, optional
        Whether to use the `rfft` transform in place of the complex `fft` transform. The default is True.
    stack_axis : int, optional
        The axis along which the VU images are stacked. The default is -2.
    decimals : int, optional
        Decimals for the truncation of the sub-pixel  The default is 2.
    Returns
    -------
    NDArrayFloat
        The VU shifts.
    """
    num_angles = data_vwu.shape[stack_axis]
    if use_rfft:
        local_fftn = np.fft.rfftn
        local_ifftn = np.fft.irfftn
    else:
        local_fftn = np.fft.fftn
        local_ifftn = np.fft.ifftn
    fft_dims = np.delete(np.arange(-len(data_vwu.shape), 0), stack_axis)
    u_axis = fft_dims[-1]
    old_fft_shapes = np.array(np.array(data_vwu.shape)[fft_dims], ndmin=1, dtype=int)
    new_fft_shapes = old_fft_shapes.copy()
    if pad_u:
        new_fft_shapes[u_axis] *= 2
    cc_coords = [np.fft.fftfreq(s, 1 / s) for s in new_fft_shapes]
    if margin > 0:
        mask = np.zeros([proj_vwu.shape[d] for d in fft_dims], dtype=proj_vwu.dtype)
        slices = [slice(margin, proj_vwu.shape[d] - margin) for d in fft_dims]
        mask[tuple(slices)] = 1.0
        proj_vwu = proj_vwu * mask[..., None, :]
    if len(fft_dims) == 2:
        shifts_vu = np.empty((len(data_vwu.shape) - 1, num_angles))
        slices = [slice(None)] * len(data_vwu.shape)
        for ii_a in range(num_angles):
            # For performance reasons, it is better to do the fft on each image
            slices[stack_axis] = slice(ii_a, ii_a + 1)
            data_vu = data_vwu[tuple(slices)].squeeze(axis=stack_axis)
            if proj_vwu.shape[stack_axis] == 1:
                proj_vu = proj_vwu.squeeze(axis=stack_axis)
            else:
                proj_vu = proj_vwu[tuple(slices)].squeeze(axis=stack_axis)
            data_vwu_f = local_fftn(data_vu, s=list(new_fft_shapes))
            proj_vwu_f = local_fftn(proj_vu, s=list(new_fft_shapes))
            cc_f = data_vwu_f * proj_vwu_f.conj()
            if normalize_fourier:
                cc_f /= np.fmax(np.abs(cc_f), eps)
            cc_r: NDArrayFloat = local_ifftn(cc_f).real
            f_vals, f_coords = extract_peak_region_nd(cc_r, cc_coords=cc_coords)
            shifts_vu[..., ii_a] = np.array([f_coords[0][1], f_coords[1][1]])
            if decimals > 0:
                f_vals_v = f_vals[:, 1]
                f_vals_u = f_vals[1, :]
                sub_pixel_v = refine_max_position_1d(f_vals_v, decimals=decimals)
                sub_pixel_u = refine_max_position_1d(f_vals_u, decimals=decimals)
                shifts_vu[..., ii_a] += [sub_pixel_v, sub_pixel_u]
    else:
        data_vwu_f = local_fftn(data_vwu, s=list(new_fft_shapes), axes=list(fft_dims))
        proj_vwu_f = local_fftn(proj_vwu, s=list(new_fft_shapes), axes=list(fft_dims))
        ccs_f = data_vwu_f * proj_vwu_f.conj()
        if normalize_fourier:
            ccs_f /= np.fmax(np.abs(ccs_f).max(axis=u_axis, keepdims=True), eps)
        ccs = local_ifftn(ccs_f, axes=fft_dims).real
        f_vals, f_h = extract_peak_regions_1d(ccs, axis=u_axis, cc_coords=cc_coords[u_axis])
        shifts_vu = f_h[1, :]
        if decimals > 0:
            shifts_vu += refine_max_position_1d(f_vals, decimals=decimals)
    # import skimage.registration as skr
    # upsample_factor = int(1 / 10 ** (-decimals))
    # shifts_vu = np.empty((len(data_vwu.shape) - 1, num_angles))
    # for ii in range(num_angles):
    #     shifts_vu[..., ii] = skr.phase_cross_correlation(
    #         data_vwu[..., ii, :], proj_vwu[..., ii, :], upsample_factor=upsample_factor, return_error=False
    #     )
    return shifts_vu 
[docs]
def fit_shifts_zyx_xc(
    ref_vol_zyx: NDArrayFloat,
    rec_vol_zyx: NDArrayFloat,
    pad_zyx: bool = False,
    normalize_fourier: bool = True,
    use_rfft: bool = True,
    decimals: int = 2,
) -> NDArrayFloat:
    """
    Find the ZYX shifts of the volume, through cross-correlation.
    Parameters
    ----------
    ref_vol_zyx : NDArrayFloat
        The reference volume.
    rec_vol_zyx : NDArrayFloat
        The reconstructed volume to register.
    pad_zyx : bool, optional
        Pad the ZYX coordinates. The default is False.
    normalize_fourier : bool, optional
        Whether to normalize the Fourier representation of the cross-correlation. The default is True.
    use_rfft : bool, optional
        Whether to use the `rfft` transform in place of the complex `fft` transform. The default is True.
    decimals : int, optional
        Decimals for the truncation of the sub-pixel  The default is 2.
    Returns
    -------
    NDArrayFloat
        The ZYX shifts.
    """
    if use_rfft:
        local_fftn = np.fft.rfftn
        local_ifftn = np.fft.irfftn
    else:
        local_fftn = np.fft.fftn
        local_ifftn = np.fft.ifftn
    fft_dims = np.arange(-np.fmin(ref_vol_zyx.ndim, 3), 0)
    old_fft_shapes = np.array(np.array(ref_vol_zyx.shape)[fft_dims], ndmin=1, dtype=int)
    new_fft_shapes = old_fft_shapes.copy()
    if pad_zyx:
        new_fft_shapes *= 2
    cc_coords = [np.fft.fftfreq(s, 1 / s) for s in new_fft_shapes]
    ref_vol_zyx_f = local_fftn(ref_vol_zyx, s=list(new_fft_shapes), axes=fft_dims)
    rec_vol_zyx_f = local_fftn(rec_vol_zyx, s=list(new_fft_shapes), axes=fft_dims)
    cc_f = ref_vol_zyx_f * rec_vol_zyx_f.conj()
    if normalize_fourier:
        cc_f /= np.fmax(np.abs(cc_f), eps)
    cc: NDArrayFloat = local_ifftn(cc_f).real
    f_vals, f_coords = extract_peak_region_nd(cc, cc_coords=cc_coords)
    shifts_zyx = np.array([coords[1] for coords in f_coords])
    if decimals > 0:
        for ii, dim in enumerate(fft_dims):
            slices = [slice(1, 2)] * ref_vol_zyx.ndim
            slices[dim] = slice(None)
            f_vals_slice = f_vals[tuple(slices)].flatten()
            sub_pixel_pos = refine_max_position_1d(f_vals_slice, decimals=decimals)
            shifts_zyx[ii] += sub_pixel_pos
    return shifts_zyx 
[docs]
def fit_image_rotation_and_scale(
    img_1_vu: NDArray, img_2_vu: NDArray, pad_mode: Union[str, None] = None, window_type: str = "hann", verbose: bool = False
) -> tuple[float, float]:
    """Fit the rotation and scaling of an image against a reference image. This works best for larger rotation angles.
    Parameters
    ----------
    img_1_vu : NDArray
        Reference image
    img_2_vu : NDArray
        Rotated and scaled image
    pad_mode : Union[str, None], optional
        Padding mode, by default None
    window_type : str, optional
        Windowing type (to cud the high frequency aliasing), by default "hann"
    verbose : bool, optional
        Whether to give verbose output, by default False
    Returns
    -------
    tuple[float, float]
        The rotation (in degrees) and scale of the second image with respect to the first
    Raises
    ------
    ValueError
        In case of mismatching shape of the two images.
    """
    if img_1_vu.ndim != img_2_vu.ndim or np.any(np.array(img_1_vu.shape) != np.array(img_2_vu.shape)):
        raise ValueError(
            f"Image shapes should be identical, but instead got image #1: {img_1_vu.shape}, and image #2: {img_2_vu.shape}"
        )
    axes = (-2, -1)
    img_shape = img_2_vu.shape
    if pad_mode is not None:
        pad_widths = [(s // 2,) for s in img_shape]
        img_1_vu = np.pad(img_1_vu, pad_width=pad_widths, mode=pad_mode)
        img_2_vu = np.pad(img_2_vu, pad_width=pad_widths, mode=pad_mode)
        img_shape = img_2_vu.shape
    img_win = window(window_type=window_type, shape=img_shape)
    img_fft_1 = np.fft.fft2(img_1_vu * img_win, axes=axes)
    img_fft_2 = np.fft.fft2(img_2_vu * img_win, axes=axes)
    # abs removes the translation component
    img_fft_1 = np.abs(np.fft.fftshift(img_fft_1, axes=axes))
    img_fft_2 = np.abs(np.fft.fftshift(img_fft_2, axes=axes))
    # transform to polar coordinates
    img_center = [s - s // 2 for s in img_shape]
    radius = min([s // 2 for s in img_shape])
    img_fft_1_p = warp_polar(img_fft_1, center=img_center, scaling="log", radius=radius)
    img_fft_2_p = warp_polar(img_fft_2, center=img_center, scaling="log", radius=radius)
    # only use half of FFT
    img_fft_1_p = img_fft_1_p[..., : img_fft_1_p.shape[0] // 2, :]
    img_fft_2_p = img_fft_2_p[..., : img_fft_2_p.shape[0] // 2, :]
    fft_polar_shifts_rs = fit_shifts_vu_xc(img_fft_1_p[:, None, 1:], img_fft_2_p[:, None, 1:], normalize_fourier=True)
    tilt_pix = np.squeeze(fft_polar_shifts_rs[0])
    tilt_deg = (180 / img_fft_2_p.shape[0]) * tilt_pix
    klog = img_fft_2_p.shape[1] / np.log(radius)
    scale = np.exp(np.squeeze(fft_polar_shifts_rs[1]) / klog)
    if verbose:
        print(f"Fitted image rotation: {tilt_deg:.6} (degrees) or {tilt_pix} (pixels), with scale factor: {scale:.6}")
    return tilt_deg, scale 
[docs]
def fit_camera_tilt_angle(img_1: NDArray, img_2: NDArray, pad_u: bool = False, fit_l1: bool = True, verbose: bool = False):
    """
    Estimate the camera tilt angle based on correlation peak values between two images.
    Parameters
    ----------
    img_1: NDArray
        The first image.
    img_2: NDArray
        The second image.
    pad_u: bool, optional
        Enable zero padding. Default is False.
    fit_l1: bool, optional
        Perform L1 norm fitting if True. Default is True.
    verbose: bool, optional
        Enable verbose output. Default is False.
    Returns
    -------
    tuple[float, float]
        Tuple containing the estimated center of rotation offset (pixels) and camera tilt angle (degrees).
    """
    fitted_shifts_h = fit_shifts_vu_xc(img_1, img_2, pad_u=pad_u)
    fitted_cors = fitted_shifts_h / 2
    # Computing tilt
    img_shape = img_2.shape
    half_img_size = (img_shape[-2] - 1) / 2
    cc_v_coords = np.linspace(-half_img_size, half_img_size, img_shape[-2])
    poly_slope = Polynomial.fit(cc_v_coords, fitted_cors, deg=1)
    b, a = poly_slope.convert().coef
    if fit_l1:
        def f(coeffs: NDArray) -> float:
            b, a = coeffs[0], coeffs[1]
            pred_line = cc_v_coords * a + b
            l1_diff = np.linalg.norm(pred_line - fitted_cors, ord=1)
            return float(l1_diff)
        coeffs_opt = minimize(f, np.array([b, a]))
        b, a = coeffs_opt.x
    tilt_deg = np.rad2deg(-a / 2)
    cor_offset_pix = b
    if verbose:
        cor_trend = Polynomial([b, a])
        print(f"Fitted center of rotation (pixels): {cor_offset_pix}, and camera tilt (degrees): {tilt_deg}")
        fig, axs = plt.subplots(1, 1)
        axs.scatter(cc_v_coords, fitted_cors, label="Line CoRs")
        axs.plot(cc_v_coords, cor_trend(cc_v_coords), "-C1", label="Line CoRs trend")
        axs.axhline(cor_offset_pix, color="C2", linestyle="--", label=f"Image CoR ({cor_offset_pix:.3})")
        axs.set_title("Correlation peaks")
        axs.grid()
        axs.legend(fontsize=13)
        fig.tight_layout()
        plt.show(block=False)
    return cor_offset_pix, tilt_deg 
[docs]
def sinusoid(
    x: Union[NDArrayFloat, float], a: Union[NDArrayFloat, float], p: Union[NDArrayFloat, float], b: Union[NDArrayFloat, float]
) -> NDArrayFloat:
    """Compute the values of a sine function.
    Parameters
    ----------
    x : NDArrayFloat | float
        The independent variable.
    a : NDArrayFloat | float
        The amplitude of the sine.
    p : NDArrayFloat | float
        The phase of the sine.
    b : NDArrayFloat | float
        The bias of the sine.
    Returns
    -------
    NDArrayFloat
        The computed values.
    """
    return a * np.sin(x + p) + b 
[docs]
def fit_sinusoid(angles: NDArrayFloat, values: NDArrayFloat, fit_l1: bool = False) -> tuple[float, float, float]:
    """Fits a sinusoid to the given values.
    Parameters
    ----------
    angles : NDArrayFloat
        Angles where to evaluate the sinusoid.
    values : NDArrayFloat
        Values of the sinusoid.
    fit_l1 : bool, optional
        Whether to use l1 fit instead of the l2 fit, by default False
    Returns
    -------
    Tuple[float, float, float]
        The amplitude, phase and bias of the sinusoid.
    """
    a0 = (values.max() - values.min()) / 2
    b0 = (values.max() + values.min()) / 2
    (a, p, b), _ = spopt.curve_fit(sinusoid, angles, values, p0=[a0, 0, b0])
    if fit_l1:
        def f(apb: NDArrayFloat) -> float:
            a, p, b = apb[0], apb[1], apb[2]
            pred_sinusoid = sinusoid(angles, a, p, b)
            l1_diff = np.linalg.norm(pred_sinusoid - values, ord=1)
            return float(l1_diff)
        apb = spopt.minimize(f, np.array([a, p, b]))
        (a, p, b) = apb.x
    return a, p, b 
[docs]
def refine_max_position_1d(
    f_vals: NDArrayFloat, f_x: Union[ArrayLike, NDArray, None] = None, return_vertex_val: bool = False, decimals: int = 2
) -> Union[NDArrayFloat, tuple[NDArrayFloat, NDArrayFloat]]:
    """Compute the sub-pixel max position of the given function sampling.
    Parameters
    ----------
    f_vals: NDArrayFloat
        Function values of the sampled points
    fx: ArrayLike, optional
        Coordinates of the sampled points
    return_vertex_val: boolean, option
        Enables returning the vertex values. Defaults to False.
    Raises
    ------
    ValueError
        In case position and values do not have the same size, or in case
        the fitted maximum is outside the fitting region.
    Returns
    -------
    float
        Estimated function max, according to the coordinates in fx.
    """
    if not len(f_vals.shape) in (1, 2):
        raise ValueError(
            f"The fitted values should be either one or a collection of 1-dimensional arrays."
            f" Array of shape: {f_vals.shape} was given."
        )
    num_vals = f_vals.shape[0]
    if f_x is None:
        f_x_half_size = (num_vals - 1) / 2
        f_x = np.linspace(-f_x_half_size, f_x_half_size, num_vals)
    else:
        f_x = np.squeeze(f_x)
        if not (len(f_x.shape) == 1 and np.all(f_x.size == num_vals)):
            raise ValueError(
                f"Base coordinates should have the same length as values array. Sizes of fx: {f_x.size}, f_vals: {num_vals}"
            )
    if len(f_vals.shape) == 1:
        # using Polynomial.fit, because supposed to be more numerically
        # stable than previous solutions (according to numpy).
        poly = Polynomial.fit(f_x, f_vals, deg=2)
        coeffs = poly.convert().coef
    else:
        coords = np.array([np.ones(num_vals), f_x, f_x**2])
        coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]
    # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
    # x_v = -b / 2a.
    vertex_x = -coeffs[1, ...] / (2 * coeffs[2, ...])
    vertex_x = np.around(vertex_x, decimals=decimals)
    vertex_min_x = np.min(f_x)
    vertex_max_x = np.max(f_x)
    lower_bound_ok = vertex_min_x < vertex_x
    upper_bound_ok = vertex_x < vertex_max_x
    if not np.all(lower_bound_ok * upper_bound_ok):
        if len(f_vals.shape) == 1:
            message = (
                f"Fitted position {vertex_x} is outside the input margins [{vertex_min_x}, {vertex_max_x}]."
                f" Input values: {f_vals}"
            )
        else:
            message = (
                f"Fitted positions outside the input margins [{vertex_min_x}, {vertex_max_x}]:"
                f" {np.sum(1 - lower_bound_ok)} below and {np.sum(1 - upper_bound_ok)} above"
            )
        raise ValueError(message)
    if return_vertex_val:
        vertex_val = coeffs[0, ...] + vertex_x * coeffs[1, ...] / 2
        vertex_val = np.around(vertex_val, decimals=decimals)
        return vertex_x, vertex_val
    else:
        return vertex_x 
[docs]
def refine_max_position_2d(
    f_vals: NDArrayFloat, fy: Union[ArrayLike, NDArray, None] = None, fx: Union[ArrayLike, NDArray, None] = None
) -> NDArray:
    """Compute the sub-pixel max position of the given function sampling.
    Parameters
    ----------
    f_vals: NDArrayFloat
        Function values of the sampled points
    fy: ArrayLike, optional
        Vertical coordinates of the sampled points
    fx: ArrayLike, optional
        Horizontal coordinates of the sampled points
    Raises
    ------
    ValueError
        In case position and values do not have the same size, or in case
        the fitted maximum is outside the fitting region.
    Returns
    -------
    tuple(float, float)
        Estimated (vertical, horizontal) function max, according to the
        coordinates in fy and fx.
    """
    if not (len(f_vals.shape) == 2):
        raise ValueError(f"The fitted values should form a 2-dimensional array. Array of shape: {f_vals.shape} was given.")
    if fy is None:
        fy = np.linspace(-1, 1, f_vals.shape[0])
        y_scaling = (f_vals.shape[0] - 1) / 2
    else:
        fy = np.array(fy, ndmin=1)
        y_scaling = 1.0
        if not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                f"Vertical coordinates should have the same length as values matrix."
                f" Sizes of fy: {fy.size}, f_vals: {f_vals.shape}"
            )
    if fx is None:
        fx = np.linspace(-1, 1, f_vals.shape[1])
        x_scaling = (f_vals.shape[1] - 1) / 2
    else:
        fx = np.array(fx, ndmin=1)
        x_scaling = 1.0
    if not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
        raise ValueError(
            f"Horizontal coordinates should have the same length as values matrix."
            f"Sizes of fx: {fx.size}, f_vals: {f_vals.shape}"
        )
    fy, fx = np.meshgrid(fy, fx, indexing="ij")
    fy = fy.flatten()
    fx = fx.flatten()
    coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy**2, fx**2])
    coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
    coeffs *= [1, y_scaling, x_scaling, y_scaling * x_scaling, y_scaling**2, x_scaling**2]
    # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
    # x_v = -b / 2a. For a 2D parabola, the vertex position is:
    # (y, x)_v = - b / A, where:
    A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
    b = coeffs[1:3]
    vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]
    vertex_min_yx = [np.min(fy), np.min(fx)]
    vertex_max_yx = [np.max(fy), np.max(fx)]
    if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
        print(f_vals)
        raise ValueError(
            f"Fitted (yx: {vertex_yx}) positions are outside the input margins"
            + f" y: [{vertex_min_yx[0]}, {vertex_max_yx[0]}], and x: [{vertex_min_yx[1]}, {vertex_max_yx[1]}]."
            + f" Input values: {f_vals}"
        )
    return vertex_yx