-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PLS Refactor] Linear solver stopping criteria (#499)
* information operators added * . * added tests for information ops * test parametrization for states * fixed tests * stopping criteria added * pylint fix * removed unnecessary flatten statements * basic test added * . * one more test added * Update src/probnum/linalg/solvers/stopping_criteria/__init__.py Co-authored-by: Marvin Pförtner <[email protected]> * Update src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py Co-authored-by: Marvin Pförtner <[email protected]> * Update src/probnum/linalg/solvers/stopping_criteria/_maxiter.py Co-authored-by: Marvin Pförtner <[email protected]> * implemented review comments Co-authored-by: Marvin Pförtner <[email protected]> Co-authored-by: Marvin Pförtner <[email protected]>
- Loading branch information
1 parent
879cdfe
commit 0dc0878
Showing
17 changed files
with
346 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,8 @@ probnum.linalg.solvers | |
:hidden: | ||
|
||
solvers.information_ops | ||
|
||
.. toctree:: | ||
:hidden: | ||
|
||
solvers.stopping_criteria |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Stopping Criteria | ||
----------------- | ||
.. automodapi:: probnum.linalg.solvers.stopping_criteria | ||
:no-heading: | ||
:headings: "*" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""Stopping criteria for probabilistic linear solvers.""" | ||
|
||
from ._linear_solver_stopping_criterion import LinearSolverStopCrit | ||
from ._maxiter import MaxIterationsStopCrit | ||
from ._posterior_contraction import PosteriorContractionStopCrit | ||
from ._residual_norm import ResidualNormStopCrit | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"LinearSolverStopCrit", | ||
"MaxIterationsStopCrit", | ||
"ResidualNormStopCrit", | ||
"PosteriorContractionStopCrit", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
LinearSolverStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" | ||
MaxIterationsStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" | ||
ResidualNormStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" | ||
PosteriorContractionStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" |
31 changes: 31 additions & 0 deletions
31
src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Base class for linear solver stopping criteria.""" | ||
|
||
import abc | ||
|
||
import probnum # pylint: disable="unused-import" | ||
|
||
|
||
class LinearSolverStopCrit(abc.ABC): | ||
r"""Stopping criterion of a (probabilistic) linear solver. | ||
Checks whether quantities tracked by the :class:`~probnum.linalg.solvers.ProbabilisticLinearSolverState` meet a desired terminal condition. | ||
See Also | ||
-------- | ||
ResidualNormStopCrit : Stop based on the norm of the residual. | ||
PosteriorContractionStopCrit : Stop based on the uncertainty about the quantity of interest. | ||
MaxIterationsStopCrit : Stop after a maximum number of iterations. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" | ||
) -> bool: | ||
"""Check whether tracked quantities meet a desired terminal condition. | ||
Parameters | ||
---------- | ||
solver_state : | ||
Current state of the linear solver. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Stopping criterion based on a maximum number of iterations.""" | ||
from typing import Optional | ||
|
||
import probnum # pylint: disable="unused-import" | ||
|
||
from ._linear_solver_stopping_criterion import LinearSolverStopCrit | ||
|
||
|
||
class MaxIterationsStopCrit(LinearSolverStopCrit): | ||
r"""Stop after a maximum number of iterations. | ||
Stop when the solver has taken a maximum number of steps. If ``None`` is | ||
specified, defaults to :math:`10n`, where :math:`n` is the dimension | ||
of the solution to the linear system. | ||
Parameters | ||
---------- | ||
maxiter : | ||
Maximum number of steps the solver should take. | ||
""" | ||
|
||
def __init__(self, maxiter: Optional[int] = None): | ||
self.maxiter = maxiter | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" | ||
) -> bool: | ||
"""Check whether the maximum number of iterations has been reached. | ||
Parameters | ||
---------- | ||
solver_state : | ||
Current state of the linear solver. | ||
""" | ||
if self.maxiter is None: | ||
self.maxiter = solver_state.problem.A.shape[0] * 10 | ||
|
||
return solver_state.step >= self.maxiter |
56 changes: 56 additions & 0 deletions
56
src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Stopping criterion based on the uncertainty about a quantity of interest.""" | ||
|
||
import numpy as np | ||
|
||
import probnum # pylint: disable="unused-import" | ||
from probnum.typing import ScalarArgType | ||
|
||
from ._linear_solver_stopping_criterion import LinearSolverStopCrit | ||
|
||
|
||
class PosteriorContractionStopCrit(LinearSolverStopCrit): | ||
r"""Posterior contraction stopping criterion. | ||
Terminate when the uncertainty about the quantity of interest :math:`q` is | ||
sufficiently small, i.e. if :math:`\sqrt{\operatorname{tr}(\mathbb{Cov}(q))} | ||
\leq \max(\text{atol}, \text{rtol} \lVert b \rVert_2)`, where :math:`q` is either | ||
the solution :math:`x`, the system matrix :math:`A` or its inverse :math:`A^{-1}`. | ||
Parameters | ||
---------- | ||
qoi : | ||
Quantity of interest. One of ``{"x", "A", "Ainv"}``. | ||
atol : | ||
Absolute tolerance. | ||
rtol : | ||
Relative tolerance. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
qoi: str = "x", | ||
atol: ScalarArgType = 10 ** -5, | ||
rtol: ScalarArgType = 10 ** -5, | ||
): | ||
self.qoi = qoi | ||
self.atol = probnum.utils.as_numpy_scalar(atol) | ||
self.rtol = probnum.utils.as_numpy_scalar(rtol) | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" | ||
) -> bool: | ||
"""Check whether the uncertainty about the quantity of interest is smaller than | ||
the specified tolerance. | ||
Parameters | ||
---------- | ||
solver_state : | ||
Current state of the linear solver. | ||
""" | ||
trace_cov_qoi = getattr(solver_state.belief, self.qoi).cov.trace() | ||
b_norm = np.linalg.norm(solver_state.problem.b, ord=2) | ||
|
||
return ( | ||
np.abs(trace_cov_qoi) <= self.atol ** 2 | ||
or np.abs(trace_cov_qoi) <= (self.rtol * b_norm) ** 2 | ||
) |
46 changes: 46 additions & 0 deletions
46
src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Stopping criterion based on the norm of the residual.""" | ||
|
||
import numpy as np | ||
|
||
import probnum | ||
from probnum.typing import ScalarArgType | ||
|
||
from ._linear_solver_stopping_criterion import LinearSolverStopCrit | ||
|
||
|
||
class ResidualNormStopCrit(LinearSolverStopCrit): | ||
r"""Residual stopping criterion. | ||
Terminate when the euclidean norm of the residual :math:`r_{i} = A x_{i} - b` is | ||
sufficiently small, i.e. if it satisfies :math:`\lVert r_i \rVert_2 \leq \max( | ||
\text{atol}, \text{rtol} \lVert b \rVert_2)`. | ||
Parameters | ||
---------- | ||
atol : | ||
Absolute tolerance. | ||
rtol : | ||
Relative tolerance. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
atol: ScalarArgType = 10 ** -5, | ||
rtol: ScalarArgType = 10 ** -5, | ||
): | ||
self.atol = probnum.utils.as_numpy_scalar(atol) | ||
self.rtol = probnum.utils.as_numpy_scalar(rtol) | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" | ||
) -> bool: | ||
"""Check whether the residual norm is smaller than the specified tolerance. | ||
Parameters | ||
---------- | ||
solver_state : | ||
Current state of the linear solver. | ||
""" | ||
residual_norm = np.linalg.norm(solver_state.residual, ord=2) | ||
b_norm = np.linalg.norm(solver_state.problem.b, ord=2) | ||
return residual_norm <= self.atol or residual_norm <= self.rtol * b_norm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Stopping criteria test cases.""" | ||
|
||
from pytest_cases import parametrize | ||
|
||
from probnum.linalg.solvers import stopping_criteria | ||
|
||
|
||
def case_maxiter(): | ||
return stopping_criteria.MaxIterationsStopCrit() | ||
|
||
|
||
def case_residual_norm(): | ||
return stopping_criteria.ResidualNormStopCrit() | ||
|
||
|
||
@parametrize("qoi", ["x", "Ainv", "A"]) | ||
def case_posterior_contraction(qoi: str): | ||
return stopping_criteria.PosteriorContractionStopCrit(qoi=qoi) |
Empty file.
20 changes: 20 additions & 0 deletions
20
.../test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""Tests for stopping criteria of linear solvers.""" | ||
|
||
import pathlib | ||
|
||
from pytest_cases import parametrize_with_cases | ||
|
||
from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria | ||
|
||
case_modules = (pathlib.Path(__file__).parent / "cases").stem | ||
cases_stopping_criteria = case_modules + ".stopping_criteria" | ||
cases_states = case_modules + ".states" | ||
|
||
|
||
@parametrize_with_cases("stop_crit", cases=cases_stopping_criteria) | ||
@parametrize_with_cases("state", cases=cases_states) | ||
def test_returns_bool( | ||
stop_crit: stopping_criteria.LinearSolverStopCrit, | ||
state: ProbabilisticLinearSolverState, | ||
): | ||
assert stop_crit(solver_state=state) in [True, False] |
24 changes: 24 additions & 0 deletions
24
tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Tests for the maximum iterations stopping criterion.""" | ||
|
||
import pathlib | ||
|
||
from pytest_cases import parametrize_with_cases | ||
|
||
from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria | ||
|
||
case_modules = (pathlib.Path(__file__).parent / "cases").stem | ||
cases_stopping_criteria = case_modules + ".stopping_criteria" | ||
cases_states = case_modules + ".states" | ||
|
||
|
||
@parametrize_with_cases("state", cases=cases_states, glob="*initial_state") | ||
def test_maxiter_None(state: ProbabilisticLinearSolverState): | ||
"""Test whether if ``maxiter=None``, the maximum number of iterations is set to | ||
:math:`10n`, where :math:`n` is the dimension of the linear system.""" | ||
stop_crit = stopping_criteria.MaxIterationsStopCrit() | ||
|
||
for _ in range(10 * state.problem.A.shape[1]): | ||
assert not stop_crit(state) | ||
state.next_step() | ||
|
||
assert stop_crit(state) |
22 changes: 22 additions & 0 deletions
22
tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
"""Tests for the posterior contraction stopping criterion.""" | ||
|
||
import pathlib | ||
|
||
from pytest_cases import parametrize_with_cases | ||
|
||
from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria | ||
|
||
case_modules = (pathlib.Path(__file__).parent / "cases").stem | ||
cases_stopping_criteria = case_modules + ".stopping_criteria" | ||
cases_states = case_modules + ".states" | ||
|
||
|
||
@parametrize_with_cases( | ||
"stop_crit", cases=cases_stopping_criteria, glob="*posterior_contraction" | ||
) | ||
@parametrize_with_cases("state", cases=cases_states, glob="*converged") | ||
def test_has_converged( | ||
stop_crit: stopping_criteria.LinearSolverStopCrit, | ||
state: ProbabilisticLinearSolverState, | ||
): | ||
assert stop_crit(solver_state=state) |
Oops, something went wrong.