Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorthogonalization for a stable implementation of the probabilistic linear solver #580

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
83c5543
initial interface
JonathanWenger Nov 8, 2021
34521cc
minor doc improvements
JonathanWenger Nov 9, 2021
21f662b
first methods added to PLS
JonathanWenger Nov 9, 2021
2887933
pylint errors fixed
JonathanWenger Nov 9, 2021
7251733
.
JonathanWenger Nov 9, 2021
dadf4ce
type hint fixed
JonathanWenger Nov 9, 2021
52a0dae
corrected generator type hint
JonathanWenger Nov 9, 2021
eef8718
made prior part of solve methods
JonathanWenger Nov 9, 2021
e8e454c
solve method implemented
JonathanWenger Nov 9, 2021
d3d3fef
pylint fix
JonathanWenger Nov 9, 2021
43e84c7
better documentation
JonathanWenger Nov 9, 2021
232dfba
minor doc fixes
JonathanWenger Nov 9, 2021
dc1a12b
solve iterator
JonathanWenger Nov 10, 2021
60e87cd
initial draft of doctest
JonathanWenger Nov 10, 2021
b0ee8b2
doctest added
JonathanWenger Nov 10, 2021
00d9034
.
JonathanWenger Nov 10, 2021
64e8108
debugging attempt
JonathanWenger Nov 10, 2021
23e6943
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 17, 2021
cadf13a
solvers added as classes
JonathanWenger Nov 17, 2021
71fb005
tests on random spd matrices
JonathanWenger Nov 17, 2021
f914ab0
bugfix solution-based belief update
JonathanWenger Nov 18, 2021
3dcbf7d
fixed notebooks
JonathanWenger Nov 18, 2021
32478a1
belief fixes for matrix-based soolvers
JonathanWenger Nov 18, 2021
1dcd2c0
test added and bugfix
JonathanWenger Nov 18, 2021
58bae7a
Merge branch 'normal-matmul' into linsolve-pls
JonathanWenger Nov 18, 2021
c605351
doctest fixed
JonathanWenger Nov 18, 2021
ef68df6
jupyter notebook fix
JonathanWenger Nov 18, 2021
991ffcb
symmetric belief update fixed
JonathanWenger Nov 18, 2021
625360c
precision in doctest reduced
JonathanWenger Nov 18, 2021
ddfe933
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 19, 2021
e475cac
.
JonathanWenger Nov 23, 2021
6ef6417
tests fixed
JonathanWenger Nov 23, 2021
2274cce
perfect information tests added and docs improved for solver state
JonathanWenger Nov 24, 2021
d736274
Merge branch 'main' into linsolve-pls
JonathanWenger Nov 24, 2021
7425fca
orthogonalized residuals for bayescg
JonathanWenger Nov 24, 2021
d08a4bc
test for orthogonalization functions run
JonathanWenger Nov 25, 2021
8d76cf5
test for orthogonalization functions run
JonathanWenger Nov 25, 2021
5ad46d5
merged main
JonathanWenger Nov 30, 2021
e72a69d
improved and generalized implementation of gram-schmidt
JonathanWenger Dec 9, 2021
8f6ea6f
CG with reorthogonalization
JonathanWenger Dec 9, 2021
c4eec0d
test for reorthogonalization
JonathanWenger Dec 9, 2021
e78abfb
.
JonathanWenger Dec 9, 2021
b7c3b35
reorthogonalization only done in the policy for now
JonathanWenger Dec 9, 2021
bb60374
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 9, 2021
79c9942
extracted reorthogonalized action function
JonathanWenger Dec 9, 2021
727c1a5
tests run
JonathanWenger Dec 9, 2021
3086977
minor fix to the state
JonathanWenger Dec 9, 2021
3f9876a
test dependency via fixture fixed
JonathanWenger Dec 17, 2021
4b08072
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 17, 2021
dce0865
tests fixed + some fixture scope adjustment
JonathanWenger Dec 17, 2021
816b2cd
more tests for orthogonalization functions
JonathanWenger Dec 17, 2021
077a2ee
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 17, 2021
f8cc01b
pylint fixes
JonathanWenger Dec 17, 2021
419b7f5
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Dec 20, 2021
a8417a6
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 12, 2022
399b9f4
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 28, 2022
df2a184
renamed inner product
JonathanWenger Jan 31, 2022
931925d
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Jan 31, 2022
863d1f2
extended inner product to arbitrary arrays
JonathanWenger Jan 31, 2022
af05bb2
orthogonalization tests work
JonathanWenger Jan 31, 2022
8fde152
only test double gram schmidt
JonathanWenger Jan 31, 2022
3d547ae
test for noneuclidean inner product
JonathanWenger Jan 31, 2022
7c75438
fixed induced norm
JonathanWenger Jan 31, 2022
50d5cd7
started vectorizing the reorthogonalization functions
JonathanWenger Jan 31, 2022
2120c31
fixed orthogonalization
JonathanWenger Jan 31, 2022
fd9c53c
fixed bug in CG search dirs where residuals where not actually reorth…
JonathanWenger Jan 31, 2022
ac7271a
simplified CG policy
JonathanWenger Feb 10, 2022
3371bd7
matmul style broadcasting for inner_product
JonathanWenger Feb 10, 2022
e58c168
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 10, 2022
f270581
added broadcasting to orthogonalization
JonathanWenger Feb 10, 2022
ff353b4
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 11, 2022
60df256
Merge branch 'main' into pls-reorthogonalization
JonathanWenger Feb 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/probnum/linalg/solvers/policies/_conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> np.ndarray:

