Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorthogonalization for a stable implementation of the probabilistic linear solver #580

Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
83c5543
initial interface
JonathanWenger Nov 8, 2021
34521cc
minor doc improvements
JonathanWenger Nov 9, 2021
21f662b
first methods added to PLS
JonathanWenger Nov 9, 2021
2887933
pylint errors fixed
JonathanWenger Nov 9, 2021
7251733
.
JonathanWenger Nov 9, 2021
dadf4ce
type hint fixed
JonathanWenger Nov 9, 2021
52a0dae
corrected generator type hint
JonathanWenger Nov 9, 2021
eef8718
made prior part of solve methods
JonathanWenger Nov 9, 2021
e8e454c
solve method implemented
JonathanWenger Nov 9, 2021
d3d3fef
pylint fix
JonathanWenger Nov 9, 2021
43e84c7
better documentation
JonathanWenger Nov 9, 2021
232dfba
minor doc fixes
JonathanWenger Nov 9, 2021
dc1a12b
solve iterator
JonathanWenger Nov 10, 2021
60e87cd
initial draft of doctest
JonathanWenger Nov 10, 2021
b0ee8b2
doctest added
JonathanWenger Nov 10, 2021
00d9034
.
JonathanWenger Nov 10, 2021
64e8108
debugging attempt
JonathanWenger Nov 10, 2021
23e6943
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 17, 2021
cadf13a
solvers added as classes
JonathanWenger Nov 17, 2021
71fb005
tests on random spd matrices
JonathanWenger Nov 17, 2021
f914ab0
bugfix solution-based belief update
JonathanWenger Nov 18, 2021
3dcbf7d
fixed notebooks
JonathanWenger Nov 18, 2021
32478a1
belief fixes for matrix-based soolvers
JonathanWenger Nov 18, 2021
1dcd2c0
test added and bugfix
JonathanWenger Nov 18, 2021
58bae7a
Merge branch 'normal-matmul' into linsolve-pls
JonathanWenger Nov 18, 2021
c605351
doctest fixed
JonathanWenger Nov 18, 2021
ef68df6
jupyter notebook fix
JonathanWenger Nov 18, 2021
991ffcb
symmetric belief update fixed
JonathanWenger Nov 18, 2021
625360c
precision in doctest reduced
JonathanWenger Nov 18, 2021
ddfe933
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 19, 2021
e475cac
.
JonathanWenger Nov 23, 2021
6ef6417
tests fixed
JonathanWenger Nov 23, 2021
2274cce
perfect information tests added and docs improved for solver state
JonathanWenger Nov 24, 2021
d736274
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 24, 2021
7425fca
orthogonalized residuals for bayescg
JonathanWenger Nov 24, 2021
d08a4bc
test for orthogonalization functions run
JonathanWenger Nov 25, 2021
8d76cf5
test for orthogonalization functions run
JonathanWenger Nov 25, 2021
5ad46d5
merged main
JonathanWenger Nov 30, 2021
e72a69d
improved and generalized implementation of gram-schmidt
JonathanWenger Dec 9, 2021
8f6ea6f
CG with reorthogonalization
JonathanWenger Dec 9, 2021
c4eec0d
test for reorthogonalization
JonathanWenger Dec 9, 2021
e78abfb
.
JonathanWenger Dec 9, 2021
b7c3b35
reorthogonalization only done in the policy for now
JonathanWenger Dec 9, 2021
bb60374
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 9, 2021
79c9942
extracted reorthogonalized action function
JonathanWenger Dec 9, 2021
727c1a5
tests run
JonathanWenger Dec 9, 2021
3086977
minor fix to the state
JonathanWenger Dec 9, 2021
3f9876a
test dependency via fixture fixed
JonathanWenger Dec 17, 2021
4b08072
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 17, 2021
dce0865
tests fixed + some fixture scope adjustment
JonathanWenger Dec 17, 2021
816b2cd
more tests for orthogonalization functions
JonathanWenger Dec 17, 2021
077a2ee
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 17, 2021
f8cc01b
pylint fixes
JonathanWenger Dec 17, 2021
419b7f5
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 20, 2021
a8417a6
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 12, 2022
399b9f4
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 28, 2022
df2a184
renamed inner product
JonathanWenger Jan 31, 2022
931925d
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 31, 2022
863d1f2
extended inner product to arbitrary arrays
JonathanWenger Jan 31, 2022
af05bb2
orthogonalization tests work
JonathanWenger Jan 31, 2022
8fde152
only test double gram schmidt
JonathanWenger Jan 31, 2022
3d547ae
test for noneuclidean inner product
JonathanWenger Jan 31, 2022
7c75438
fixed induced norm
JonathanWenger Jan 31, 2022
50d5cd7
started vectorizing the reorthogonalization functions
JonathanWenger Jan 31, 2022
2120c31
fixed orthogonalization
JonathanWenger Jan 31, 2022
fd9c53c
fixed bug in CG search dirs where residuals where not actually reorth…
JonathanWenger Jan 31, 2022
ac7271a
simplified CG policy
JonathanWenger Feb 10, 2022
3371bd7
matmul style broadcasting for inner_product
JonathanWenger Feb 10, 2022
e58c168
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 10, 2022
f270581
added broadcasting to orthogonalization
JonathanWenger Feb 10, 2022
ff353b4
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 11, 2022
60df256
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/probnum/linalg/solvers/_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""State of a probabilistic linear solver."""

import dataclasses
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -31,20 +31,23 @@ def __init__(
prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief",
rng: Optional[np.random.Generator] = None,
):
self.rng: Optional[np.random.Generator] = rng
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this to be part of the state? Considering the ad-prototype I would suggest passing it to the solve_iterator and forward it to the component which requires an rng, e.g. the policy.

self.problem: problems.LinearSystem = problem

# Belief
self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior
self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior

self._step: int = 0

# Caches
self._actions: List[np.ndarray] = [None]
self._observations: List[Any] = [None]
self._residuals: List[np.ndarray] = [
self.problem.A @ self.belief.x.mean - self.problem.b,
None,
]
self.rng: Optional[np.random.Generator] = rng
self.cache: Dict[str, Any] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we make this private and offer a setter and a getter function for each cache entry.
That way, we can ensure that the cache is only ever be written once per iteration (in the getter).
Also, we can prevent modification of previous cache entries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding all the feedback points for the state, I opened a new issue #622 which collects them. They are all tightly coupled with the design of the linear solver state and will likely generate a large diff unrelated to this PR. Therefore I would like to keep them separate.


# Solver info
self._step: int = 0

def __repr__(self) -> str:
return f"{self.__class__.__name__}(step={self.step})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

# Compute projected residual
action_A = solver_state.action @ solver_state.problem.A
pred = action_A @ solver_state.belief.x.mean
proj_resid = solver_state.observation - pred

# Compute gain and covariance update
cov_xy = solver_state.belief.x.cov @ action_A.T
gram = action_A @ cov_xy + self._noise_var
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
Expand Down
13 changes: 12 additions & 1 deletion src/probnum/linalg/solvers/beliefs/_linear_system_belief.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from typing import Mapping, Optional

from probnum import randvars
from probnum import linops, randvars

# pylint: disable="invalid-name"

Expand Down Expand Up @@ -134,6 +134,8 @@ def A(self) -> randvars.RandomVariable:
@property
def Ainv(self) -> Optional[randvars.RandomVariable]:
"""Belief about the (pseudo-)inverse of the system matrix."""
if self._Ainv is None:
return self._induced_Ainv()
return self._Ainv

@property
Expand All @@ -149,3 +151,12 @@ def _induced_x(self) -> randvars.RandomVariable:
:math:`H` and :math:`b`.
"""
return self.Ainv @ self.b

