#!/usr/bin/env python3"""Solvers for the tomographic reconstruction problem.@author: Nicola VIGANÒ, Computational Imaging group, CWI, The Netherlands,and ESRF - The European Synchrotron, Grenoble, France"""importcopyascpfromabcimportABC,abstractmethodfromtypingimportAny,Callable,Optional,Unionfromcollections.abcimportSequenceimportnumpyasnpfromnumpy.typingimportDTypeLike,NDArrayfromtqdm.autoimporttqdmfrom.importdata_terms,filters,operators,projectors,regularizerseps=np.finfo(np.float32).epsNDArrayFloat=NDArray[np.floating]
[docs]classSolver(ABC):""" Initialize the base solver class. Parameters ---------- verbose : bool, optional Turn on verbose output. The default is False. tolerance : Optional[float], optional Tolerance on the data residual for computing when to stop iterations. The default is None. relaxation : float, optional The relaxation length. The default is 1.0. data_term : Union[str, data_terms.DataFidelityBase], optional Data fidelity term for computing the data residual. The default is "l2". data_term_test : Optional[data_terms.DataFidelityBase], optional The data fidelity to be used for the test set. If None, it will use the same as for the rest of the data. The default is None. """def__init__(self,verbose:bool=False,leave_progress:bool=True,relaxation:float=1.0,tolerance:Optional[float]=None,data_term:Union[str,data_terms.DataFidelityBase]="l2",data_term_test:Union[str,data_terms.DataFidelityBase,None]=None,):self.verbose=verboseself.leave_progress=leave_progressself.relaxation=relaxationself.tolerance=toleranceself.data_term=self._initialize_data_fidelity_function(data_term)ifdata_term_testisNone:data_term_test=self.data_termelse:data_term_test=self._initialize_data_fidelity_function(data_term_test)self.data_term_test=cp.deepcopy(data_term_test)
[docs]definfo(self)->str:""" Return the solver info. Returns ------- str Solver info string. """returntype(self).__name__
[docs]defupper(self)->str:""" Return the upper case name of the solver. Returns ------- str Upper case string name of the solver. """returntype(self).__name__.upper()
[docs]deflower(self)->str:""" Return the lower case name of the solver. Returns ------- str Lower case string name of the solver. """returntype(self).__name__.lower()
[docs]@abstractmethoddef__call__(self,A:operators.BaseTransform,b:NDArrayFloat,*args:Any,**kwds:Any)->tuple[NDArrayFloat,SolutionInfo]:"""Execute the reconstruction of the data. Parameters ---------- A : operators.BaseTransform The projection operator. b : NDArrayFloat The data to be reconstructed. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction and related information. """
[docs]@staticmethoddef_initialize_data_fidelity_function(data_term:Union[str,data_terms.DataFidelityBase])->data_terms.DataFidelityBase:ifisinstance(data_term,str):ifdata_term.lower()=="l2":returndata_terms.DataFidelity_l2()elifdata_term.lower()=="kl":returndata_terms.DataFidelity_KL()else:raiseValueError(f"Unknown data term: '{data_term}', only accepted terms are: 'l2' | 'kl'.")elifisinstance(data_term,(data_terms.DataFidelity_l2,data_terms.DataFidelity_KL)):returncp.deepcopy(data_term)else:raiseValueError(f"Unsupported data term: '{data_term.info()}', only accepted terms are 'kl' and 'l2'-based.")
[docs]@staticmethoddef_initialize_regularizer(regularizer:Union[regularizers.BaseRegularizer,None,Sequence[regularizers.BaseRegularizer]])->Sequence[regularizers.BaseRegularizer]:ifregularizerisNone:return[]elifisinstance(regularizer,regularizers.BaseRegularizer):return[regularizer]elifisinstance(regularizer,(list,tuple)):check_regs_ok=[isinstance(r,regularizers.BaseRegularizer)forrinregularizer]ifnotnp.all(check_regs_ok):raiseValueError("The following regularizers are not derived from the regularizers.BaseRegularizer class: "f"{np.array(np.arange(len(check_regs_ok))[np.array(check_regs_ok,dtype=bool)])}")else:returnlist(regularizer)else:raiseValueError("Unknown regularizer type.")
[docs]@staticmethoddef_initialize_b_masks(b:NDArrayFloat,b_mask:Optional[NDArrayFloat],b_test_mask:Optional[NDArrayFloat])->tuple[Optional[NDArrayFloat],Optional[NDArrayFloat]]:ifb_test_maskisnotNone:ifb_maskisNone:b_mask=np.ones_like(b)# As we are being passed a test residual pixel mask, we need# to make sure to mask those pixels out from the reconstruction.# At the same time, we need to remove any masked pixel from the test count.b_mask,b_test_mask=b_mask*(1-b_test_mask),b_test_mask*b_maskreturn(b_mask,b_test_mask)
[docs]classFBP(Solver):"""Implementation of the Filtered Back-Projection (FBP) algorithm."""def__init__(self,verbose:bool=False,leave_progress:bool=False,regularizer:Union[Sequence[regularizers.BaseRegularizer],regularizers.BaseRegularizer,None]=None,data_term:Union[str,data_terms.DataFidelityBase]="l2",fbp_filter:Union[str,NDArrayFloat,filters.Filter]="ramp",pad_mode:str="constant",):"""Initialize the Filtered Back-Projection (FBP) algorithm. Parameters ---------- verbose : bool, optional Turn on verbose output. The default is False. leave_progress: bool, optional Leave the progress bar after the computation is finished. The default is True. regularizer : Sequence[regularizers.BaseRegularizer] | regularizers.BaseRegularizer | None, optional NOT USED, only exposed for compatibility reasons. data_term : Union[str, data_terms.DataFidelityBase], optional NOT USED, only exposed for compatibility reasons. fbp_filter : Union[str, NDArrayFloat], optional FBP filter to use. Either a string from scikit-image's list of `iradon` filters, or an array. The default is "ramp". pad_mode: str, optional The padding mode to use for the linear convolution. The default is "constant". """super().__init__(verbose=verbose)ifisinstance(fbp_filter,str):fbp_filter=fbp_filter.lower()self.fbp_filter=fbp_filterself.pad_mode=pad_mode
[docs]definfo(self)->str:""" Return the solver info. Returns ------- str Solver info string. """ifisinstance(self.fbp_filter,str):returnsuper().info()+"(F:"+self.fbp_filter.upper()+")"elifisinstance(self.fbp_filter,np.ndarray):returnsuper().info()+"(F:"+filters.FilterCustom.__name__.upper()+")"else:returnsuper().info()+"(F:"+type(self.fbp_filter).__name__.upper()+")"
[docs]def__call__(# noqa: C901self,A:operators.BaseTransform,b:NDArrayFloat,iterations:int=0,x0:Optional[NDArrayFloat]=None,lower_limit:Union[float,NDArrayFloat,None]=None,upper_limit:Union[float,NDArrayFloat,None]=None,x_mask:Optional[NDArrayFloat]=None,b_mask:Optional[NDArrayFloat]=None,)->tuple[NDArrayFloat,SolutionInfo]:""" Reconstruct the data, using the FBP algorithm. Parameters ---------- A : BaseTransform Projection operator. b : NDArrayFloat Data to reconstruct. iterations : int Number of iterations. x0 : Optional[NDArrayFloat], optional Initial solution. The default is None. lower_limit : Union[float, NDArrayFloat], optional Lower clipping value. The default is None. upper_limit : Union[float, NDArrayFloat], optional Upper clipping value. The default is None. x_mask : Optional[NDArrayFloat], optional Solution mask. The default is None. b_mask : Optional[NDArrayFloat], optional Data mask. The default is None. Raises ------ ValueError In case the data is 1D. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction, and None. """iflen(b.shape)<2:raiseValueError(f"Data should be at least 2-dimensional (b.shape = {b.shape})")info=SolutionInfo(self.info(),max_iterations=0,tolerance=0.0)ifisinstance(self.fbp_filter,str):ifself.fbp_filterin("mr","data"):local_filter=filters.FilterMR(projector=A)else:local_filter=filters.FilterFBP(filter_name=self.fbp_filter)elifisinstance(self.fbp_filter,np.ndarray):local_filter=filters.FilterCustom(self.fbp_filter)else:local_filter=self.fbp_filterlocal_filter.pad_mode=self.pad_modeifisinstance(A,operators.ProjectorOperator):pre_weights=A.get_pre_weights()ifpre_weightsisnotNone:b=b*pre_weightsb_f=local_filter(b)x=A.T(b_f)iflower_limitisnotNoneorupper_limitisnotNone:x=x.clip(lower_limit,upper_limit)ifx_maskisnotNone:x*=x_maskreturnx,info
[docs]classSART(Solver):"""Solver class implementing the Simultaneous Algebraic Reconstruction Technique (SART) algorithm."""
[docs]defcompute_residual(self,A:Callable,b:NDArrayFloat,x:NDArrayFloat,A_num_rows:int,b_mask:Optional[NDArrayFloat],)->NDArrayFloat:"""Compute the solution residual. Parameters ---------- A : Callable The forward projector. b : NDArrayFloat The detector data. x : NDArrayFloat The current solution A_num_rows : int The number of projections. b_mask : Optional[NDArrayFloat] The mask to apply Returns ------- NDArrayFloat The residual. """fp=np.stack([A(x,ii)foriiinrange(A_num_rows)],axis=-1)fp=np.ascontiguousarray(fp,dtype=b.dtype)res=fp-bifb_maskisnotNone:res*=b_maskreturnres
[docs]def__call__(# noqa: C901self,A:Union[Callable[[NDArray,int],NDArray],projectors.ProjectorUncorrected],b:NDArrayFloat,iterations:int,A_num_rows:Optional[int]=None,At:Optional[Callable]=None,x0:Optional[NDArrayFloat]=None,lower_limit:Union[float,NDArrayFloat,None]=None,upper_limit:Union[float,NDArrayFloat,None]=None,x_mask:Optional[NDArrayFloat]=None,b_mask:Optional[NDArrayFloat]=None,)->tuple[NDArrayFloat,SolutionInfo]:""" Reconstruct the data, using the SART algorithm. Parameters ---------- A : Union[Callable, BaseTransform] Projection operator. b : NDArrayFloat Data to reconstruct. iterations : int Number of iterations. A_num_rows : int Number of projections. x0 : Optional[NDArrayFloat], optional Initial solution. The default is None. At : Callable, optional The back-projection operator. This is only needed if the projection operator does not have an adjoint. The default is None. lower_limit : Union[float, NDArrayFloat], optional Lower clipping value. The default is None. upper_limit : Union[float, NDArrayFloat], optional Upper clipping value. The default is None. x_mask : Optional[NDArrayFloat], optional Solution mask. The default is None. b_mask : Optional[NDArrayFloat], optional Data mask. The default is None. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction, and the residuals. """ifisinstance(A,projectors.ProjectorUncorrected):p=Aifnotp.projector_backend.has_individual_projs:raiseValueError("The projector needs to have enabled single projections.")A=lambdax,ii:p.fp_angle(x,ii)# noqa: E731ifisinstance(p,projectors.ProjectorAttenuationXRF):At=lambday,ii:p.bp_angle(y,ii,single_line=True)# noqa: E731else:At=lambday,ii:p.bp_angle(y,ii)# noqa: E731A_num_rows=len(p.angles_rot_rad)elifAtisNone:raiseValueError("Parameter `At` is required, if `A` is not a projector.")elifA_num_rowsisNone:raiseValueError("Parameter `A_num_rows` is required, if `A` is not a projector.")# Back-projection diagonal re-scalingb_ones=np.ones_like(b)ifb_maskisnotNone:b_ones*=b_masktau=[At(b_ones[...,ii,:],ii)foriiinrange(A_num_rows)]tau=np.abs(np.stack(tau,axis=-2))tau[(tau/np.max(tau))<1e-5]=1tau=self.relaxation/tau# Forward-projection diagonal re-scalingx_ones=np.ones([*tau.shape[:-2],tau.shape[-1]],dtype=tau.dtype)ifx_maskisnotNone:x_ones*=x_masksigma=[A(x_ones,ii)foriiinrange(A_num_rows)]sigma=np.abs(np.stack(sigma,axis=-2))sigma[(sigma/np.max(sigma))<1e-5]=1sigma=1/sigmaifx0isNone:x0=np.zeros_like(x_ones)else:x0=np.array(x0).copy()x=x0info=SolutionInfo(self.info(),max_iterations=iterations,tolerance=self.tolerance)ifself.toleranceisnotNone:res=self.compute_residual(A,b,x,A_num_rows=A_num_rows,b_mask=b_mask)info.residual0=np.linalg.norm(res.flatten())rows_sequence=np.random.permutation(A_num_rows)algo_info=f"- Performing {self.upper()} iterations: "foriiintqdm(range(iterations),desc=algo_info,disable=(notself.verbose),leave=self.leave_progress):info.iterations+=1forii_ainrows_sequence:res=A(x,ii_a)-b[...,ii_a,:]ifb_maskisnotNone:res*=b_mask[...,ii_a,:]x-=At(res*sigma[...,ii_a,:],ii_a)*tau[...,ii_a,:]iflower_limitisnotNone:x=np.fmax(x,lower_limit)ifupper_limitisnotNone:x=np.fmin(x,upper_limit)ifx_maskisnotNone:x*=x_maskifself.toleranceisnotNone:res=self.compute_residual(A,b,x,A_num_rows=A_num_rows,b_mask=b_mask)info.residuals[ii]=np.linalg.norm(res)ifself.tolerance>info.residuals[ii]:breakreturnx,info
[docs]classMLEM(Solver):""" Initialize the MLEM solver class. This class implements the Maximul Likelihood Expectation Maximization (MLEM) algorithm. Parameters ---------- verbose : bool, optional Turn on verbose output. The default is False. leave_progress: bool, optional Leave the progress bar after the computation is finished. The default is True. tolerance : Optional[float], optional Tolerance on the data residual for computing when to stop iterations. The default is None. regularizer : Sequence[regularizers.BaseRegularizer] | regularizers.BaseRegularizer | None, optional Regularizer to be used. The default is None. data_term : Union[str, data_terms.DataFidelityBase], optional Data fidelity term for computing the data residual. The default is "l2". data_term_test : Optional[data_terms.DataFidelityBase], optional The data fidelity to be used for the test set. If None, it will use the same as for the rest of the data. The default is None. """def__init__(self,verbose:bool=False,leave_progress:bool=True,tolerance:Optional[float]=None,regularizer:Union[Sequence[regularizers.BaseRegularizer],regularizers.BaseRegularizer,None]=None,data_term:Union[str,data_terms.DataFidelityBase]="kl",data_term_test:Union[str,data_terms.DataFidelityBase,None]=None,):super().__init__(verbose=verbose,leave_progress=leave_progress,tolerance=tolerance,data_term=data_term,data_term_test=data_term_test,)self.regularizer=self._initialize_regularizer(regularizer)
[docs]definfo(self)->str:""" Return the MLEM info. Returns ------- str info string. """returnSolver.info(self)+f"(B:{self.data_term.background:g})"ifself.data_term.backgroundisnotNoneelse""
[docs]def__call__(# noqa: C901self,A:operators.BaseTransform,b:NDArrayFloat,iterations:int,x0:Optional[NDArrayFloat]=None,lower_limit:Union[float,NDArrayFloat,None]=None,upper_limit:Union[float,NDArrayFloat,None]=None,x_mask:Optional[NDArrayFloat]=None,b_mask:Optional[NDArrayFloat]=None,b_test_mask:Optional[NDArrayFloat]=None,)->tuple[NDArrayFloat,SolutionInfo]:""" Reconstruct the data, using the MLEM algorithm. Parameters ---------- A : BaseTransform Projection operator. b : NDArrayFloat Data to reconstruct. iterations : int Number of iterations. x0 : Optional[NDArrayFloat], optional Initial solution. The default is None. lower_limit : Union[float, NDArrayFloat], optional Lower clipping value. The default is None. upper_limit : Union[float, NDArrayFloat], optional Upper clipping value. The default is None. x_mask : Optional[NDArrayFloat], optional Solution mask. The default is None. b_mask : Optional[NDArrayFloat], optional Data mask. The default is None. b_test_mask : Optional[NDArrayFloat], optional Test data mask. The default is None. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction, and the residuals. """b=np.array(b)(b_mask,b_test_mask)=self._initialize_b_masks(b,b_mask,b_test_mask)# Back-projection diagonal re-scalingb_ones=np.ones_like(b)ifb_maskisnotNone:b_ones*=b_masktau=A.T(b_ones)# Forward-projection diagonal re-scalingx_ones=np.ones_like(tau)ifx_maskisnotNone:x_ones*=x_masksigma=np.abs(A(x_ones))sigma[(sigma/np.max(sigma))<1e-5]=1sigma=1/sigmaifx0isNone:x=np.ones_like(tau)else:x=np.array(x0).copy()ifx_maskisnotNone:x*=x_maskself.data_term.assign_data(b)info=SolutionInfo(self.info(),max_iterations=iterations,tolerance=self.tolerance)ifb_test_maskisnotNoneorself.toleranceisnotNone:Ax=A(x)ifb_test_maskisnotNone:ifself.data_term_test.background!=self.data_term.background:print("WARNING - the data_term and and data_term_test should have the same background. Making them equal.")self.data_term_test.background=self.data_term.backgroundself.data_term_test.assign_data(b)res_test_0=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residual0_cv=self.data_term_test.compute_residual_norm(res_test_0)ifself.toleranceisnotNone:res_0=self.data_term.compute_residual(Ax,mask=b_mask)info.residual0=self.data_term.compute_residual_norm(res_0)reg_info="".join(["-"+r.info().upper()forrinself.regularizer])algo_info=f"- Performing {self.upper()}-{self.data_term.upper()}{reg_info} iterations: "foriiintqdm(range(iterations),desc=algo_info,disable=(notself.verbose),leave=self.leave_progress):info.iterations+=1# The MLEM updateAx=A(x)ifb_test_maskisnotNone:res_test=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residuals_cv[ii]=self.data_term_test.compute_residual_norm(res_test)ifself.toleranceisnotNone:res=self.data_term.compute_residual(Ax,mask=b_mask)info.residuals[ii]=self.data_term.compute_residual_norm(res)ifself.tolerance>info.residuals[ii]:breakifself.data_term.backgroundisnotNone:Ax=Ax+self.data_term.backgroundAx=Ax.clip(eps,None)upd=A.T(b/Ax)x*=upd/tauiflower_limitisnotNoneorupper_limitisnotNone:x=x.clip(lower_limit,upper_limit)ifx_maskisnotNone:x*=x_maskreturnx,info
[docs]classSIRT(Solver):""" Initialize the SIRT solver class. This class implements the Simultaneous Iterative Reconstruction Technique (SIRT) algorithm. Parameters ---------- verbose : bool, optional Turn on verbose output. The default is False. leave_progress: bool, optional Leave the progress bar after the computation is finished. The default is True. tolerance : Optional[float], optional Tolerance on the data residual for computing when to stop iterations. The default is None. relaxation : float, optional The relaxation length. The default is 1.95. regularizer : Sequence[regularizers.BaseRegularizer] | regularizers.BaseRegularizer | None, optional Regularizer to be used. The default is None. data_term : Union[str, data_terms.DataFidelityBase], optional Data fidelity term for computing the data residual. The default is "l2". data_term_test : Optional[data_terms.DataFidelityBase], optional The data fidelity to be used for the test set. If None, it will use the same as for the rest of the data. The default is None. """def__init__(self,verbose:bool=False,leave_progress:bool=True,relaxation:float=1.95,tolerance:Optional[float]=None,regularizer:Union[Sequence[regularizers.BaseRegularizer],regularizers.BaseRegularizer,None]=None,data_term:Union[str,data_terms.DataFidelityBase]="l2",data_term_test:Union[str,data_terms.DataFidelityBase,None]=None,):super().__init__(verbose=verbose,leave_progress=leave_progress,relaxation=relaxation,tolerance=tolerance,data_term=data_term,data_term_test=data_term_test,)self.regularizer=self._initialize_regularizer(regularizer)
[docs]definfo(self)->str:""" Return the SIRT info. Returns ------- str SIRT info string. """reg_info="".join(["-"+r.info().upper()forrinself.regularizer])returnSolver.info(self)+"-"+self.data_term.info()+reg_info
[docs]def__call__(# noqa: C901self,A:operators.BaseTransform,b:NDArrayFloat,iterations:int,x0:Optional[NDArrayFloat]=None,lower_limit:Union[float,NDArrayFloat,None]=None,upper_limit:Union[float,NDArrayFloat,None]=None,x_mask:Optional[NDArrayFloat]=None,b_mask:Optional[NDArrayFloat]=None,b_test_mask:Optional[NDArrayFloat]=None,)->tuple[NDArrayFloat,SolutionInfo]:""" Reconstruct the data, using the SIRT algorithm. Parameters ---------- A : BaseTransform Projection operator. b : NDArrayFloat Data to reconstruct. iterations : int Number of iterations. x0 : Optional[NDArrayFloat], optional Initial solution. The default is None. lower_limit : Union[float, NDArrayFloat], optional Lower clipping value. The default is None. upper_limit : Union[float, NDArrayFloat], optional Upper clipping value. The default is None. x_mask : Optional[NDArrayFloat], optional Solution mask. The default is None. b_mask : Optional[NDArrayFloat], optional Data mask. The default is None. b_test_mask : Optional[NDArrayFloat], optional Test data mask. The default is None. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction, and the residuals. """b=np.array(b)(b_mask,b_test_mask)=self._initialize_b_masks(b,b_mask,b_test_mask)# Back-projection diagonal re-scalingb_ones=np.ones_like(b)ifb_maskisnotNone:b_ones*=b_masktau=np.abs(A.T(b_ones))forreginself.regularizer:tau+=reg.initialize_sigma_tau(tau)tau[(tau/np.max(tau))<1e-5]=1tau=self.relaxation/tau# Forward-projection diagonal re-scalingx_ones=np.ones_like(tau)ifx_maskisnotNone:x_ones*=x_masksigma=np.abs(A(x_ones))sigma[(sigma/np.max(sigma))<1e-5]=1sigma=1/sigmaifx0isNone:x=np.zeros_like(x_ones)else:x=np.array(x0).copy()self.data_term.assign_data(b,sigma)info=SolutionInfo(self.info(),max_iterations=iterations,tolerance=self.tolerance)ifb_test_maskisnotNoneorself.toleranceisnotNone:Ax=A(x)res_0=self.data_term.compute_residual(Ax,mask=b_mask)info.residual0=self.data_term.compute_residual_norm(res_0)ifb_test_maskisnotNone:ifself.data_term_test.background!=self.data_term.background:print("WARNING - the data_term and and data_term_test should have the same background. Making them equal.")self.data_term_test.background=self.data_term.backgroundself.data_term_test.assign_data(b,sigma)res_test_0=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residual0_cv=self.data_term_test.compute_residual_norm(res_test_0)reg_info="".join(["-"+r.info().upper()forrinself.regularizer])algo_info=f"- Performing {self.upper()}-{self.data_term.upper()}{reg_info} iterations: "foriiintqdm(range(iterations),desc=algo_info,disable=(notself.verbose),leave=self.leave_progress):info.iterations+=1Ax=A(x)res=self.data_term.compute_residual(Ax,mask=b_mask)ifb_test_maskisnotNoneorself.toleranceisnotNone:info.residuals[ii]=self.data_term.compute_residual_norm(res)ifb_test_maskisnotNone:res_test=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residuals_cv[ii]=self.data_term_test.compute_residual_norm(res_test)ifself.toleranceisnotNoneandself.tolerance>info.residuals[ii]:ifself.verbose:print(f"Residual reached the desired tolerance of {self.tolerance}. Ending iterations..")breakq=[reg.initialize_dual()forreginself.regularizer]forq_r,reginzip(q,self.regularizer):reg.update_dual(q_r,x)reg.apply_proximal(q_r)upd=A.T(res*sigma)forq_r,reginzip(q,self.regularizer):upd-=reg.compute_update_primal(q_r)x+=upd*tauiflower_limitisnotNoneorupper_limitisnotNone:x=x.clip(lower_limit,upper_limit)ifx_maskisnotNone:x*=x_maskreturnx,info
[docs]classPDHG(Solver):""" Initialize the PDHG solver class. PDHG stands for primal-dual hybrid gradient algorithm from Chambolle and Pock. Parameters ---------- verbose : bool, optional Turn on verbose output. The default is False. leave_progress: bool, optional Leave the progress bar after the computation is finished. The default is True. tolerance : Optional[float], optional Tolerance on the data residual for computing when to stop iterations. The default is None. relaxation : float, optional The relaxation length. The default is 0.95. regularizer : Sequence[regularizers.BaseRegularizer] | regularizers.BaseRegularizer | None, optional Regularizer to be used. The default is None. data_term : Union[str, data_terms.DataFidelityBase], optional Data fidelity term for computing the data residual. The default is "l2". data_term_test : Optional[data_terms.DataFidelityBase], optional The data fidelity to be used for the test set. If None, it will use the same as for the rest of the data. The default is None. """def__init__(self,verbose:bool=False,leave_progress:bool=True,tolerance:Optional[float]=None,relaxation:float=0.95,regularizer:Union[Sequence[regularizers.BaseRegularizer],regularizers.BaseRegularizer,None]=None,data_term:Union[str,data_terms.DataFidelityBase]="l2",data_term_test:Union[str,data_terms.DataFidelityBase,None]=None,):super().__init__(verbose=verbose,leave_progress=leave_progress,relaxation=relaxation,tolerance=tolerance,data_term=data_term,data_term_test=data_term_test,)self.regularizer=self._initialize_regularizer(regularizer)
[docs]definfo(self)->str:""" Return the PDHG info. Returns ------- str PDHG info string. """reg_info="".join(["-"+r.info().upper()forrinself.regularizer])returnSolver.info(self)+"-"+self.data_term.info()+reg_info
[docs]defpower_method(self,A:operators.BaseTransform,b:NDArrayFloat,iterations:int=5)->tuple[np.floating,Sequence[int],DTypeLike]:""" Compute the l2-norm of the operator A, with the power method. Parameters ---------- A : BaseTransform Operator whose l2-norm needs to be computed. b : NDArrayFloat The data vector. iterations : int, optional Number of power method iterations. The default is 5. Returns ------- Tuple[float, Tuple[int], DTypeLike] The l2-norm of A, and the shape and type of the solution. """x:NDArrayFloat=np.array(np.random.rand(*b.shape))x=x.astype(b.dtype)x/=np.linalg.norm(x)x=A.T(x)x_norm=np.linalg.norm(x)L=x_normfor_inrange(iterations):x/=x_normx=A.T(A(x))x_norm=np.linalg.norm(x)L=np.sqrt(x_norm)return(L,x.shape,x.dtype)
[docs]def__call__(# noqa: C901self,A:operators.BaseTransform,b:NDArrayFloat,iterations:int,x0:Optional[NDArrayFloat]=None,lower_limit:Union[float,NDArrayFloat,None]=None,upper_limit:Union[float,NDArrayFloat,None]=None,x_mask:Optional[NDArrayFloat]=None,b_mask:Optional[NDArrayFloat]=None,b_test_mask:Optional[NDArrayFloat]=None,precondition:bool=True,)->tuple[NDArrayFloat,SolutionInfo]:""" Reconstruct the data, using the PDHG algorithm. Parameters ---------- A : BaseTransform Projection operator. b : NDArrayFloat Data to reconstruct. iterations : int Number of iterations. x0 : Optional[NDArrayFloat], optional Initial solution. The default is None. lower_limit : Union[float, NDArrayFloat], optional Lower clipping value. The default is None. upper_limit : Union[float, NDArrayFloat], optional Upper clipping value. The default is None. x_mask : Optional[NDArrayFloat], optional Solution mask. The default is None. b_mask : Optional[NDArrayFloat], optional Data mask. The default is None. b_test_mask : Optional[NDArrayFloat], optional Test data mask. The default is None. precondition : bool, optional Whether to use the preconditioned version of the algorithm. The default is True. Returns ------- Tuple[NDArrayFloat, SolutionInfo] The reconstruction, and the residuals. """b=np.array(b)ifprecondition:try:At_abs=A.T.absolute()A_abs=A.absolute()exceptAttributeError:print(A)print("WARNING: Turning off preconditioning because system matrix does not support absolute")precondition=False(b_mask,b_test_mask)=self._initialize_b_masks(b,b_mask,b_test_mask)ifprecondition:tau=np.ones_like(b)ifb_maskisnotNone:tau*=b_masktau=np.abs(At_abs(tau))forreginself.regularizer:tau+=reg.initialize_sigma_tau(tau)tau[(tau/np.max(tau))<1e-5]=1tau=self.relaxation/taux_shape=tau.shapex_dtype=tau.dtypesigma=np.ones_like(tau)ifx_maskisnotNone:sigma*=x_masksigma=np.abs(A_abs(sigma))sigma[(sigma/np.max(sigma))<1e-5]=1sigma=self.relaxation/sigmaelse:(x_shape,x_dtype,sigma,tau)=self._get_data_sigma_tau_unpreconditioned(A,b)ifx0isNone:x0=np.zeros(x_shape,dtype=x_dtype)else:x0=np.array(x0).copy()x=x0x_relax=x.copy()self.data_term.assign_data(b,sigma)p=self.data_term.initialize_dual()q=[reg.initialize_dual()forreginself.regularizer]info=SolutionInfo(self.info(),max_iterations=iterations,tolerance=self.tolerance)ifb_test_maskisnotNoneorself.toleranceisnotNone:Ax=A(x)res_0=self.data_term.compute_residual(Ax,mask=b_mask)info.residual0=self.data_term.compute_residual_norm(res_0)ifb_test_maskisnotNone:ifself.data_term_test.background!=self.data_term.background:print("WARNING - the data_term and and data_term_test should have the same background. Making them equal.")self.data_term_test.background=self.data_term.backgroundself.data_term_test.assign_data(b,sigma)res_test_0=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residual0_cv=self.data_term_test.compute_residual_norm(res_test_0)reg_info="".join(["-"+r.info().upper()forrinself.regularizer])algo_info=f"- Performing {self.upper()}-{self.data_term.upper()}{reg_info} iterations: "foriiintqdm(range(iterations),desc=algo_info,disable=(notself.verbose),leave=self.leave_progress):info.iterations+=1Ax_rlx=A(x_relax)self.data_term.update_dual(p,Ax_rlx)self.data_term.apply_proximal(p)ifb_maskisnotNone:p*=b_maskforq_r,reginzip(q,self.regularizer):reg.update_dual(q_r,x_relax)reg.apply_proximal(q_r)upd=A.T(p)forq_r,reginzip(q,self.regularizer):upd+=reg.compute_update_primal(q_r)x_new=x-upd*tauiflower_limitisnotNoneorupper_limitisnotNone:x_new=x_new.clip(lower_limit,upper_limit)ifx_maskisnotNone:x_new*=x_maskx_relax=x_new+(x_new-x)x=x_newifb_test_maskisnotNoneorself.toleranceisnotNone:Ax=A(x)res=self.data_term.compute_residual(Ax,mask=b_mask)info.residuals[ii]=self.data_term.compute_residual_norm(res)ifb_test_maskisnotNone:res_test=self.data_term_test.compute_residual(Ax,mask=b_test_mask)info.residuals_cv[ii]=self.data_term_test.compute_residual_norm(res_test)ifself.toleranceisnotNoneandself.tolerance>info.residuals[ii]:ifself.verbose:print(f"Residual reached the desired tolerance of {self.tolerance}. Ending iterations..")breakreturnx,info