Skip to content

Commit

Permalink
Reorthogonalization for a stable implementation of the probabilistic …
Browse files Browse the repository at this point in the history
…linear solver (#580)

* initial interface

* minor doc improvements

* first methods added to PLS

* pylint errors fixed

* .

* type hint fixed

* corrected generator type hint

* made prior part of solve methods

* solve method implemented

* pylint fix

* better documentation

* minor doc fixes

* solve iterator

* initial draft of doctest

* doctest added

* .

* debugging attempt

* solvers added as classes

* tests on random spd matrices

* bugfix solution-based belief update

* fixed notebooks

* belief fixes for matrix-based soolvers

* test added and bugfix

* doctest fixed

* jupyter notebook fix

* symmetric belief update fixed

* precision in doctest reduced

* .

* tests fixed

* perfect information tests added and docs improved for solver state

* orthogonalized residuals for bayescg

* test for orthogonalization functions run

* test for orthogonalization functions run

* improved and generalized implementation of gram-schmidt

* CG with reorthogonalization

* test for reorthogonalization

* .

* reorthogonalization only done in the policy for now

* extracted reorthogonalized action function

* tests run

* minor fix to the state

* test dependency via fixture fixed

* tests fixed + some fixture scope adjustment

* more tests for orthogonalization functions

* pylint fixes

* renamed inner product

* extended inner product to arbitrary arrays

* orthogonalization tests work

* only test double gram schmidt

* test for noneuclidean inner product

* fixed induced norm

* started vectorizing the reorthogonalization functions

* fixed orthogonalization

* fixed bug in CG search dirs where residuals where not actually reorthogonalized

* simplified CG policy

* matmul style broadcasting for inner_product

* added broadcasting to orthogonalization
  • Loading branch information
JonathanWenger authored Feb 11, 2022
1 parent 49f26ae commit 2c5cce8
Show file tree
Hide file tree
Showing 14 changed files with 690 additions and 29 deletions.
13 changes: 8 additions & 5 deletions src/probnum/linalg/solvers/_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""State of a probabilistic linear solver."""

import dataclasses
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -31,20 +31,23 @@ def __init__(
prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief",
rng: Optional[np.random.Generator] = None,
):
self.rng: Optional[np.random.Generator] = rng
self.problem: problems.LinearSystem = problem

# Belief
self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior
self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior

self._step: int = 0

# Caches
self._actions: List[np.ndarray] = [None]
self._observations: List[Any] = [None]
self._residuals: List[np.ndarray] = [
self.problem.A @ self.belief.x.mean - self.problem.b,
None,
]
self.rng: Optional[np.random.Generator] = rng
self.cache: Dict[str, Any] = {}

# Solver info
self._step: int = 0

def __repr__(self) -> str:
return f"{self.__class__.__name__}(step={self.step})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

# Compute projected residual
action_A = solver_state.action @ solver_state.problem.A
pred = action_A @ solver_state.belief.x.mean
proj_resid = solver_state.observation - pred

# Compute gain and covariance update
cov_xy = solver_state.belief.x.cov @ action_A.T
gram = action_A @ cov_xy + self._noise_var
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
Expand Down
13 changes: 12 additions & 1 deletion src/probnum/linalg/solvers/beliefs/_linear_system_belief.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from typing import Mapping, Optional

from probnum import randvars
from probnum import linops, randvars

# pylint: disable="invalid-name"

Expand Down Expand Up @@ -134,6 +134,8 @@ def A(self) -> randvars.RandomVariable:
@property
def Ainv(self) -> Optional[randvars.RandomVariable]:
"""Belief about the (pseudo-)inverse of the system matrix."""
if self._Ainv is None:
return self._induced_Ainv()
return self._Ainv

@property
Expand All @@ -149,3 +151,12 @@ def _induced_x(self) -> randvars.RandomVariable:
:math:`H` and :math:`b`.
"""
return self.Ainv @ self.b

def _induced_Ainv(self) -> randvars.RandomVariable:
r"""Induced belief about the inverse from a belief about the solution.
Computes a consistent belief about the inverse from a belief about the solution.
"""
return randvars.Constant(
linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0]))
)
98 changes: 92 additions & 6 deletions src/probnum/linalg/solvers/policies/_conjugate_gradient.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Policy returning :math:`A`-conjugate actions."""

from typing import Callable, Iterable, Optional, Tuple

import numpy as np

import probnum # pylint: disable="unused-import"
from probnum import linops, randvars

from . import _linear_solver_policy

Expand All @@ -11,21 +14,104 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy):
r"""Policy returning :math:`A`-conjugate actions.
Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`.
Parameters
----------
reorthogonalization_fn_residual
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None`
the residuals are reorthogonalized before the action is computed.
reorthogonalization_fn_action
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None`
the computed action is reorthogonalized.
"""

def __init__(
self,
reorthogonalization_fn_residual: Optional[
Callable[
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
]
] = None,
reorthogonalization_fn_action: Optional[
Callable[
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
]
] = None,
) -> None:
self._reorthogonalization_fn_residual = reorthogonalization_fn_residual
self._reorthogonalization_fn_action = reorthogonalization_fn_action

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> np.ndarray:

action = -solver_state.residual.copy()
residual = solver_state.residual

if self._reorthogonalization_fn_residual is not None and solver_state.step == 0:
solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual]

if solver_state.step > 0:
# Reorthogonalization of the residual
if self._reorthogonalization_fn_residual is not None:
residual, prev_residual = self._reorthogonalized_residuals(
solver_state=solver_state
)
else:
residual = solver_state.residual
prev_residual = solver_state.residuals[solver_state.step - 1]

# A-conjugacy correction (in exact arithmetic)
beta = (
np.linalg.norm(solver_state.residual)
/ np.linalg.norm(solver_state.residuals[solver_state.step - 1])
) ** 2
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2
action = -residual + beta * solver_state.actions[solver_state.step - 1]

# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
return self._reorthogonalized_action(
action=action, solver_state=solver_state
)

action += beta * solver_state.actions[solver_state.step - 1]
else:
action = -residual

return action

def _reorthogonalized_residuals(
self,
solver_state: "probnum.linalg.solvers.LinearSolverState",
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the reorthogonalized residual and its predecessor."""
residual = self._reorthogonalization_fn_residual(
v=solver_state.residual,
orthogonal_basis=np.asarray(
solver_state.cache["reorthogonalized_residuals"]
),
inner_product=None,
)
solver_state.cache["reorthogonalized_residuals"].append(residual)
prev_residual = solver_state.cache["reorthogonalized_residuals"][
solver_state.step - 1
]
return residual, prev_residual

def _reorthogonalized_action(
self,
action: np.ndarray,
solver_state: "probnum.linalg.solvers.LinearSolverState",
) -> np.ndarray:
"""Reorthogonalize the computed action."""
if isinstance(solver_state.prior.x, randvars.Normal):
inprod_matrix = (
solver_state.problem.A
@ solver_state.prior.x.cov
@ solver_state.problem.A.T
)
elif isinstance(solver_state.prior.x, randvars.Constant):
inprod_matrix = solver_state.problem.A

orthogonal_basis = np.asarray(solver_state.actions[0 : solver_state.step])

return self._reorthogonalization_fn_action(
v=action,
orthogonal_basis=orthogonal_basis,
inner_product=inprod_matrix,
)
12 changes: 11 additions & 1 deletion src/probnum/utils/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
"""Utility functions that involve numerical linear algebra."""

from ._cholesky_updates import cholesky_update, tril_to_positive_tril
from ._inner_product import induced_norm, inner_product
from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt

__all__ = ["cholesky_update", "tril_to_positive_tril"]
__all__ = [
"inner_product",
"induced_norm",
"cholesky_update",
"tril_to_positive_tril",
"gram_schmidt",
"modified_gram_schmidt",
"double_gram_schmidt",
]
81 changes: 81 additions & 0 deletions src/probnum/utils/linalg/_inner_product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Functions defining useful inner products."""
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union

import numpy as np

if TYPE_CHECKING:
from probnum import linops


def inner_product(
v: np.ndarray,
w: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
) -> np.ndarray:
r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`.
For n-d arrays the function computes the inner product over the last axis of the
two arrays ``v`` and ``w``.
Parameters
----------
v
First array.
w
Second array.
A
Symmetric positive (semi-)definite matrix defining the geometry.
Returns
-------
inprod :
Inner product(s) of ``v`` and ``w``.
Notes
-----
Note that the broadcasting behavior of :func:`inner_product` differs from :func:`numpy.inner`. Rather it follows the broadcasting rules of :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors.
"""
v_T = v[..., None, :]
w = w[..., :, None]

if A is None:
vw_inprod = v_T @ w
else:
vw_inprod = v_T @ (A @ w)

return np.squeeze(vw_inprod, axis=(-2, -1))


def induced_norm(
v: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
axis: int = -1,
) -> np.ndarray:
r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`.
Computes the induced norm over the given axis of the array.
Parameters
----------
v
Array.
A
Symmetric positive (semi-)definite linear operator defining the geometry.
axis
Specifies the axis along which to compute the vector norms.
Returns
-------
norm :
Vector norm of ``v`` along the given ``axis``.
"""

if A is None:
return np.linalg.norm(v, ord=2, axis=axis, keepdims=False)

v = np.moveaxis(v, axis, -1)
w = np.squeeze(A @ v[..., :, None], axis=-1)

return np.sqrt(np.sum(v * w, axis=-1))
Loading

0 comments on commit 2c5cce8

Please sign in to comment.