def _induced_Ainv(self) -> randvars.RandomVariable:
r"""Induced belief about the inverse from a belief about the solution.

Computes a consistent belief about the inverse from a belief about the solution.
"""
return randvars.Constant(
linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0]))
)
98 changes: 92 additions & 6 deletions src/probnum/linalg/solvers/policies/_conjugate_gradient.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Policy returning :math:`A`-conjugate actions."""

from typing import Callable, Iterable, Optional, Tuple

import numpy as np

import probnum # pylint: disable="unused-import"
from probnum import linops, randvars

from . import _linear_solver_policy

Expand All @@ -11,21 +14,104 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy):
r"""Policy returning :math:`A`-conjugate actions.

Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`.

Parameters
----------
reorthogonalization_fn_residual
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None`
the residuals are reorthogonalized before the action is computed.
reorthogonalization_fn_action
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None`
the computed action is reorthogonalized.
"""

def __init__(
self,
reorthogonalization_fn_residual: Optional[
Callable[
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
]
] = None,
reorthogonalization_fn_action: Optional[
Callable[
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
]
] = None,
) -> None:
self._reorthogonalization_fn_residual = reorthogonalization_fn_residual
self._reorthogonalization_fn_action = reorthogonalization_fn_action

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> np.ndarray:

action = -solver_state.residual.copy()
residual = solver_state.residual

if self._reorthogonalization_fn_residual is not None and solver_state.step == 0:
solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual]
Comment on lines +50 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These "initial checks" can be avoided and made less error-prone by turning state.cache into a collections.defaultdict(list)


if solver_state.step > 0:
# Reorthogonalization of the residual
if self._reorthogonalization_fn_residual is not None:
residual, prev_residual = self._reorthogonalized_residuals(
solver_state=solver_state
)
else:
residual = solver_state.residual
prev_residual = solver_state.residuals[solver_state.step - 1]

# A-conjugacy correction (in exact arithmetic)
beta = (
np.linalg.norm(solver_state.residual)
/ np.linalg.norm(solver_state.residuals[solver_state.step - 1])
) ** 2
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2
action = -residual + beta * solver_state.actions[solver_state.step - 1]

JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved
# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
return self._reorthogonalized_action(
action=action, solver_state=solver_state
)

action += beta * solver_state.actions[solver_state.step - 1]
else:
action = -residual

return action

def _reorthogonalized_residuals(
self,
solver_state: "probnum.linalg.solvers.LinearSolverState",
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the reorthogonalized residual and its predecessor."""
residual = self._reorthogonalization_fn_residual(
v=solver_state.residual,
orthogonal_basis=np.asarray(
solver_state.cache["reorthogonalized_residuals"]
),
inner_product=None,
)
solver_state.cache["reorthogonalized_residuals"].append(residual)
prev_residual = solver_state.cache["reorthogonalized_residuals"][
solver_state.step - 1
]
return residual, prev_residual

def _reorthogonalized_action(
self,
action: np.ndarray,
solver_state: "probnum.linalg.solvers.LinearSolverState",
) -> np.ndarray:
"""Reorthogonalize the computed action."""
if isinstance(solver_state.prior.x, randvars.Normal):
inprod_matrix = (
solver_state.problem.A
@ solver_state.prior.x.cov
@ solver_state.problem.A.T
)
elif isinstance(solver_state.prior.x, randvars.Constant):
inprod_matrix = solver_state.problem.A

orthogonal_basis = np.asarray(solver_state.actions[0 : solver_state.step])

return self._reorthogonalization_fn_action(
v=action,
orthogonal_basis=orthogonal_basis,
inner_product=inprod_matrix,
)
12 changes: 11 additions & 1 deletion src/probnum/utils/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
"""Utility functions that involve numerical linear algebra."""

from ._cholesky_updates import cholesky_update, tril_to_positive_tril
from ._inner_product import induced_norm, inner_product
from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt

__all__ = ["cholesky_update", "tril_to_positive_tril"]
__all__ = [
"inner_product",
"induced_norm",
"cholesky_update",
"tril_to_positive_tril",
"gram_schmidt",
"modified_gram_schmidt",
"double_gram_schmidt",
]
81 changes: 81 additions & 0 deletions src/probnum/utils/linalg/_inner_product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Functions defining useful inner products."""
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union

import numpy as np

if TYPE_CHECKING:
from probnum import linops


JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved
def inner_product(
v: np.ndarray,
w: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
) -> np.ndarray:
r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`.