action = -solver_state.residual.copy()
residual = solver_state.residual.copy()
JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved

if self._reorthogonalization_fn_residual is not None and solver_state.step == 0:
solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual]
Expand All @@ -62,15 +62,17 @@ def __call__(

# A-conjugacy correction (in exact arithmetic)
beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2
action += beta * solver_state.actions[solver_state.step - 1]
action = -residual + beta * solver_state.actions[solver_state.step - 1]

JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved
# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
return self._reorthogonalized_action(
action=action, solver_state=solver_state
)

return action
return action

return -residual
JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved

def _reorthogonalized_residuals(
self,
Expand All @@ -82,7 +84,7 @@ def _reorthogonalized_residuals(
orthogonal_basis=np.asarray(
solver_state.cache["reorthogonalized_residuals"]
),
inprod=None,
inner_product=None,
)
solver_state.cache["reorthogonalized_residuals"].append(residual)
prev_residual = solver_state.cache["reorthogonalized_residuals"][
Expand Down Expand Up @@ -110,5 +112,5 @@ def _reorthogonalized_action(
return self._reorthogonalization_fn_action(
v=action,
orthogonal_basis=orthogonal_basis,
inprod=inprod_matrix,
inner_product=inprod_matrix,
)
6 changes: 3 additions & 3 deletions src/probnum/utils/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Utility functions that involve numerical linear algebra."""

from ._cholesky_updates import cholesky_update, tril_to_positive_tril
from ._inner_product import euclidean_inprod, euclidean_norm
from ._inner_product import induced_norm, inner_product
from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt

__all__ = [
"euclidean_inprod",
"euclidean_norm",
"inner_product",
"induced_norm",
"cholesky_update",
"tril_to_positive_tril",
"gram_schmidt",
Expand Down
53 changes: 32 additions & 21 deletions src/probnum/utils/linalg/_inner_product.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,76 @@
"""Functions defining useful inner products."""
from __future__ import annotations

from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import numpy as np

from probnum import linops
if TYPE_CHECKING:
from probnum import linops


JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved
def euclidean_inprod(
def inner_product(
v: np.ndarray,
w: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
) -> np.ndarray:
r"""(Modified) Euclidean inner product :math:`\langle v, w \rangle_A := v^T A w`.
r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`.

For arrays the function computes the inner product over the last axes of the
two arrays ``v`` and ``w``.

Parameters
----------
v
First vector.
First array.
w
Second vector.
Second array.
A
Symmetric positive (semi-)definite matrix defining the geometry.

Returns
-------
inprod
Inner product.
inprod :
*shape=(\*v.shape[:-1], \*w.shape[:-1])* -- Inner product of ``v`` and ``w``. If they are both 1-D arrays then a scalar is returned; otherwise an array is returned.
JonathanWenger marked this conversation as resolved.
Show resolved Hide resolved
"""

v_T = v[..., None, :]
w = w[..., :, None]

if A is None:
vw_inprod = v_T @ w
vw_inprod = np.dot(v, w)
else:
vw_inprod = v_T @ (A @ w)
vw_inprod = np.dot(v, A @ w)

return np.squeeze(vw_inprod, axis=(-2, -1))
return np.squeeze(vw_inprod, axis=(-1))


def euclidean_norm(
def induced_norm(
v: np.ndarray,
A: Optional[Union[np.ndarray, linops.LinearOperator]] = None,
axis: int = -1,
) -> np.ndarray:
r"""(Modified) Euclidean norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`.
r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`.

Computes the induced norm over the given axis of the array.

Parameters
----------
v
Vector.
Array.
A
Symmetric positive (semi-)definite matrix defining the geometry.
Symmetric positive (semi-)definite linear operator defining the geometry.
axis
Specifies the axis along which to compute the vector norms.

Returns
-------
norm
Vector norm.
norm :
Vector norm of ``v`` along the given ``axis``.
"""

if A is None:
return np.linalg.norm(v, ord=2, axis=-1, keepdims=False)
return np.linalg.norm(v, ord=2, axis=axis, keepdims=False)

v = np.moveaxis(v, axis, -1)
w = np.squeeze(A @ v[..., :, None], axis=-1)

return np.sqrt(euclidean_inprod(v, v, A))
return np.sqrt(np.sum(v * w, axis=-1))
86 changes: 53 additions & 33 deletions src/probnum/utils/linalg/_orthogonalize.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Orthogonalization of vectors."""

from functools import partial
from typing import Callable, Iterable, Optional, Union

import numpy as np

from probnum import linops

from ._inner_product import euclidean_inprod, euclidean_norm
from ._inner_product import induced_norm
from ._inner_product import inner_product as inner_product_fn


def gram_schmidt(
v: np.ndarray,
orthogonal_basis: Iterable[np.ndarray],
inprod: Optional[
inner_product: Optional[
Union[
np.ndarray,
linops.LinearOperator,
Expand All @@ -32,24 +34,27 @@ def gram_schmidt(
Vector to orthogonalize.
orthogonal_basis
Orthogonal basis.
inprod
Inner product.
inner_product
Inner product defining orthogonality. Can be either a :class`numpy.ndarray` or a :class:`Callable`
defining the inner product. Defaults to the euclidean inner product.
normalize
Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`.

Returns
-------
v_orth
v_orth :
Orthogonalized vector.
"""
if inprod is None:
inprod_fn = euclidean_inprod
norm_fn = euclidean_norm
elif isinstance(inprod, (np.ndarray, linops.LinearOperator)):
inprod_fn = lambda v, w: euclidean_inprod(v, w, A=inprod)
norm_fn = lambda v: euclidean_norm(v, A=inprod)
orthogonal_basis = np.atleast_2d(orthogonal_basis)

if inner_product is None:
inprod_fn = inner_product_fn
norm_fn = partial(induced_norm, axis=-1)
elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)):
inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product)
norm_fn = lambda v: induced_norm(v, A=inner_product, axis=-1)
else:
inprod_fn = inprod
inprod_fn = inner_product
norm_fn = lambda v: np.sqrt(inprod_fn(v, v))

v_orth = v.copy()
Expand All @@ -66,7 +71,7 @@ def gram_schmidt(
def modified_gram_schmidt(
v: np.ndarray,
orthogonal_basis: Iterable[np.ndarray],
inprod: Optional[
inner_product: Optional[
Union[
np.ndarray,
linops.LinearOperator,
Expand All @@ -86,24 +91,27 @@ def modified_gram_schmidt(
Vector to orthogonalize.
orthogonal_basis
Orthogonal basis.
inprod
Inner product.
inner_product
Inner product defining orthogonality. Can be either a :class:`numpy.ndarray` or a :class:`Callable`
defining the inner product. Defaults to the euclidean inner product.
normalize
Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`.

Returns
-------
v_orth
v_orth :
Orthogonalized vector.
"""
if inprod is None:
inprod_fn = euclidean_inprod
norm_fn = euclidean_norm
elif isinstance(inprod, (np.ndarray, linops.LinearOperator)):
inprod_fn = lambda v, w: euclidean_inprod(v, w, A=inprod)
norm_fn = lambda v: euclidean_norm(v, A=inprod)
orthogonal_basis = np.atleast_2d(orthogonal_basis)

if inner_product is None:
inprod_fn = inner_product_fn
norm_fn = induced_norm
elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)):
inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product)
norm_fn = lambda v: induced_norm(v, A=inner_product)
else:
inprod_fn = inprod
inprod_fn = inner_product
norm_fn = lambda v: np.sqrt(inprod_fn(v, v))

v_orth = v.copy()
Expand All @@ -120,34 +128,40 @@ def modified_gram_schmidt(
def double_gram_schmidt(
v: np.ndarray,
orthogonal_basis: Iterable[np.ndarray],
inprod: Optional[
inner_product: Optional[
Union[
np.ndarray,
linops.LinearOperator,
Callable[[np.ndarray, np.ndarray], np.ndarray],
]
] = None,
normalize: bool = False,
gram_schmidt_fn: Callable = modified_gram_schmidt,
) -> np.ndarray:
r"""Perform the modified Gram-Schmidt process twice.
r"""Perform the (modified) Gram-Schmidt process twice.

Computes a vector :math:`v'` such that :math:`\langle v', b_i \rangle = 0` for
all basis vectors :math:`b_i \in B` in the orthogonal basis. This performs the modified Gram-Schmidt orthogonalization process twice, which is generally more stable than just reorthogonalizing once. [1]_ [2]_
all basis vectors :math:`b_i \in B` in the orthogonal basis. This performs the
(modified) Gram-Schmidt orthogonalization process twice, which is generally more
stable than just reorthogonalizing once. [1]_ [2]_

Parameters
----------
v
Vector to orthogonalize.
orthogonal_basis
Orthogonal basis.
inprod
Inner product.
inner_product
Inner product defining orthogonality. Can be either a :class:`numpy.ndarray` or a :class:`Callable`
defining the inner product. Defaults to the euclidean inner product.
normalize
Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`.
gram_schmidt_fn
Gram-Schmidt process to use. One of :meth:`gram_schmidt` or :meth:`modified_gram_schmidt`.

Returns
-------
v_orth
v_orth :
Orthogonalized vector.

References
Expand All @@ -157,9 +171,15 @@ def double_gram_schmidt(
.. [2] L. Giraud, J. Langou, and M. Rozloznik, The loss of orthogonality in the
Gram-Schmidt orthogonalization process, Comput. Math. Appl., 50 (2005)
"""
v_orth = modified_gram_schmidt(
v=v, orthogonal_basis=orthogonal_basis, inprod=inprod, normalize=normalize
v_orth = gram_schmidt_fn(
v=v,
orthogonal_basis=orthogonal_basis,
inner_product=inner_product,
normalize=normalize,
)
return modified_gram_schmidt(
v=v_orth, orthogonal_basis=orthogonal_basis, inprod=inprod, normalize=normalize
return gram_schmidt_fn(
v=v_orth,
orthogonal_basis=orthogonal_basis,
inner_product=inner_product,
normalize=normalize,
)
Loading