-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorthogonalization for a stable implementation of the probabilistic …
…linear solver (#580) * initial interface * minor doc improvements * first methods added to PLS * pylint errors fixed * . * type hint fixed * corrected generator type hint * made prior part of solve methods * solve method implemented * pylint fix * better documentation * minor doc fixes * solve iterator * initial draft of doctest * doctest added * . * debugging attempt * solvers added as classes * tests on random spd matrices * bugfix solution-based belief update * fixed notebooks * belief fixes for matrix-based soolvers * test added and bugfix * doctest fixed * jupyter notebook fix * symmetric belief update fixed * precision in doctest reduced * . * tests fixed * perfect information tests added and docs improved for solver state * orthogonalized residuals for bayescg * test for orthogonalization functions run * test for orthogonalization functions run * improved and generalized implementation of gram-schmidt * CG with reorthogonalization * test for reorthogonalization * . * reorthogonalization only done in the policy for now * extracted reorthogonalized action function * tests run * minor fix to the state * test dependency via fixture fixed * tests fixed + some fixture scope adjustment * more tests for orthogonalization functions * pylint fixes * renamed inner product * extended inner product to arbitrary arrays * orthogonalization tests work * only test double gram schmidt * test for noneuclidean inner product * fixed induced norm * started vectorizing the reorthogonalization functions * fixed orthogonalization * fixed bug in CG search dirs where residuals where not actually reorthogonalized * simplified CG policy * matmul style broadcasting for inner_product * added broadcasting to orthogonalization
- Loading branch information
1 parent
49f26ae
commit 2c5cce8
Showing
14 changed files
with
690 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,15 @@ | ||
"""Utility functions that involve numerical linear algebra.""" | ||
|
||
from ._cholesky_updates import cholesky_update, tril_to_positive_tril | ||
from ._inner_product import induced_norm, inner_product | ||
from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt | ||
|
||
__all__ = ["cholesky_update", "tril_to_positive_tril"] | ||
__all__ = [ | ||
"inner_product", | ||
"induced_norm", | ||
"cholesky_update", | ||
"tril_to_positive_tril", | ||
"gram_schmidt", | ||
"modified_gram_schmidt", | ||
"double_gram_schmidt", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Functions defining useful inner products.""" | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Optional, Union | ||
|
||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
from probnum import linops | ||
|
||
|
||
def inner_product( | ||
v: np.ndarray, | ||
w: np.ndarray, | ||
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, | ||
) -> np.ndarray: | ||
r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`. | ||
For n-d arrays the function computes the inner product over the last axis of the | ||
two arrays ``v`` and ``w``. | ||
Parameters | ||
---------- | ||
v | ||
First array. | ||
w | ||
Second array. | ||
A | ||
Symmetric positive (semi-)definite matrix defining the geometry. | ||
Returns | ||
------- | ||
inprod : | ||
Inner product(s) of ``v`` and ``w``. | ||
Notes | ||
----- | ||
Note that the broadcasting behavior of :func:`inner_product` differs from :func:`numpy.inner`. Rather it follows the broadcasting rules of :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors. | ||
""" | ||
v_T = v[..., None, :] | ||
w = w[..., :, None] | ||
|
||
if A is None: | ||
vw_inprod = v_T @ w | ||
else: | ||
vw_inprod = v_T @ (A @ w) | ||
|
||
return np.squeeze(vw_inprod, axis=(-2, -1)) | ||
|
||
|
||
def induced_norm( | ||
v: np.ndarray, | ||
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, | ||
axis: int = -1, | ||
) -> np.ndarray: | ||
r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`. | ||
Computes the induced norm over the given axis of the array. | ||
Parameters | ||
---------- | ||
v | ||
Array. | ||
A | ||
Symmetric positive (semi-)definite linear operator defining the geometry. | ||
axis | ||
Specifies the axis along which to compute the vector norms. | ||
Returns | ||
------- | ||
norm : | ||
Vector norm of ``v`` along the given ``axis``. | ||
""" | ||
|
||
if A is None: | ||
return np.linalg.norm(v, ord=2, axis=axis, keepdims=False) | ||
|
||
v = np.moveaxis(v, axis, -1) | ||
w = np.squeeze(A @ v[..., :, None], axis=-1) | ||
|
||
return np.sqrt(np.sum(v * w, axis=-1)) |
Oops, something went wrong.