For n-d arrays the function computes the inner product over the last axis of the
two arrays ``v`` and ``w``.

Parameters
----------
v
First array.
w
Second array.
A
Symmetric positive (semi-)definite matrix defining the geometry.

Returns
-------
inprod :
Inner product(s) of ``v`` and ``w``.

Notes
-----
Note that the broadcasting behavior of :func:`inner_product` differs from :func:`numpy.inner`. Rather it follows the broadcasting rules of :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors.
"""
v_T = v[..., None, :]
w = w[..., :, None]

if A is None:
vw_inprod = v_T @ w
else:
vw_inprod = v_T @ (A @ w)

return np.squeeze(vw_inprod, axis=(-2, -1))


def induced_norm(
v: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
axis: int = -1,
) -> np.ndarray:
r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`.

Computes the induced norm over the given axis of the array.

Parameters
----------
v
Array.
A
Symmetric positive (semi-)definite linear operator defining the geometry.
axis
Specifies the axis along which to compute the vector norms.

Returns
-------
norm :
Vector norm of ``v`` along the given ``axis``.
"""

if A is None:
return np.linalg.norm(v, ord=2, axis=axis, keepdims=False)

v = np.moveaxis(v, axis, -1)
w = np.squeeze(A @ v[..., :, None], axis=-1)

return np.sqrt(np.sum(v * w, axis=-1))
Loading