From 2c5cce865335dc1f56fb4aed3b0bb25558961f35 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Fri, 11 Feb 2022 07:27:27 -0500 Subject: [PATCH] Reorthogonalization for a stable implementation of the probabilistic linear solver (#580) * initial interface * minor doc improvements * first methods added to PLS * pylint errors fixed * . * type hint fixed * corrected generator type hint * made prior part of solve methods * solve method implemented * pylint fix * better documentation * minor doc fixes * solve iterator * initial draft of doctest * doctest added * . * debugging attempt * solvers added as classes * tests on random spd matrices * bugfix solution-based belief update * fixed notebooks * belief fixes for matrix-based soolvers * test added and bugfix * doctest fixed * jupyter notebook fix * symmetric belief update fixed * precision in doctest reduced * . * tests fixed * perfect information tests added and docs improved for solver state * orthogonalized residuals for bayescg * test for orthogonalization functions run * test for orthogonalization functions run * improved and generalized implementation of gram-schmidt * CG with reorthogonalization * test for reorthogonalization * . * reorthogonalization only done in the policy for now * extracted reorthogonalized action function * tests run * minor fix to the state * test dependency via fixture fixed * tests fixed + some fixture scope adjustment * more tests for orthogonalization functions * pylint fixes * renamed inner product * extended inner product to arbitrary arrays * orthogonalization tests work * only test double gram schmidt * test for noneuclidean inner product * fixed induced norm * started vectorizing the reorthogonalization functions * fixed orthogonalization * fixed bug in CG search dirs where residuals where not actually reorthogonalized * simplified CG policy * matmul style broadcasting for inner_product * added broadcasting to orthogonalization --- src/probnum/linalg/solvers/_state.py | 13 +- .../_solution_based_proj_rhs_belief_update.py | 3 + .../solvers/beliefs/_linear_system_belief.py | 13 +- .../solvers/policies/_conjugate_gradient.py | 98 +++++++++- src/probnum/utils/linalg/__init__.py | 12 +- src/probnum/utils/linalg/_inner_product.py | 81 ++++++++ src/probnum/utils/linalg/_orthogonalize.py | 185 ++++++++++++++++++ tests/test_linalg/cases/linear_systems.py | 7 +- .../test_solvers/cases/policies.py | 13 ++ tests/test_linalg/test_solvers/conftest.py | 6 +- .../test_policies/test_conjugate_gradient.py | 36 ++-- .../test_linear_solver_policy.py | 4 + .../test_linalg/test_inner_product.py | 87 ++++++++ .../test_linalg/test_orthogonalize.py | 161 +++++++++++++++ 14 files changed, 690 insertions(+), 29 deletions(-) create mode 100644 src/probnum/utils/linalg/_inner_product.py create mode 100644 src/probnum/utils/linalg/_orthogonalize.py create mode 100644 tests/test_utils/test_linalg/test_inner_product.py create mode 100644 tests/test_utils/test_linalg/test_orthogonalize.py diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 5c5738d02..cd96958cc 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -1,7 +1,7 @@ """State of a probabilistic linear solver.""" import dataclasses -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -31,20 +31,23 @@ def __init__( prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief", rng: Optional[np.random.Generator] = None, ): + self.rng: Optional[np.random.Generator] = rng self.problem: problems.LinearSystem = problem + # Belief self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior - self._step: int = 0 - + # Caches self._actions: List[np.ndarray] = [None] self._observations: List[Any] = [None] self._residuals: List[np.ndarray] = [ self.problem.A @ self.belief.x.mean - self.problem.b, - None, ] - self.rng: Optional[np.random.Generator] = rng + self.cache: Dict[str, Any] = {} + + # Solver info + self._step: int = 0 def __repr__(self) -> str: return f"{self.__class__.__name__}(step={self.step})" diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py index e926a88a2..1d6f4d337 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py @@ -44,9 +44,12 @@ def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> LinearSystemBelief: + # Compute projected residual action_A = solver_state.action @ solver_state.problem.A pred = action_A @ solver_state.belief.x.mean proj_resid = solver_state.observation - pred + + # Compute gain and covariance update cov_xy = solver_state.belief.x.cov @ action_A.T gram = action_A @ cov_xy + self._noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index a4b240c35..8fa830032 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -7,7 +7,7 @@ from functools import cached_property from typing import Mapping, Optional -from probnum import randvars +from probnum import linops, randvars # pylint: disable="invalid-name" @@ -134,6 +134,8 @@ def A(self) -> randvars.RandomVariable: @property def Ainv(self) -> Optional[randvars.RandomVariable]: """Belief about the (pseudo-)inverse of the system matrix.""" + if self._Ainv is None: + return self._induced_Ainv() return self._Ainv @property @@ -149,3 +151,12 @@ def _induced_x(self) -> randvars.RandomVariable: :math:`H` and :math:`b`. """ return self.Ainv @ self.b + + def _induced_Ainv(self) -> randvars.RandomVariable: + r"""Induced belief about the inverse from a belief about the solution. + + Computes a consistent belief about the inverse from a belief about the solution. + """ + return randvars.Constant( + linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0])) + ) diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 3069f9b3c..019e52e2c 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -1,8 +1,11 @@ """Policy returning :math:`A`-conjugate actions.""" +from typing import Callable, Iterable, Optional, Tuple + import numpy as np import probnum # pylint: disable="unused-import" +from probnum import linops, randvars from . import _linear_solver_policy @@ -11,21 +14,104 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy): r"""Policy returning :math:`A`-conjugate actions. Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`. + + Parameters + ---------- + reorthogonalization_fn_residual + Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` + the residuals are reorthogonalized before the action is computed. + reorthogonalization_fn_action + Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` + the computed action is reorthogonalized. """ + def __init__( + self, + reorthogonalization_fn_residual: Optional[ + Callable[ + [np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray + ] + ] = None, + reorthogonalization_fn_action: Optional[ + Callable[ + [np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray + ] + ] = None, + ) -> None: + self._reorthogonalization_fn_residual = reorthogonalization_fn_residual + self._reorthogonalization_fn_action = reorthogonalization_fn_action + def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: - action = -solver_state.residual.copy() + residual = solver_state.residual + + if self._reorthogonalization_fn_residual is not None and solver_state.step == 0: + solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual] if solver_state.step > 0: + # Reorthogonalization of the residual + if self._reorthogonalization_fn_residual is not None: + residual, prev_residual = self._reorthogonalized_residuals( + solver_state=solver_state + ) + else: + residual = solver_state.residual + prev_residual = solver_state.residuals[solver_state.step - 1] + # A-conjugacy correction (in exact arithmetic) - beta = ( - np.linalg.norm(solver_state.residual) - / np.linalg.norm(solver_state.residuals[solver_state.step - 1]) - ) ** 2 + beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 + action = -residual + beta * solver_state.actions[solver_state.step - 1] + + # Reorthogonalization of the resulting action + if self._reorthogonalization_fn_action is not None: + return self._reorthogonalized_action( + action=action, solver_state=solver_state + ) - action += beta * solver_state.actions[solver_state.step - 1] + else: + action = -residual return action + + def _reorthogonalized_residuals( + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + ) -> Tuple[np.ndarray, np.ndarray]: + """Compute the reorthogonalized residual and its predecessor.""" + residual = self._reorthogonalization_fn_residual( + v=solver_state.residual, + orthogonal_basis=np.asarray( + solver_state.cache["reorthogonalized_residuals"] + ), + inner_product=None, + ) + solver_state.cache["reorthogonalized_residuals"].append(residual) + prev_residual = solver_state.cache["reorthogonalized_residuals"][ + solver_state.step - 1 + ] + return residual, prev_residual + + def _reorthogonalized_action( + self, + action: np.ndarray, + solver_state: "probnum.linalg.solvers.LinearSolverState", + ) -> np.ndarray: + """Reorthogonalize the computed action.""" + if isinstance(solver_state.prior.x, randvars.Normal): + inprod_matrix = ( + solver_state.problem.A + @ solver_state.prior.x.cov + @ solver_state.problem.A.T + ) + elif isinstance(solver_state.prior.x, randvars.Constant): + inprod_matrix = solver_state.problem.A + + orthogonal_basis = np.asarray(solver_state.actions[0 : solver_state.step]) + + return self._reorthogonalization_fn_action( + v=action, + orthogonal_basis=orthogonal_basis, + inner_product=inprod_matrix, + ) diff --git a/src/probnum/utils/linalg/__init__.py b/src/probnum/utils/linalg/__init__.py index 8daa20e2c..a817cdd0f 100644 --- a/src/probnum/utils/linalg/__init__.py +++ b/src/probnum/utils/linalg/__init__.py @@ -1,5 +1,15 @@ """Utility functions that involve numerical linear algebra.""" from ._cholesky_updates import cholesky_update, tril_to_positive_tril +from ._inner_product import induced_norm, inner_product +from ._orthogonalize import double_gram_schmidt, gram_schmidt, modified_gram_schmidt -__all__ = ["cholesky_update", "tril_to_positive_tril"] +__all__ = [ + "inner_product", + "induced_norm", + "cholesky_update", + "tril_to_positive_tril", + "gram_schmidt", + "modified_gram_schmidt", + "double_gram_schmidt", +] diff --git a/src/probnum/utils/linalg/_inner_product.py b/src/probnum/utils/linalg/_inner_product.py new file mode 100644 index 000000000..0593ab198 --- /dev/null +++ b/src/probnum/utils/linalg/_inner_product.py @@ -0,0 +1,81 @@ +"""Functions defining useful inner products.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np + +if TYPE_CHECKING: + from probnum import linops + + +def inner_product( + v: np.ndarray, + w: np.ndarray, + A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, +) -> np.ndarray: + r"""Inner product :math:`\langle v, w \rangle_A := v^T A w`. + + For n-d arrays the function computes the inner product over the last axis of the + two arrays ``v`` and ``w``. + + Parameters + ---------- + v + First array. + w + Second array. + A + Symmetric positive (semi-)definite matrix defining the geometry. + + Returns + ------- + inprod : + Inner product(s) of ``v`` and ``w``. + + Notes + ----- + Note that the broadcasting behavior of :func:`inner_product` differs from :func:`numpy.inner`. Rather it follows the broadcasting rules of :func:`numpy.matmul` in that n-d arrays are treated as stacks of vectors. + """ + v_T = v[..., None, :] + w = w[..., :, None] + + if A is None: + vw_inprod = v_T @ w + else: + vw_inprod = v_T @ (A @ w) + + return np.squeeze(vw_inprod, axis=(-2, -1)) + + +def induced_norm( + v: np.ndarray, + A: Optional[Union[np.ndarray, linops.LinearOperator]] = None, + axis: int = -1, +) -> np.ndarray: + r"""Induced norm :math:`\lVert v \rVert_A := \sqrt{v^T A v}`. + + Computes the induced norm over the given axis of the array. + + Parameters + ---------- + v + Array. + A + Symmetric positive (semi-)definite linear operator defining the geometry. + axis + Specifies the axis along which to compute the vector norms. + + Returns + ------- + norm : + Vector norm of ``v`` along the given ``axis``. + """ + + if A is None: + return np.linalg.norm(v, ord=2, axis=axis, keepdims=False) + + v = np.moveaxis(v, axis, -1) + w = np.squeeze(A @ v[..., :, None], axis=-1) + + return np.sqrt(np.sum(v * w, axis=-1)) diff --git a/src/probnum/utils/linalg/_orthogonalize.py b/src/probnum/utils/linalg/_orthogonalize.py new file mode 100644 index 000000000..4c499b6a7 --- /dev/null +++ b/src/probnum/utils/linalg/_orthogonalize.py @@ -0,0 +1,185 @@ +"""Orthogonalization of vectors.""" + +from functools import partial +from typing import Callable, Iterable, Optional, Union + +import numpy as np + +from probnum import linops + +from ._inner_product import induced_norm +from ._inner_product import inner_product as inner_product_fn + + +def gram_schmidt( + v: np.ndarray, + orthogonal_basis: Iterable[np.ndarray], + inner_product: Optional[ + Union[ + np.ndarray, + linops.LinearOperator, + Callable[[np.ndarray, np.ndarray], np.ndarray], + ] + ] = None, + normalize: bool = False, +) -> np.ndarray: + r"""Orthogonalize a vector with respect to an orthogonal basis and inner product. + + Computes a vector :math:`v'` such that :math:`\langle v', b_i \rangle = 0` for + all basis vectors :math:`b_i \in B` in the orthogonal basis. + + Parameters + ---------- + v + Vector (or stack of vectors) to orthogonalize against ``orthogonal_basis``. + orthogonal_basis + Orthogonal basis. + inner_product + Inner product defining orthogonality. Can be either a :class`numpy.ndarray` or a :class:`Callable` + defining the inner product. Defaults to the euclidean inner product. + normalize + Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`. + + Returns + ------- + v_orth : + Orthogonalized vector. + """ + orthogonal_basis = np.atleast_2d(orthogonal_basis) + + if inner_product is None: + inprod_fn = inner_product_fn + norm_fn = partial(induced_norm, axis=-1) + elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)): + inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product) + norm_fn = lambda v: induced_norm(v, A=inner_product, axis=-1) + else: + inprod_fn = inner_product + norm_fn = lambda v: np.sqrt(inprod_fn(v, v)) + + v_orth = v.copy() + + for u in orthogonal_basis: + v_orth -= (inprod_fn(u, v)[..., None] / inprod_fn(u, u)) * u + + if normalize: + v_orth /= norm_fn(v_orth)[..., None] + + return v_orth + + +def modified_gram_schmidt( + v: np.ndarray, + orthogonal_basis: Iterable[np.ndarray], + inner_product: Optional[ + Union[ + np.ndarray, + linops.LinearOperator, + Callable[[np.ndarray, np.ndarray], np.ndarray], + ] + ] = None, + normalize: bool = False, +) -> np.ndarray: + r"""Stabilized Gram-Schmidt process. + + Computes a vector :math:`v'` such that :math:`\langle v', b_i \rangle = 0` for + all basis vectors :math:`b_i \in B` in the orthogonal basis in a numerically stable fashion. + + Parameters + ---------- + v + Vector (or stack of vectors) to orthogonalize against ``orthogonal_basis``. + orthogonal_basis + Orthogonal basis. + inner_product + Inner product defining orthogonality. Can be either a :class:`numpy.ndarray` or a :class:`Callable` + defining the inner product. Defaults to the euclidean inner product. + normalize + Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`. + + Returns + ------- + v_orth : + Orthogonalized vector. + """ + orthogonal_basis = np.atleast_2d(orthogonal_basis) + + if inner_product is None: + inprod_fn = inner_product_fn + norm_fn = induced_norm + elif isinstance(inner_product, (np.ndarray, linops.LinearOperator)): + inprod_fn = lambda v, w: inner_product_fn(v, w, A=inner_product) + norm_fn = lambda v: induced_norm(v, A=inner_product) + else: + inprod_fn = inner_product + norm_fn = lambda v: np.sqrt(inprod_fn(v, v)) + + v_orth = v.copy() + + for u in orthogonal_basis: + v_orth -= (inprod_fn(u, v_orth)[..., None] / inprod_fn(u, u)) * u + + if normalize: + v_orth /= norm_fn(v_orth)[..., None] + + return v_orth + + +def double_gram_schmidt( + v: np.ndarray, + orthogonal_basis: Iterable[np.ndarray], + inner_product: Optional[ + Union[ + np.ndarray, + linops.LinearOperator, + Callable[[np.ndarray, np.ndarray], np.ndarray], + ] + ] = None, + normalize: bool = False, + gram_schmidt_fn: Callable = modified_gram_schmidt, +) -> np.ndarray: + r"""Perform the (modified) Gram-Schmidt process twice. + + Computes a vector :math:`v'` such that :math:`\langle v', b_i \rangle = 0` for + all basis vectors :math:`b_i \in B` in the orthogonal basis. This performs the + (modified) Gram-Schmidt orthogonalization process twice, which is generally more + stable than just reorthogonalizing once. [1]_ [2]_ + + Parameters + ---------- + v + Vector (or stack of vectors) to orthogonalize against ``orthogonal_basis``. + orthogonal_basis + Orthogonal basis. + inner_product + Inner product defining orthogonality. Can be either a :class:`numpy.ndarray` or a :class:`Callable` + defining the inner product. Defaults to the euclidean inner product. + normalize + Normalize the output vector, s.t. :math:`\langle v', v' \rangle = 1`. + gram_schmidt_fn + Gram-Schmidt process to use. One of :meth:`gram_schmidt` or :meth:`modified_gram_schmidt`. + + Returns + ------- + v_orth : + Orthogonalized vector. + + References + ---------- + .. [1] L. Giraud, J. Langou, M. Rozloznik, and J. van den Eshof, Rounding error + analysis of the classical Gram-Schmidt orthogonalization process, Numer. Math., 101 (2005), pp. 87–100 + .. [2] L. Giraud, J. Langou, and M. Rozloznik, The loss of orthogonality in the + Gram-Schmidt orthogonalization process, Comput. Math. Appl., 50 (2005) + """ + v_orth = gram_schmidt_fn( + v=v, + orthogonal_basis=orthogonal_basis, + inner_product=inner_product, + normalize=normalize, + ) + return gram_schmidt_fn( + v=v_orth, + orthogonal_basis=orthogonal_basis, + inner_product=inner_product, + normalize=normalize, + ) diff --git a/tests/test_linalg/cases/linear_systems.py b/tests/test_linalg/cases/linear_systems.py index 55c3bdc98..4554ae717 100644 --- a/tests/test_linalg/cases/linear_systems.py +++ b/tests/test_linalg/cases/linear_systems.py @@ -12,7 +12,7 @@ cases_matrices = ".matrices" -@pytest_cases.parametrize_with_cases("matrix", cases=cases_matrices) +@pytest_cases.parametrize_with_cases("matrix", cases=cases_matrices, scope="module") def case_linsys( matrix: Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], rng: np.random.Generator, @@ -22,7 +22,10 @@ def case_linsys( @pytest_cases.parametrize_with_cases( - "spd_matrix", cases=cases_matrices, has_tag=["symmetric", "positive_definite"] + "spd_matrix", + cases=cases_matrices, + has_tag=["symmetric", "positive_definite"], + scope="module", ) def case_spd_linsys( spd_matrix: Union[np.ndarray, scipy.sparse.spmatrix, linops.LinearOperator], diff --git a/tests/test_linalg/test_solvers/cases/policies.py b/tests/test_linalg/test_solvers/cases/policies.py index ac58cf89e..70b4ed5cb 100644 --- a/tests/test_linalg/test_solvers/cases/policies.py +++ b/tests/test_linalg/test_solvers/cases/policies.py @@ -2,12 +2,25 @@ from pytest_cases import case from probnum.linalg.solvers import policies +from probnum.utils.linalg import double_gram_schmidt, modified_gram_schmidt def case_conjugate_gradient(): return policies.ConjugateGradientPolicy() +def case_conjugate_gradient_reorthogonalized_residuals(): + return policies.ConjugateGradientPolicy( + reorthogonalization_fn_residual=double_gram_schmidt + ) + + +def case_conjugate_gradient_reorthogonalized_actions(): + return policies.ConjugateGradientPolicy( + reorthogonalization_fn_action=modified_gram_schmidt + ) + + @case(tags=["random"]) def case_random_unit_vector(): return policies.RandomUnitVectorPolicy() diff --git a/tests/test_linalg/test_solvers/conftest.py b/tests/test_linalg/test_solvers/conftest.py index c371e4451..80cf5e88d 100644 --- a/tests/test_linalg/test_solvers/conftest.py +++ b/tests/test_linalg/test_solvers/conftest.py @@ -2,21 +2,21 @@ from pytest_cases import fixture, parametrize -@fixture() +@fixture(scope="module") @parametrize("m", [1, 10, 100]) def nrows(m: int) -> int: """Number of rows of a matrix.""" return m -@fixture() +@fixture(scope="module") @parametrize("n", [1, 10, 100]) def ncols(n: int) -> int: """Number of columns of a matrix.""" return n -@fixture +@fixture(scope="module") @parametrize("k", [1, 2, 100]) def nrhs(k: int) -> int: """Number of right-hand-sides of a linear system.""" diff --git a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py index 4a39495ad..96cc5917e 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py +++ b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py @@ -1,4 +1,5 @@ """Tests for a policy returning random unit vectors.""" +import copy import pathlib import numpy as np @@ -18,6 +19,7 @@ def test_initial_action_is_negative_gradient( policy: policies.ConjugateGradientPolicy, state: LinearSolverState ): assert state.step == 0 + state = copy.deepcopy(state) action = policy(state) np.testing.assert_allclose(action, -state.residual) @@ -29,33 +31,45 @@ def test_conjugate_actions( ): """Tests whether actions generated by the policy are A-conjugate via a naive CG implementation.""" - A = state.problem.A + + solver_state = copy.deepcopy(state) + A = solver_state.problem.A for _ in range(A.shape[1]): # Action - s = policy(state) - state.action = s + s = policy(solver_state) + solver_state.action = s # Observation y = A @ s - state.observation = y + solver_state.observation = y # Residual - r = state.residual + r = solver_state.residual # Step size alpha = np.linalg.norm(r) ** 2 / np.inner(s, y) # Solution update - x = state.belief.x.mean - state.belief.x = randvars.Constant(x + alpha * s) + x = solver_state.belief.x.mean + solver_state.belief.x = randvars.Constant(x + alpha * s) + + solver_state.next_step() + + actions = np.array(solver_state.actions[:-1]).T + innerprods_actions = actions.T @ A @ actions - state.next_step() + np.testing.assert_allclose( + innerprods_actions, np.diag(np.diag(innerprods_actions)), atol=1e-7, rtol=1e7 + ) - actions = np.array(state.actions[:-1]).T - innerprods = actions.T @ A @ actions + residuals = np.array(solver_state.residuals[:-1]).T + innerprods_residuals = residuals.T @ residuals np.testing.assert_allclose( - innerprods, np.diag(np.diag(innerprods)), atol=1e-7, rtol=1e7 + innerprods_residuals, + np.diag(np.diag(innerprods_residuals)), + atol=1e-7, + rtol=1e7, ) diff --git a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py index e2b0a4f11..4df954456 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py +++ b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py @@ -1,4 +1,5 @@ """Tests for probabilistic linear solver policies.""" +import copy import pathlib import numpy as np @@ -14,6 +15,7 @@ @parametrize_with_cases("policy", cases=cases_policies) @parametrize_with_cases("state", cases=cases_states) def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolverState): + state = copy.deepcopy(state) action = policy(state) assert isinstance(action, np.ndarray) @@ -21,6 +23,7 @@ def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolve @parametrize_with_cases("policy", cases=cases_policies) @parametrize_with_cases("state", cases=cases_states) def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): + state = copy.deepcopy(state) action = policy(state) assert action.shape[0] == state.problem.A.shape[1] @@ -32,6 +35,7 @@ def test_uses_solver_state_random_number_generator( ): """Test whether randomized policies make use of the random number generator stored in the linear solver state.""" + state = copy.deepcopy(state) rng_state_pre = state.rng.bit_generator.state["state"]["state"] _ = policy(state) rng_state_post = state.rng.bit_generator.state["state"]["state"] diff --git a/tests/test_utils/test_linalg/test_inner_product.py b/tests/test_utils/test_linalg/test_inner_product.py new file mode 100644 index 000000000..57822628f --- /dev/null +++ b/tests/test_utils/test_linalg/test_inner_product.py @@ -0,0 +1,87 @@ +"""Tests for general inner products.""" + +import numpy as np +import pytest + +from probnum.problems.zoo.linalg import random_spd_matrix +from probnum.utils.linalg import induced_norm, inner_product + + +@pytest.fixture(scope="module", params=[1, 10, 50]) +def n(request) -> int: + """Vector size.""" + return request.param + + +@pytest.fixture(scope="module", params=[1, 3, 5]) +def m(request) -> int: + """Number of simultaneous vectors.""" + return request.param + + +@pytest.fixture(scope="module", params=[1, 3]) +def p(request) -> int: + """Number of matrices.""" + return request.param + + +@pytest.fixture(scope="module") +def vector0(n: int) -> np.ndarray: + rng = np.random.default_rng(86 + n) + return rng.standard_normal(size=(n,)) + + +@pytest.fixture(scope="module") +def vector1(n: int) -> np.ndarray: + rng = np.random.default_rng(567 + n) + return rng.standard_normal(size=(n,)) + + +@pytest.fixture(scope="module") +def array0(p: int, m: int, n: int) -> np.ndarray: + rng = np.random.default_rng(86 + p + m + n) + return rng.standard_normal(size=(p, m, n)) + + +@pytest.fixture(scope="module") +def array1(m: int, n: int) -> np.ndarray: + rng = np.random.default_rng(567 + m + n) + return rng.standard_normal(size=(m, n)) + + +def test_inner_product_vectors(vector0: np.ndarray, vector1: np.ndarray): + assert inner_product(v=vector0, w=vector1) == pytest.approx( + np.inner(vector0, vector1) + ) + + +def test_inner_product_arrays(array0: np.ndarray, array1: np.ndarray): + assert inner_product(v=array0, w=array1) == pytest.approx( + np.einsum("...i,...i", array0, array1) + ) + + +def test_euclidean_norm_vector(vector0: np.ndarray): + assert np.linalg.norm(vector0, ord=2) == pytest.approx(induced_norm(v=vector0)) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_euclidean_norm_array(array0: np.ndarray, axis: int): + assert np.linalg.norm(array0, axis=axis, ord=2) == pytest.approx( + induced_norm(v=array0, axis=axis) + ) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_induced_norm_array(array0: np.ndarray, axis: int): + inprod_mat = random_spd_matrix( + rng=np.random.default_rng(254), dim=array0.shape[axis] + ) + array0_moved_axis = np.moveaxis(array0, axis, -1) + A_array_0_moved_axis = np.squeeze( + inprod_mat @ array0_moved_axis[..., :, None], axis=-1 + ) + + assert np.sqrt( + np.sum(array0_moved_axis * A_array_0_moved_axis, axis=-1) + ) == pytest.approx(induced_norm(v=array0, A=inprod_mat, axis=axis)) diff --git a/tests/test_utils/test_linalg/test_orthogonalize.py b/tests/test_utils/test_linalg/test_orthogonalize.py new file mode 100644 index 000000000..2e4c25f2a --- /dev/null +++ b/tests/test_utils/test_linalg/test_orthogonalize.py @@ -0,0 +1,161 @@ +"""Tests for orthogonalization functions.""" + +from functools import partial +from typing import Callable, Union + +import numpy as np +import pytest + +from probnum import linops +from probnum.problems.zoo.linalg import random_spd_matrix +from probnum.utils.linalg import ( + double_gram_schmidt, + gram_schmidt, + modified_gram_schmidt, +) + +n = 100 + + +@pytest.fixture(scope="module", params=[1, 10, 50]) +def basis_size(request) -> int: + """Number of basis vectors.""" + return request.param + + +@pytest.fixture(scope="module") +def vector() -> np.ndarray: + rng = np.random.default_rng(526367 + n) + return rng.standard_normal(size=(n,)) + + +@pytest.fixture(scope="module") +def vectors() -> np.ndarray: + rng = np.random.default_rng(234 + n) + return rng.standard_normal(size=(2, 10, n)) + + +@pytest.fixture( + scope="module", + params=[ + np.eye(n), + linops.Identity(n), + linops.Scaling(factors=1.0, shape=(n, n)), + np.inner, + ], +) +def inprod(request) -> int: + return request.param + + +@pytest.fixture( + scope="module", + params=[ + partial(double_gram_schmidt, gram_schmidt_fn=gram_schmidt), + partial(double_gram_schmidt, gram_schmidt_fn=modified_gram_schmidt), + ], +) +def orthogonalization_fn(request) -> int: + return request.param + + +def test_is_orthogonal( + vector: np.ndarray, + basis_size: int, + inprod: Union[ + np.ndarray, + linops.LinearOperator, + Callable[[np.ndarray, np.ndarray], np.ndarray], + ], + orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], +): + # Compute orthogonal basis + seed = abs(32 + hash(basis_size)) + basis = np.random.default_rng(seed).normal(size=(vector.shape[0], basis_size)) + orthogonal_basis, _ = np.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, orthogonal_basis=orthogonal_basis, inner_product=inprod + ) + np.testing.assert_allclose( + orthogonal_basis @ ortho_vector, + np.zeros((basis_size,)), + atol=1e-12, + rtol=1e-12, + ) + + +def test_is_normalized( + vector: np.ndarray, + basis_size: int, + orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], +): + # Compute orthogonal basis + seed = abs(9467 + hash(basis_size)) + basis = np.random.default_rng(seed).normal(size=(vector.shape[0], basis_size)) + orthogonal_basis, _ = np.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, orthogonal_basis=orthogonal_basis, normalize=True + ) + + assert np.inner(ortho_vector, ortho_vector) == pytest.approx(1.0) + + +@pytest.mark.parametrize( + "inner_product_matrix", + [ + np.diag(np.random.default_rng(123).standard_gamma(1.0, size=(n,))), + 5 * np.eye(n), + random_spd_matrix(rng=np.random.default_rng(46), dim=n), + ], +) +def test_noneuclidean_innerprod( + vector: np.ndarray, + basis_size: int, + inner_product_matrix: np.ndarray, + orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], +): + evals, evecs = np.linalg.eigh(inner_product_matrix) + orthogonal_basis = evecs * 1 / np.sqrt(evals) + orthogonal_basis = orthogonal_basis[:, 0:basis_size].T + + # Orthogonalize vector + ortho_vector = orthogonalization_fn( + v=vector, + orthogonal_basis=orthogonal_basis, + inner_product=inner_product_matrix, + normalize=False, + ) + + np.testing.assert_allclose( + orthogonal_basis @ inner_product_matrix @ ortho_vector, + np.zeros((basis_size,)), + atol=1e-12, + rtol=1e-12, + ) + + +def test_broadcasting( + vectors: np.ndarray, + basis_size: int, + orthogonalization_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], +): + # Compute orthogonal basis + seed = abs(32 + hash(basis_size)) + basis = np.random.default_rng(seed).normal(size=(vectors.shape[-1], basis_size)) + orthogonal_basis, _ = np.linalg.qr(basis) + orthogonal_basis = orthogonal_basis.T + + # Orthogonalize vector + ortho_vectors = orthogonalization_fn(v=vectors, orthogonal_basis=orthogonal_basis) + np.testing.assert_allclose( + np.squeeze(orthogonal_basis @ ortho_vectors[..., None], axis=-1), + np.zeros(vectors.shape[:-1] + (basis_size,)), + atol=1e-12, + rtol=1e-12, + )