Skip to content

Commit

Permalink
[PLS Refactor] Linear solver stopping criteria (#499)
Browse files Browse the repository at this point in the history
* information operators added

* .

* added tests for information ops

* test parametrization for states

* fixed tests

* stopping criteria added

* pylint fix

* removed unnecessary flatten statements

* basic test added

* .

* one more test added

* Update src/probnum/linalg/solvers/stopping_criteria/__init__.py

Co-authored-by: Marvin Pförtner <[email protected]>

* Update src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py

Co-authored-by: Marvin Pförtner <[email protected]>

* Update src/probnum/linalg/solvers/stopping_criteria/_maxiter.py

Co-authored-by: Marvin Pförtner <[email protected]>

* implemented review comments

Co-authored-by: Marvin Pförtner <[email protected]>
Co-authored-by: Marvin Pförtner <[email protected]>
  • Loading branch information
3 people authored Sep 17, 2021
1 parent 879cdfe commit 0dc0878
Show file tree
Hide file tree
Showing 17 changed files with 346 additions and 12 deletions.
5 changes: 5 additions & 0 deletions docs/source/api/linalg/solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ probnum.linalg.solvers
:hidden:

solvers.information_ops

.. toctree::
:hidden:

solvers.stopping_criteria
5 changes: 5 additions & 0 deletions docs/source/api/linalg/solvers.stopping_criteria.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Stopping Criteria
-----------------
.. automodapi:: probnum.linalg.solvers.stopping_criteria
:no-heading:
:headings: "*"
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,5 @@ def __call__(
----------
solver_state :
Current state of the linear solver.
Returns
-------
observation :
Returns an observation of the linear system.
"""
raise NotImplementedError
8 changes: 7 additions & 1 deletion src/probnum/linalg/solvers/information_ops/_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@ class MatVecInfoOp(LinearSolverInfoOp):
def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> np.ndarray:
r"""Matrix vector product with the system matrix :math:`A`."""
r"""Matrix vector product with the system matrix :math:`A`.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
return solver_state.problem.A @ solver_state.action
8 changes: 7 additions & 1 deletion src/probnum/linalg/solvers/information_ops/_proj_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,11 @@ class ProjResidualInfoOp(LinearSolverInfoOp):
def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> np.ndarray:
r"""Projected residual :math:`s_i^\top (A x_i - b)` of the linear system."""
r"""Projected residual :math:`s_i^\top (A x_i - b)` of the linear system.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
return solver_state.action @ solver_state.residual
20 changes: 20 additions & 0 deletions src/probnum/linalg/solvers/stopping_criteria/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Stopping criteria for probabilistic linear solvers."""

from ._linear_solver_stopping_criterion import LinearSolverStopCrit
from ._maxiter import MaxIterationsStopCrit
from ._posterior_contraction import PosteriorContractionStopCrit
from ._residual_norm import ResidualNormStopCrit

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"LinearSolverStopCrit",
"MaxIterationsStopCrit",
"ResidualNormStopCrit",
"PosteriorContractionStopCrit",
]

# Set correct module paths. Corrects links and module paths in documentation.
LinearSolverStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria"
MaxIterationsStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria"
ResidualNormStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria"
PosteriorContractionStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria"
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Base class for linear solver stopping criteria."""

import abc

import probnum # pylint: disable="unused-import"


class LinearSolverStopCrit(abc.ABC):
r"""Stopping criterion of a (probabilistic) linear solver.
Checks whether quantities tracked by the :class:`~probnum.linalg.solvers.ProbabilisticLinearSolverState` meet a desired terminal condition.
See Also
--------
ResidualNormStopCrit : Stop based on the norm of the residual.
PosteriorContractionStopCrit : Stop based on the uncertainty about the quantity of interest.
MaxIterationsStopCrit : Stop after a maximum number of iterations.
"""

@abc.abstractmethod
def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> bool:
"""Check whether tracked quantities meet a desired terminal condition.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
raise NotImplementedError
38 changes: 38 additions & 0 deletions src/probnum/linalg/solvers/stopping_criteria/_maxiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Stopping criterion based on a maximum number of iterations."""
from typing import Optional

import probnum # pylint: disable="unused-import"

from ._linear_solver_stopping_criterion import LinearSolverStopCrit


class MaxIterationsStopCrit(LinearSolverStopCrit):
r"""Stop after a maximum number of iterations.
Stop when the solver has taken a maximum number of steps. If ``None`` is
specified, defaults to :math:`10n`, where :math:`n` is the dimension
of the solution to the linear system.
Parameters
----------
maxiter :
Maximum number of steps the solver should take.
"""

def __init__(self, maxiter: Optional[int] = None):
self.maxiter = maxiter

def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> bool:
"""Check whether the maximum number of iterations has been reached.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
if self.maxiter is None:
self.maxiter = solver_state.problem.A.shape[0] * 10

return solver_state.step >= self.maxiter
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Stopping criterion based on the uncertainty about a quantity of interest."""

import numpy as np

import probnum # pylint: disable="unused-import"
from probnum.typing import ScalarArgType

from ._linear_solver_stopping_criterion import LinearSolverStopCrit


class PosteriorContractionStopCrit(LinearSolverStopCrit):
r"""Posterior contraction stopping criterion.
Terminate when the uncertainty about the quantity of interest :math:`q` is
sufficiently small, i.e. if :math:`\sqrt{\operatorname{tr}(\mathbb{Cov}(q))}
\leq \max(\text{atol}, \text{rtol} \lVert b \rVert_2)`, where :math:`q` is either
the solution :math:`x`, the system matrix :math:`A` or its inverse :math:`A^{-1}`.
Parameters
----------
qoi :
Quantity of interest. One of ``{"x", "A", "Ainv"}``.
atol :
Absolute tolerance.
rtol :
Relative tolerance.
"""

def __init__(
self,
qoi: str = "x",
atol: ScalarArgType = 10 ** -5,
rtol: ScalarArgType = 10 ** -5,
):
self.qoi = qoi
self.atol = probnum.utils.as_numpy_scalar(atol)
self.rtol = probnum.utils.as_numpy_scalar(rtol)

def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> bool:
"""Check whether the uncertainty about the quantity of interest is smaller than
the specified tolerance.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
trace_cov_qoi = getattr(solver_state.belief, self.qoi).cov.trace()
b_norm = np.linalg.norm(solver_state.problem.b, ord=2)

return (
np.abs(trace_cov_qoi) <= self.atol ** 2
or np.abs(trace_cov_qoi) <= (self.rtol * b_norm) ** 2
)
46 changes: 46 additions & 0 deletions src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Stopping criterion based on the norm of the residual."""

import numpy as np

import probnum
from probnum.typing import ScalarArgType

from ._linear_solver_stopping_criterion import LinearSolverStopCrit


class ResidualNormStopCrit(LinearSolverStopCrit):
r"""Residual stopping criterion.
Terminate when the euclidean norm of the residual :math:`r_{i} = A x_{i} - b` is
sufficiently small, i.e. if it satisfies :math:`\lVert r_i \rVert_2 \leq \max(
\text{atol}, \text{rtol} \lVert b \rVert_2)`.
Parameters
----------
atol :
Absolute tolerance.
rtol :
Relative tolerance.
"""

def __init__(
self,
atol: ScalarArgType = 10 ** -5,
rtol: ScalarArgType = 10 ** -5,
):
self.atol = probnum.utils.as_numpy_scalar(atol)
self.rtol = probnum.utils.as_numpy_scalar(rtol)

def __call__(
self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState"
) -> bool:
"""Check whether the residual norm is smaller than the specified tolerance.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
residual_norm = np.linalg.norm(solver_state.residual, ord=2)
b_norm = np.linalg.norm(solver_state.problem.b, ord=2)
return residual_norm <= self.atol or residual_norm <= self.rtol * b_norm
30 changes: 25 additions & 5 deletions tests/test_linalg/test_solvers/cases/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
)

