diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 62be3da21..613a0e7cf 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -157,8 +157,7 @@ def x(self) -> randvars.RandomVariable: """Belief about the solution.""" if self._x is None: return self._induced_x() - else: - return self._x + return self._x @property def A(self) -> randvars.RandomVariable: diff --git a/src/probnum/linalg/solvers/matrixbased.py b/src/probnum/linalg/solvers/matrixbased.py index 0b5091175..5ab0dd166 100644 --- a/src/probnum/linalg/solvers/matrixbased.py +++ b/src/probnum/linalg/solvers/matrixbased.py @@ -149,8 +149,7 @@ def has_converged(self, iter, maxiter, **kwargs): "Iteration terminated. Solver reached the maximum number of iterations." ) return True, "maxiter" - else: - return False, "" + return False, "" def solve(self, callback=None, **kwargs): """Solve the linear system :math:`Ax=b`. @@ -315,7 +314,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b): A0_covfactor = self.A return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor # Construct matrix priors from initial guess x0 - elif isinstance(x0, np.ndarray): + if isinstance(x0, np.ndarray): A0_mean, Ainv0_mean = self._construct_symmetric_matrix_prior_means( A=self.A, x0=x0, b=b ) @@ -323,7 +322,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b): # Symmetric posterior correspondence A0_covfactor = self.A return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor - elif isinstance(x0, randvars.RandomVariable): + if isinstance(x0, randvars.RandomVariable): raise NotImplementedError # Prior on Ainv specified @@ -360,7 +359,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b): return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor # Prior on A specified - elif A0 is not None and not isinstance(Ainv0, randvars.RandomVariable): + if A0 is not None and not isinstance(Ainv0, randvars.RandomVariable): if isinstance(A0, randvars.RandomVariable): A0_mean = A0.mean A0_covfactor = A0.cov.A @@ -392,7 +391,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b): Ainv0_covfactor = Ainv0_mean return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor # Both matrix priors on A and H specified via random variables - elif isinstance(A0, randvars.RandomVariable) and isinstance( + if isinstance(A0, randvars.RandomVariable) and isinstance( Ainv0, randvars.RandomVariable ): A0_mean = A0.mean @@ -400,8 +399,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b): Ainv0_mean = Ainv0.mean Ainv0_covfactor = Ainv0.cov.A return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor - else: - raise NotImplementedError + raise NotImplementedError def _compute_trace_Ainv_covfactor0(self, Y, unc_scale): """Computes the trace of the prior covariance factor for the inverse view. @@ -512,15 +510,14 @@ def has_converged(self, iter, maxiter, resid=None, atol=None, rtol=None): b_norm = np.linalg.norm(self.b) if resid_norm <= atol: return True, "resid_atol" - elif resid_norm <= rtol * b_norm: + if resid_norm <= rtol * b_norm: return True, "resid_rtol" # uncertainty-based if np.sqrt(self.trace_sol_cov) <= atol: return True, "tracecov_atol" - elif np.sqrt(self.trace_sol_cov) <= rtol * b_norm: + if np.sqrt(self.trace_sol_cov) <= rtol * b_norm: return True, "tracecov_rtol" - else: - return False, "" + return False, "" def _calibrate_uncertainty(self, S, sy, method): """Calibrate uncertainty based on the Rayleigh coefficients. diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 4896a315e..d961bf3c6 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -61,26 +61,25 @@ def __call__( ) return residual + # Reorthogonalization of the residual + if self._reorthogonalization_fn_residual is not None: + residual, prev_residual = self._reorthogonalized_residual( + solver_state=solver_state + ) else: - # Reorthogonalization of the residual - if self._reorthogonalization_fn_residual is not None: - residual, prev_residual = self._reorthogonalized_residual( - solver_state=solver_state - ) - else: - prev_residual = solver_state.residuals[solver_state.step - 1] + prev_residual = solver_state.residuals[solver_state.step - 1] - # A-conjugacy correction (in exact arithmetic) - beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 - action = residual + beta * solver_state.actions[solver_state.step - 1] + # A-conjugacy correction (in exact arithmetic) + beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 + action = residual + beta * solver_state.actions[solver_state.step - 1] - # Reorthogonalization of the resulting action - if self._reorthogonalization_fn_action is not None: - action = self._reorthogonalized_action( - action=action, solver_state=solver_state - ) + # Reorthogonalization of the resulting action + if self._reorthogonalization_fn_action is not None: + action = self._reorthogonalized_action( + action=action, solver_state=solver_state + ) - return action + return action def _reorthogonalized_residual( self, diff --git a/tox.ini b/tox.ini index 4705caac6..1eb773fea 100644 --- a/tox.ini +++ b/tox.ini @@ -68,7 +68,7 @@ commands = # Per-package Linting Passes pylint src/probnum/diffeq --disable="redefined-outer-name,too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unnecessary-pass,unused-variable,unused-argument,no-self-use,duplicate-code,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0 pylint src/probnum/filtsmooth --disable="no-member,arguments-differ,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unused-variable,unused-argument,no-self-use,duplicate-code,useless-param-doc" --jobs=0 - pylint src/probnum/linalg --disable="no-member,abstract-method,arguments-differ,else-if-used,redefined-builtin,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unused-argument,attribute-defined-outside-init,no-else-return,no-else-raise,no-self-use,duplicate-code,missing-module-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0 + pylint src/probnum/linalg --disable="no-member,abstract-method,arguments-differ,redefined-builtin,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unused-argument,attribute-defined-outside-init,no-else-raise,no-self-use,else-if-used,duplicate-code,missing-module-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0 pylint src/probnum/linops --disable="too-many-instance-attributes,too-many-arguments,too-many-locals,protected-access,no-else-return,no-else-raise,else-if-used,missing-class-docstring,missing-function-docstring,missing-raises-doc,duplicate-code" --jobs=0 pylint src/probnum/problems --disable="too-many-arguments,too-many-locals,unused-variable,unused-argument,consider-using-from-import,duplicate-code,missing-module-docstring,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc" --jobs=0 pylint src/probnum/quad --disable="too-many-arguments,missing-module-docstring" --jobs=0