Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed no-else-return message in probnum.linalg module #699

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def x(self) -> randvars.RandomVariable:
"""Belief about the solution."""
if self._x is None:
return self._induced_x()
else:
return self._x
return self._x

@property
def A(self) -> randvars.RandomVariable:
Expand Down
21 changes: 9 additions & 12 deletions src/probnum/linalg/solvers/matrixbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def has_converged(self, iter, maxiter, **kwargs):
"Iteration terminated. Solver reached the maximum number of iterations."
)
return True, "maxiter"
else:
return False, ""
return False, ""

def solve(self, callback=None, **kwargs):
"""Solve the linear system :math:`Ax=b`.
Expand Down Expand Up @@ -315,15 +314,15 @@ def _get_prior_params(self, A0, Ainv0, x0, b):
A0_covfactor = self.A
return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor
# Construct matrix priors from initial guess x0
elif isinstance(x0, np.ndarray):
if isinstance(x0, np.ndarray):
A0_mean, Ainv0_mean = self._construct_symmetric_matrix_prior_means(
A=self.A, x0=x0, b=b
)
Ainv0_covfactor = Ainv0_mean
# Symmetric posterior correspondence
A0_covfactor = self.A
return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor
elif isinstance(x0, randvars.RandomVariable):
if isinstance(x0, randvars.RandomVariable):
raise NotImplementedError

# Prior on Ainv specified
Expand Down Expand Up @@ -360,7 +359,7 @@ def _get_prior_params(self, A0, Ainv0, x0, b):
return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor

# Prior on A specified
elif A0 is not None and not isinstance(Ainv0, randvars.RandomVariable):
if A0 is not None and not isinstance(Ainv0, randvars.RandomVariable):
if isinstance(A0, randvars.RandomVariable):
A0_mean = A0.mean
A0_covfactor = A0.cov.A
Expand Down Expand Up @@ -392,16 +391,15 @@ def _get_prior_params(self, A0, Ainv0, x0, b):
Ainv0_covfactor = Ainv0_mean
return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor
# Both matrix priors on A and H specified via random variables
elif isinstance(A0, randvars.RandomVariable) and isinstance(
if isinstance(A0, randvars.RandomVariable) and isinstance(
Ainv0, randvars.RandomVariable
):
A0_mean = A0.mean
A0_covfactor = A0.cov.A
Ainv0_mean = Ainv0.mean
Ainv0_covfactor = Ainv0.cov.A
return A0_mean, A0_covfactor, Ainv0_mean, Ainv0_covfactor
else:
raise NotImplementedError
raise NotImplementedError

def _compute_trace_Ainv_covfactor0(self, Y, unc_scale):
"""Computes the trace of the prior covariance factor for the inverse view.
Expand Down Expand Up @@ -512,15 +510,14 @@ def has_converged(self, iter, maxiter, resid=None, atol=None, rtol=None):
b_norm = np.linalg.norm(self.b)
if resid_norm <= atol:
return True, "resid_atol"
elif resid_norm <= rtol * b_norm:
if resid_norm <= rtol * b_norm:
return True, "resid_rtol"
# uncertainty-based
if np.sqrt(self.trace_sol_cov) <= atol:
return True, "tracecov_atol"
elif np.sqrt(self.trace_sol_cov) <= rtol * b_norm:
if np.sqrt(self.trace_sol_cov) <= rtol * b_norm:
return True, "tracecov_rtol"
else:
return False, ""
return False, ""

def _calibrate_uncertainty(self, S, sy, method):
"""Calibrate uncertainty based on the Rayleigh coefficients.
Expand Down
31 changes: 15 additions & 16 deletions src/probnum/linalg/solvers/policies/_conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,25 @@ def __call__(
)

return residual
# Reorthogonalization of the residual
if self._reorthogonalization_fn_residual is not None:
residual, prev_residual = self._reorthogonalized_residual(
solver_state=solver_state
)
else:
# Reorthogonalization of the residual
if self._reorthogonalization_fn_residual is not None:
residual, prev_residual = self._reorthogonalized_residual(
solver_state=solver_state
)
else:
prev_residual = solver_state.residuals[solver_state.step - 1]
prev_residual = solver_state.residuals[solver_state.step - 1]

# A-conjugacy correction (in exact arithmetic)
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2
action = residual + beta * solver_state.actions[solver_state.step - 1]
# A-conjugacy correction (in exact arithmetic)
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2
action = residual + beta * solver_state.actions[solver_state.step - 1]

# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
action = self._reorthogonalized_action(
action=action, solver_state=solver_state
)
# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
action = self._reorthogonalized_action(
action=action, solver_state=solver_state
)

return action
return action

def _reorthogonalized_residual(
self,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ commands =
# Per-package Linting Passes
pylint src/probnum/diffeq --disable="redefined-outer-name,too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unnecessary-pass,unused-variable,unused-argument,no-self-use,duplicate-code,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0
pylint src/probnum/filtsmooth --disable="no-member,arguments-differ,too-many-arguments,too-many-locals,too-few-public-methods,protected-access,unused-variable,unused-argument,no-self-use,duplicate-code,useless-param-doc" --jobs=0
pylint src/probnum/linalg --disable="no-member,abstract-method,arguments-differ,else-if-used,redefined-builtin,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unused-argument,attribute-defined-outside-init,no-else-return,no-else-raise,no-self-use,duplicate-code,missing-module-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0
pylint src/probnum/linalg --disable="no-member,abstract-method,arguments-differ,redefined-builtin,too-many-instance-attributes,too-many-arguments,too-many-locals,too-many-lines,too-many-statements,too-many-branches,too-complex,too-few-public-methods,protected-access,unused-argument,attribute-defined-outside-init,no-else-raise,no-self-use,else-if-used,duplicate-code,missing-module-docstring,missing-param-doc,missing-type-doc,missing-raises-doc,missing-return-type-doc" --jobs=0
pylint src/probnum/linops --disable="too-many-instance-attributes,too-many-arguments,too-many-locals,protected-access,no-else-return,no-else-raise,else-if-used,missing-class-docstring,missing-function-docstring,missing-raises-doc,duplicate-code" --jobs=0
pylint src/probnum/problems --disable="too-many-arguments,too-many-locals,unused-variable,unused-argument,consider-using-from-import,duplicate-code,missing-module-docstring,missing-function-docstring,missing-param-doc,missing-type-doc,missing-raises-doc" --jobs=0
pylint src/probnum/quad --disable="too-many-arguments,missing-module-docstring" --jobs=0
Expand Down