-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reorthogonalization for a stable implementation of the probabilistic linear solver #580
Changes from 70 commits
83c5543
34521cc
21f662b
2887933
7251733
dadf4ce
52a0dae
eef8718
e8e454c
d3d3fef
43e84c7
232dfba
dc1a12b
60e87cd
b0ee8b2
00d9034
64e8108
23e6943
cadf13a
71fb005
f914ab0
3dcbf7d
32478a1
1dcd2c0
58bae7a
c605351
ef68df6
991ffcb
625360c
ddfe933
e475cac
6ef6417
2274cce
d736274
7425fca
d08a4bc
8d76cf5
5ad46d5
e72a69d
8f6ea6f
c4eec0d
e78abfb
b7c3b35
bb60374
79c9942
727c1a5
3086977
3f9876a
4b08072
dce0865
816b2cd
077a2ee
f8cc01b
419b7f5
a8417a6
399b9f4
df2a184
931925d
863d1f2
af05bb2
8fde152
3d547ae
7c75438
50d5cd7
2120c31
fd9c53c
ac7271a
3371bd7
e58c168
f270581
ff353b4
60df256
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
"""State of a probabilistic linear solver.""" | ||
|
||
import dataclasses | ||
from typing import Any, List, Optional, Tuple | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
|
@@ -31,20 +31,23 @@ def __init__( | |
prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief", | ||
rng: Optional[np.random.Generator] = None, | ||
): | ||
self.rng: Optional[np.random.Generator] = rng | ||
self.problem: problems.LinearSystem = problem | ||
|
||
# Belief | ||
self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior | ||
self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior | ||
|
||
self._step: int = 0 | ||
|
||
# Caches | ||
self._actions: List[np.ndarray] = [None] | ||
self._observations: List[Any] = [None] | ||
self._residuals: List[np.ndarray] = [ | ||
self.problem.A @ self.belief.x.mean - self.problem.b, | ||
None, | ||
] | ||
self.rng: Optional[np.random.Generator] = rng | ||
self.cache: Dict[str, Any] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we make this private and offer a setter and a getter function for each cache entry. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding all the feedback points for the state, I opened a new issue #622 which collects them. They are all tightly coupled with the design of the linear solver state and will likely generate a large diff unrelated to this PR. Therefore I would like to keep them separate. |
||
|
||
# Solver info | ||
self._step: int = 0 | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(step={self.step})" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
"""Policy returning :math:`A`-conjugate actions.""" | ||
|
||
from typing import Callable, Iterable, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
import probnum # pylint: disable="unused-import" | ||
from probnum import linops, randvars | ||
|
||
from . import _linear_solver_policy | ||
|
||
|
@@ -11,21 +14,104 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy): | |
r"""Policy returning :math:`A`-conjugate actions. | ||
|
||
Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`. | ||
|
||
Parameters | ||
---------- | ||
reorthogonalization_fn_residual | ||
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` | ||
the residuals are reorthogonalized before the action is computed. | ||
reorthogonalization_fn_action | ||
Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` | ||
the computed action is reorthogonalized. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
reorthogonalization_fn_residual: Optional[ | ||
Callable[ | ||
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray | ||
] | ||
] = None, | ||
reorthogonalization_fn_action: Optional[ | ||
Callable[ | ||
[np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray | ||
] | ||
] = None, | ||
) -> None: | ||
self._reorthogonalization_fn_residual = reorthogonalization_fn_residual | ||
self._reorthogonalization_fn_action = reorthogonalization_fn_action | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.LinearSolverState" | ||
) -> np.ndarray: | ||
|
||
action = -solver_state.residual.copy() | ||
residual = solver_state.residual | ||
|
||
if self._reorthogonalization_fn_residual is not None and solver_state.step == 0: | ||
solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual] | ||
Comment on lines
+50
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These "initial checks" can be avoided and made less error-prone by turning |
||
|
||
if solver_state.step > 0: | ||
# Reorthogonalization of the residual | ||
if self._reorthogonalization_fn_residual is not None: | ||
residual, prev_residual = self._reorthogonalized_residuals( | ||
solver_state=solver_state | ||
) | ||
else: | ||
residual = solver_state.residual | ||
prev_residual = solver_state.residuals[solver_state.step - 1] | ||
|
||
# A-conjugacy correction (in exact arithmetic) | ||
beta = ( | ||
np.linalg.norm(solver_state.residual) | ||
/ np.linalg.norm(solver_state.residuals[solver_state.step - 1]) | ||
) ** 2 | ||
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 | ||
action = -residual + beta * solver_state.actions[solver_state.step - 1] | ||
|
||
JonathanWenger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Reorthogonalization of the resulting action | ||
if self._reorthogonalization_fn_action is not None: | ||
return self._reorthogonalized_action( | ||
action=action, solver_state=solver_state | ||
) | ||
|
||
action += beta * solver_state.actions[solver_state.step - 1] | ||
else: | ||
action = -residual | ||
|
||
return action | ||
|
||
def _reorthogonalized_residuals( | ||
self, | ||
solver_state: "probnum.linalg.solvers.LinearSolverState", | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Compute the reorthogonalized residual and its predecessor.""" | ||
residual = self._reorthogonalization_fn_residual( | ||
v=solver_state.residual, | ||
orthogonal_basis=np.asarray( | ||
solver_state.cache["reorthogonalized_residuals"] | ||
), | ||
inner_product=None, | ||
) | ||
solver_state.cache["reorthogonalized_residuals"].append(residual) | ||
prev_residual = solver_state.cache["reorthogonalized_residuals"][ | ||
solver_state.step - 1 | ||
] | ||
return residual, prev_residual | ||
|
||
def _reorthogonalized_action( | ||
self, | ||
action: np.ndarray, | ||
solver_state: "probnum.linalg.solvers.LinearSolverState", | ||
) -> np.ndarray: | ||
"""Reorthogonalize the computed action.""" | ||
if isinstance(solver_state.prior.x, randvars.Normal): | ||
inprod_matrix = ( | ||
solver_state.problem.A | ||
@ solver_state.prior.x.cov | ||
@ solver_state.problem.A.T | ||
) | ||
elif isinstance(solver_state.prior.x, randvars.Constant): | ||
inprod_matrix = solver_state.problem.A | ||
|
||
orthogonal_basis = np.asarray(solver_state.actions[0 : solver_state.step]) | ||
|
||
return self._reorthogonalization_fn_action( | ||
v=action, | ||
orthogonal_basis=orthogonal_basis, | ||
inner_product=inprod_matrix, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,15 @@ | ||
"""Utility functions that involve numerical linear algebra.""" | ||
|
||
from ._cholesky_updates import cholesky_update, tril_to_positive_tril | ||
from ._inner_product import induced_norm, inner_product | ||
from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt | ||
|
||
__all__ = ["cholesky_update", "tril_to_positive_tril"] | ||
__all__ = [ | ||
"inner_product", | ||
"induced_norm", | ||
"cholesky_update", | ||
"tril_to_positive_tril", | ||
"gram_schmidt", | ||
"modified_gram_schmidt", | ||
"double_gram_schmidt", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Functions defining useful inner products.""" | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Optional, Union | ||
|
||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
from probnum import linops | ||
|
||
|
||
JonathanWenger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def inner_product( | ||
v: np.ndarray, | ||
w: np.ndarray, | ||
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, | ||
) -> np.ndarray: | ||
r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`. | ||
|
||
For n-d arrays the function computes the inner product over the last axis of the | ||
two arrays ``v`` and ``w``. | ||
|
||
Parameters | ||
---------- | ||
v | ||
First array. | ||
w | ||
Second array. | ||
A | ||
Symmetric positive (semi-)definite matrix defining the geometry. | ||
|
||
Returns | ||
------- | ||
inprod : | ||
Inner product(s) of ``v`` and ``w``. | ||
|
||
Notes | ||
----- | ||
Note that the broadcasting behavior of :func:`inner_product` differs from :func:`numpy.inner`. Rather it follows the broadcasting rules of :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors. | ||
""" | ||
v_T = v[..., None, :] | ||
w = w[..., :, None] | ||
|
||
if A is None: | ||
vw_inprod = v_T @ w | ||
else: | ||
vw_inprod = v_T @ (A @ w) | ||
|
||
return np.squeeze(vw_inprod, axis=(-2, -1)) | ||
|
||
|
||
def induced_norm( | ||
v: np.ndarray, | ||
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, | ||
axis: int = -1, | ||
) -> np.ndarray: | ||
r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`. | ||
|
||
Computes the induced norm over the given axis of the array. | ||
|
||
Parameters | ||
---------- | ||
v | ||
Array. | ||
A | ||
Symmetric positive (semi-)definite linear operator defining the geometry. | ||
axis | ||
Specifies the axis along which to compute the vector norms. | ||
|
||
Returns | ||
------- | ||
norm : | ||
Vector norm of ``v`` along the given ``axis``. | ||
""" | ||
|
||
if A is None: | ||
return np.linalg.norm(v, ord=2, axis=axis, keepdims=False) | ||
|
||
v = np.moveaxis(v, axis, -1) | ||
w = np.squeeze(A @ v[..., :, None], axis=-1) | ||
|
||
return np.sqrt(np.sum(v * w, axis=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this to be part of the state? Considering the
ad-prototype
I would suggest passing it to thesolve_iterator
and forward it to the component which requires an rng, e.g. the policy.