# Prior
Ainv = randvars.Normal(
mean=linops.Identity(n), cov=linops.SymmetricKronecker(linops.Identity(n))
)
b = randvars.Constant(linsys.b)
prior = linalg.solvers.beliefs.LinearSystemBelief(
A=randvars.Constant(linsys.A),
Ainv=None,
x=randvars.Normal(
mean=np.zeros(linsys.A.shape[1]), cov=linops.Identity(shape=linsys.A.shape)
),
b=randvars.Constant(linsys.b),
Ainv=Ainv,
x=(Ainv @ b[:, None]).reshape(
(n,)
), # TODO: This can be replaced by Ainv @ b once https://github.com/probabilistic-numerics/probnum/issues/456 is fixed
b=b,
)


Expand All @@ -44,3 +48,19 @@ def case_state(
initial_state.action = rng.standard_normal(size=initial_state.problem.A.shape[1])

return initial_state


def case_state_converged(
rng: np.random.Generator,
):
"""State of a linear solver, which has converged at initialization."""
belief = linalg.solvers.beliefs.LinearSystemBelief(
A=randvars.Constant(linsys.A),
Ainv=randvars.Constant(linops.aslinop(linsys.A).inv().todense()),
x=randvars.Constant(linsys.solution),
b=randvars.Constant(linsys.b),
)
state = linalg.solvers.ProbabilisticLinearSolverState(
problem=linsys, prior=belief, rng=rng
)
return state
18 changes: 18 additions & 0 deletions tests/test_linalg/test_solvers/cases/stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Stopping criteria test cases."""

from pytest_cases import parametrize

from probnum.linalg.solvers import stopping_criteria


def case_maxiter():
return stopping_criteria.MaxIterationsStopCrit()


def case_residual_norm():
return stopping_criteria.ResidualNormStopCrit()


@parametrize("qoi", ["x", "Ainv", "A"])
def case_posterior_contraction(qoi: str):
return stopping_criteria.PosteriorContractionStopCrit(qoi=qoi)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Tests for stopping criteria of linear solvers."""

import pathlib

from pytest_cases import parametrize_with_cases

from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria

case_modules = (pathlib.Path(__file__).parent / "cases").stem
cases_stopping_criteria = case_modules + ".stopping_criteria"
cases_states = case_modules + ".states"


@parametrize_with_cases("stop_crit", cases=cases_stopping_criteria)
@parametrize_with_cases("state", cases=cases_states)
def test_returns_bool(
stop_crit: stopping_criteria.LinearSolverStopCrit,
state: ProbabilisticLinearSolverState,
):
assert stop_crit(solver_state=state) in [True, False]
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for the maximum iterations stopping criterion."""

import pathlib

from pytest_cases import parametrize_with_cases

from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria

case_modules = (pathlib.Path(__file__).parent / "cases").stem
cases_stopping_criteria = case_modules + ".stopping_criteria"
cases_states = case_modules + ".states"


@parametrize_with_cases("state", cases=cases_states, glob="*initial_state")
def test_maxiter_None(state: ProbabilisticLinearSolverState):
"""Test whether if ``maxiter=None``, the maximum number of iterations is set to
:math:`10n`, where :math:`n` is the dimension of the linear system."""
stop_crit = stopping_criteria.MaxIterationsStopCrit()

for _ in range(10 * state.problem.A.shape[1]):
assert not stop_crit(state)
state.next_step()

assert stop_crit(state)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Tests for the posterior contraction stopping criterion."""

import pathlib

from pytest_cases import parametrize_with_cases

from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria

case_modules = (pathlib.Path(__file__).parent / "cases").stem
cases_stopping_criteria = case_modules + ".stopping_criteria"
cases_states = case_modules + ".states"


@parametrize_with_cases(
"stop_crit", cases=cases_stopping_criteria, glob="*posterior_contraction"
)
@parametrize_with_cases("state", cases=cases_states, glob="*converged")
def test_has_converged(
stop_crit: stopping_criteria.LinearSolverStopCrit,
state: ProbabilisticLinearSolverState,
):
assert stop_crit(solver_state=state)
Loading

0 comments on commit 0dc0878

Please sign in to comment.