diff --git a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py index 9d999fb6c..e6939abab 100644 --- a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py +++ b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py @@ -90,8 +90,8 @@ class ProbabilisticLinearSolver( >>> pls = ProbabilisticLinearSolver( ... policy=policies.ConjugateGradientPolicy(), - ... information_op=information_ops.ProjectedRHSInformationOp(), - ... belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + ... information_op=information_ops.ProjectedResidualInformationOp(), + ... belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), ... stopping_criterion=( ... stopping_criteria.MaxIterationsStoppingCriterion(100) ... | stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5) @@ -152,7 +152,7 @@ def solve_iterator( solver_state State of the probabilistic linear solver. """ - solver_state = LinearSolverState(problem=problem, prior=prior, rng=rng) + solver_state = LinearSolverState(problem=problem, prior=prior) while True: @@ -163,7 +163,7 @@ def solve_iterator( break # Compute action via policy - solver_state.action = self.policy(solver_state=solver_state) + solver_state.action = self.policy(solver_state=solver_state, rng=rng) # Make observation via information operator solver_state.observation = self.information_op(solver_state=solver_state) @@ -234,8 +234,8 @@ def __init__( ): super().__init__( policy=policies.ConjugateGradientPolicy(), - information_op=information_ops.ProjectedRHSInformationOp(), - belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + information_op=information_ops.ProjectedResidualInformationOp(), + belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), stopping_criterion=stopping_criterion, ) @@ -267,8 +267,8 @@ def __init__( ): super().__init__( policy=policies.RandomUnitVectorPolicy(), - information_op=information_ops.ProjectedRHSInformationOp(), - belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + information_op=information_ops.ProjectedResidualInformationOp(), + belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), stopping_criterion=stopping_criterion, ) diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 7fa0d24e3..e58d9d488 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -1,7 +1,7 @@ """State of a probabilistic linear solver.""" - +from collections import defaultdict import dataclasses -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, DefaultDict, List, Optional, Tuple import numpy as np @@ -24,21 +24,17 @@ class LinearSolverState: Linear system to be solved. prior Prior belief over the quantities of interest of the linear system. - rng - Random number generator. """ def __init__( self, problem: problems.LinearSystem, prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief", - rng: Optional[np.random.Generator] = None, ): - self.rng: Optional[np.random.Generator] = rng - self.problem: problems.LinearSystem = problem + self._problem: problems.LinearSystem = problem # Belief - self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior + self._prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior # Caches @@ -47,7 +43,7 @@ def __init__( self._residuals: List[np.ndarray] = [ self.problem.b - self.problem.A @ self.belief.x.mean, ] - self.cache: Dict[str, Any] = {} + self.cache: DefaultDict[str, Any] = defaultdict(list) # Solver info self._step: int = 0 @@ -55,6 +51,16 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(step={self.step})" + @property + def problem(self) -> problems.LinearSystem: + """Linear system to be solved.""" + return self._problem + + @property + def prior(self) -> "probnum.linalg.solvers.beliefs.LinearSystemBelief": + """Prior belief over the quantities of interest of the linear system.""" + return self._prior + @property def step(self) -> int: """Current step of the solver.""" @@ -110,7 +116,7 @@ def observations(self) -> Tuple[Any, ...]: @property def residual(self) -> np.ndarray: - r"""Cached residual :math:`b - Ax_i` for the current solution estimate :math:`x_i`.""" + r"""Residual :math:`r_{i} = b - Ax_{i}`.""" if self._residuals[self.step] is None: self._residuals[self.step] = ( self.problem.b - self.problem.A @ self.belief.x.mean diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py index fec2b1a72..01475f525 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py @@ -1,15 +1,13 @@ """Solution-based belief updates for the quantities of interest of a linear system.""" -from ._solution_based_proj_rhs_belief_update import ( - SolutionBasedProjectedRHSBeliefUpdate, -) +from ._projected_residual_belief_update import ProjectedResidualBeliefUpdate # Public classes and functions. Order is reflected in documentation. __all__ = [ - "SolutionBasedProjectedRHSBeliefUpdate", + "ProjectedResidualBeliefUpdate", ] # Set correct module paths. Corrects links and module paths in documentation. -SolutionBasedProjectedRHSBeliefUpdate.__module__ = ( +ProjectedResidualBeliefUpdate.__module__ = ( "probnum.linalg.solvers.belief_updates.solution_based" ) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py similarity index 69% rename from src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py rename to src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py index 1d6f4d337..df63fe25f 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py @@ -10,14 +10,14 @@ from .._linear_solver_belief_update import LinearSolverBeliefUpdate -class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): - r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information. +class ProjectedResidualBeliefUpdate(LinearSolverBeliefUpdate): + r"""Gaussian belief update given projected residual information. - Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that + Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`s\^top r_i = s^\top (b - Ax_i) = s^\top A (x - x_i)`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, such that .. math :: \begin{align} - x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\ + x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top r_i,\\ \Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i, \end{align} @@ -28,11 +28,6 @@ class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): ---------- noise_var : Variance of the scalar observation noise. - - References - ---------- - .. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian - Analysis*, 2019, 14, 937-1012 """ def __init__(self, noise_var: FloatLike = 0.0) -> None: @@ -44,14 +39,12 @@ def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> LinearSystemBelief: - # Compute projected residual - action_A = solver_state.action @ solver_state.problem.A - pred = action_A @ solver_state.belief.x.mean - proj_resid = solver_state.observation - pred + proj_resid = solver_state.observation # Compute gain and covariance update + action_A = solver_state.action.T @ solver_state.problem.A cov_xy = solver_state.belief.x.cov @ action_A.T - gram = action_A @ cov_xy + self._noise_var + gram = action_A @ cov_xy + self.noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 gain = cov_xy * gram_pinv cov_update = np.outer(gain, cov_xy) @@ -68,3 +61,8 @@ def __call__( return LinearSystemBelief( x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b ) + + @property + def noise_var(self) -> float: + """Observation noise.""" + return self._noise_var diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 8fa830032..22eae5491 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -7,7 +7,7 @@ from functools import cached_property from typing import Mapping, Optional -from probnum import linops, randvars +from probnum import randvars # pylint: disable="invalid-name" @@ -69,7 +69,7 @@ def dim_mismatch_error(**kwargs): if A.shape[1] != x.shape[0]: raise dim_mismatch_error(A=A, x=x) - if x.ndim > 1: + if x.ndim > 1 and b is not None: if x.shape[1] != b.shape[1]: raise dim_mismatch_error(x=x, b=b) elif b is not None: @@ -79,7 +79,7 @@ def dim_mismatch_error(**kwargs): if Ainv is not None: if Ainv.ndim != 2: raise ValueError( - f"Belief over the inverse system matrix may have at most two dimensions, but has {A.ndim}." + f"Belief over the inverse system matrix may have at most two dimensions, but has {Ainv.ndim}." ) if A is not None: if A.shape != Ainv.shape: @@ -100,6 +100,23 @@ def dim_mismatch_error(**kwargs): f"Belief over right-hand-side may have either one or two dimensions but has {b.ndim}." ) + if x is not None and not isinstance(x, randvars.RandomVariable): + raise TypeError( + f"The belief about the solution 'x' must be a RandomVariable, but is {type(x)}." + ) + if A is not None and not isinstance(A, randvars.RandomVariable): + raise TypeError( + f"The belief about the matrix 'A' must be a RandomVariable, but is {type(A)}." + ) + if Ainv is not None and not isinstance(Ainv, randvars.RandomVariable): + raise TypeError( + f"The belief about the inverse matrix 'Ainv' must be a RandomVariable, but is {type(Ainv)}." + ) + if b is not None and not isinstance(b, randvars.RandomVariable): + raise TypeError( + f"The belief about the right-hand-side 'b' must be a RandomVariable, but is {type(b)}." + ) + self._x = x self._A = A self._Ainv = Ainv @@ -134,8 +151,6 @@ def A(self) -> randvars.RandomVariable: @property def Ainv(self) -> Optional[randvars.RandomVariable]: """Belief about the (pseudo-)inverse of the system matrix.""" - if self._Ainv is None: - return self._induced_Ainv() return self._Ainv @property @@ -150,13 +165,4 @@ def _induced_x(self) -> randvars.RandomVariable: to) the random variable :math:`x=Hb`. This assumes independence between :math:`H` and :math:`b`. """ - return self.Ainv @ self.b - - def _induced_Ainv(self) -> randvars.RandomVariable: - r"""Induced belief about the inverse from a belief about the solution. - - Computes a consistent belief about the inverse from a belief about the solution. - """ - return randvars.Constant( - linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0])) - ) + return randvars.asrandvar(self.Ainv @ self.b) diff --git a/src/probnum/linalg/solvers/information_ops/__init__.py b/src/probnum/linalg/solvers/information_ops/__init__.py index 6c3b14e73..1e7869a01 100644 --- a/src/probnum/linalg/solvers/information_ops/__init__.py +++ b/src/probnum/linalg/solvers/information_ops/__init__.py @@ -9,16 +9,16 @@ from ._linear_solver_information_op import LinearSolverInformationOp from ._matvec import MatVecInformationOp -from ._projected_rhs import ProjectedRHSInformationOp +from ._projected_residual import ProjectedResidualInformationOp # Public classes and functions. Order is reflected in documentation. __all__ = [ "LinearSolverInformationOp", "MatVecInformationOp", - "ProjectedRHSInformationOp", + "ProjectedResidualInformationOp", ] # Set correct module paths. Corrects links and module paths in documentation. LinearSolverInformationOp.__module__ = "probnum.linalg.solvers.information_ops" MatVecInformationOp.__module__ = "probnum.linalg.solvers.information_ops" -ProjectedRHSInformationOp.__module__ = "probnum.linalg.solvers.information_ops" +ProjectedResidualInformationOp.__module__ = "probnum.linalg.solvers.information_ops" diff --git a/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py index 41aafcce2..6a83d82ce 100644 --- a/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py +++ b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py @@ -14,7 +14,7 @@ class LinearSolverInformationOp(abc.ABC): See Also -------- MatVecInformationOp: Collect information via matrix-vector multiplication. - ProjectedRHSInformationOp: Collect information via a projection of the current residual. + ProjectedResidualInformationOp: Collect information via a projection of the current residual. """ @abc.abstractmethod diff --git a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py b/src/probnum/linalg/solvers/information_ops/_projected_residual.py similarity index 54% rename from src/probnum/linalg/solvers/information_ops/_projected_rhs.py rename to src/probnum/linalg/solvers/information_ops/_projected_residual.py index 501bcdcad..4d5e9a0d2 100644 --- a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py +++ b/src/probnum/linalg/solvers/information_ops/_projected_residual.py @@ -6,20 +6,20 @@ from ._linear_solver_information_op import LinearSolverInformationOp -class ProjectedRHSInformationOp(LinearSolverInformationOp): - r"""Projected right hand side :math:`s_i \mapsto b^\top s_i = (Ax)^\top s_i` of the linear system. +class ProjectedResidualInformationOp(LinearSolverInformationOp): + r"""Projected residual information operator. - Obtain information about a linear system by projecting the right hand side :math:`b=Ax` onto a given action :math:`s_i` resulting in :math:`y_i = s_i^\top b`. + Obtain information about a linear system by projecting the residual :math:`b-Ax_{i-1}` onto a given action :math:`s_i` resulting in :math:`s_i \mapsto s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1}) = s_i^\top A (x - x_{i-1})`. """ def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: - r"""Projected right hand side :math:`s_i^\top b = s_i^\top Ax` of the linear system. + r"""Projected residual :math:`s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1})` of the linear system. Parameters ---------- solver_state : Current state of the linear solver. """ - return solver_state.action @ solver_state.problem.b + return solver_state.action.T @ solver_state.residual diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 84f1b96ec..4c521a61d 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -43,22 +43,27 @@ def __init__( self._reorthogonalization_fn_action = reorthogonalization_fn_action def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: residual = solver_state.residual - if self._reorthogonalization_fn_residual is not None and solver_state.step == 0: - solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual] + if solver_state.step == 0: + if self._reorthogonalization_fn_residual is not None: + solver_state.cache["reorthogonalized_residuals"].append( + solver_state.residual + ) - if solver_state.step > 0: + return residual + else: # Reorthogonalization of the residual if self._reorthogonalization_fn_residual is not None: - residual, prev_residual = self._reorthogonalized_residuals( + residual, prev_residual = self._reorthogonalized_residual( solver_state=solver_state ) else: - residual = solver_state.residual prev_residual = solver_state.residuals[solver_state.step - 1] # A-conjugacy correction (in exact arithmetic) @@ -67,16 +72,13 @@ def __call__( # Reorthogonalization of the resulting action if self._reorthogonalization_fn_action is not None: - return self._reorthogonalized_action( + action = self._reorthogonalized_action( action=action, solver_state=solver_state ) - else: - action = residual - - return action + return action - def _reorthogonalized_residuals( + def _reorthogonalized_residual( self, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py index a0bc6d2d2..43197b49b 100644 --- a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py +++ b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py @@ -1,5 +1,6 @@ """Base class for policies of probabilistic linear solvers returning actions.""" import abc +from typing import Optional import numpy as np @@ -22,7 +23,9 @@ class LinearSolverPolicy(abc.ABC): @abc.abstractmethod def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: """Return an action for a given solver state. @@ -30,6 +33,8 @@ def __call__( ---------- solver_state Current state of the linear solver. + rng + Random number generator. Returns ------- diff --git a/src/probnum/linalg/solvers/policies/_random_unit_vector.py b/src/probnum/linalg/solvers/policies/_random_unit_vector.py index c8762584d..5e299bcbd 100644 --- a/src/probnum/linalg/solvers/policies/_random_unit_vector.py +++ b/src/probnum/linalg/solvers/policies/_random_unit_vector.py @@ -1,4 +1,7 @@ """Policy returning randomly drawn standard unit vectors.""" + +from typing import Optional + import numpy as np import probnum # pylint: disable="unused-import" @@ -29,7 +32,9 @@ def __init__(self, probabilities: str = "uniform", replace=True) -> None: self.replace = replace def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: nrows = solver_state.problem.A.shape[0] @@ -47,7 +52,7 @@ def __call__( raise NotImplementedError # Sample unit vector - idx = solver_state.rng.choice( + idx = rng.choice( a=nrows, size=1, p=solver_state.cache["row_sample_probs"], diff --git a/tests/test_linalg/test_solvers/cases/belief_updates.py b/tests/test_linalg/test_solvers/cases/belief_updates.py index 628a4fe12..7629ad742 100644 --- a/tests/test_linalg/test_solvers/cases/belief_updates.py +++ b/tests/test_linalg/test_solvers/cases/belief_updates.py @@ -6,8 +6,8 @@ @parametrize(noise_var=[0.0, 0.001, 1.0]) -def case_solution_based_projected_rhs_belief_update(noise_var: float): - return solution_based.SolutionBasedProjectedRHSBeliefUpdate(noise_var=noise_var) +def case_solution_based_projected_residual_belief_update(noise_var: float): + return solution_based.ProjectedResidualBeliefUpdate(noise_var=noise_var) def case_matrix_based_linear_belief_update(): diff --git a/tests/test_linalg/test_solvers/cases/information_ops.py b/tests/test_linalg/test_solvers/cases/information_ops.py index b200bb610..b6323e297 100644 --- a/tests/test_linalg/test_solvers/cases/information_ops.py +++ b/tests/test_linalg/test_solvers/cases/information_ops.py @@ -7,5 +7,5 @@ def case_matvec(): return information_ops.MatVecInformationOp() -def case_projected_rhs(): - return information_ops.ProjectedRHSInformationOp() +def case_projected_residual(): + return information_ops.ProjectedResidualInformationOp() diff --git a/tests/test_linalg/test_solvers/cases/states.py b/tests/test_linalg/test_solvers/cases/states.py index 8d059bf37..589b1b30c 100644 --- a/tests/test_linalg/test_solvers/cases/states.py +++ b/tests/test_linalg/test_solvers/cases/states.py @@ -26,11 +26,9 @@ @case(tags=["initial"]) -def case_initial_state( - rng: np.random.Generator, -): +def case_initial_state(): """Initial state of a linear solver.""" - return linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + return linalg.solvers.LinearSolverState(problem=linsys, prior=prior) @case(tags=["has_action"]) @@ -38,7 +36,7 @@ def case_state( rng: np.random.Generator, ): """State of a linear solver.""" - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) return state @@ -61,7 +59,7 @@ def case_state_matrix_based( ), b=b, ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) state.observation = rng.standard_normal(size=state.problem.A.shape[1]) @@ -85,7 +83,7 @@ def case_state_symmetric_matrix_based( ), b=b, ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) state.observation = rng.standard_normal(size=state.problem.A.shape[1]) @@ -97,9 +95,7 @@ def case_state_solution_based( rng: np.random.Generator, ): """State of a solution-based linear solver.""" - initial_state = linalg.solvers.LinearSolverState( - problem=linsys, prior=prior, rng=rng - ) + initial_state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) initial_state.action = rng.standard_normal(size=initial_state.problem.A.shape[1]) initial_state.observation = rng.standard_normal() @@ -116,5 +112,5 @@ def case_state_converged( x=randvars.Constant(linsys.solution), b=randvars.Constant(linsys.b), ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=belief, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=belief) return state diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py similarity index 83% rename from tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py rename to tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py index 406bf8ae6..374dce6d0 100644 --- a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py @@ -16,7 +16,9 @@ @parametrize_with_cases( - "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" + "belief_update", + cases=cases_belief_updates, + glob="*solution_based_projected_residual*", ) @parametrize_with_cases( "state", @@ -24,7 +26,7 @@ has_tag=["has_action", "has_observation", "solution_based"], ) def test_returns_linear_system_belief( - belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + belief_update: belief_updates.solution_based.ProjectedResidualBeliefUpdate, state: LinearSolverState, ): belief = belief_update(solver_state=state) @@ -33,13 +35,13 @@ def test_returns_linear_system_belief( def test_negative_noise_variance_raises_error(): with pytest.raises(ValueError): - belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate( - noise_var=-1.0 - ) + belief_updates.solution_based.ProjectedResidualBeliefUpdate(noise_var=-1.0) @parametrize_with_cases( - "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" + "belief_update", + cases=cases_belief_updates, + glob="*solution_based_projected_residual*", ) @parametrize_with_cases( "state", @@ -47,7 +49,7 @@ def test_negative_noise_variance_raises_error(): has_tag=["has_action", "has_observation", "solution_based"], ) def test_beliefs_against_naive_implementation( - belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + belief_update: belief_updates.solution_based.ProjectedResidualBeliefUpdate, state: LinearSolverState, ): """Compare the updated belief to a naive implementation.""" @@ -61,8 +63,7 @@ def test_beliefs_against_naive_implementation( noise_var = belief_update._noise_var action_A = action @ state.problem.A - pred = action_A @ belief.x.mean - proj_resid = observ - pred + proj_resid = observ cov_xy = belief.x.cov @ action_A.T gram = action_A @ cov_xy + noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 diff --git a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py index 7b3ffa417..25ad93b3f 100644 --- a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py +++ b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py @@ -80,6 +80,25 @@ def test_non_two_dimensional_raises_value_error(): LinearSystemBelief(A=A, Ainv=Ainv, x=x, b=b[:, None]) +def test_non_randvar_arguments_raises_type_error(): + A = np.eye(5) + Ainv = np.eye(5) + x = np.ones((5, 1)) + b = np.ones((5, 1)) + + with pytest.raises(TypeError): + LinearSystemBelief(x=x) + + with pytest.raises(TypeError): + LinearSystemBelief(Ainv=Ainv) + + with pytest.raises(TypeError): + LinearSystemBelief(x=randvars.Constant(x), A=A) + + with pytest.raises(TypeError): + LinearSystemBelief(x=randvars.Constant(x), b=b) + + def test_induced_solution_belief(rng: np.random.Generator): """Test whether a consistent belief over the solution is inferred from a belief over the inverse.""" diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py similarity index 82% rename from tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py rename to tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py index 6915030c6..e5101b7f2 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py @@ -12,10 +12,12 @@ cases_states = case_modules + ".states" -@parametrize_with_cases("info_op", cases=cases_information_ops, glob="*projected_rhs") +@parametrize_with_cases( + "info_op", cases=cases_information_ops, glob="*projected_residual" +) @parametrize_with_cases("state", cases=cases_states, has_tag=["has_action"]) -def test_is_projected_rhs( +def test_is_projected_residual( info_op: information_ops.LinearSolverInformationOp, state: LinearSolverState ): observation = info_op(state) - np.testing.assert_equal(observation, state.action.T @ state.problem.b) + np.testing.assert_equal(observation, state.action.T @ state.residual) diff --git a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py index 4df954456..0d51acae2 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py +++ b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py @@ -16,7 +16,7 @@ @parametrize_with_cases("state", cases=cases_states) def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolverState): state = copy.deepcopy(state) - action = policy(state) + action = policy(state, rng=np.random.default_rng(1)) assert isinstance(action, np.ndarray) @@ -24,7 +24,7 @@ def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolve @parametrize_with_cases("state", cases=cases_states) def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): state = copy.deepcopy(state) - action = policy(state) + action = policy(state, rng=np.random.default_rng(1)) assert action.shape[0] == state.problem.A.shape[1] @@ -33,10 +33,9 @@ def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): def test_uses_solver_state_random_number_generator( policy: policies.LinearSolverPolicy, state: LinearSolverState ): - """Test whether randomized policies make use of the random number generator stored - in the linear solver state.""" - state = copy.deepcopy(state) - rng_state_pre = state.rng.bit_generator.state["state"]["state"] - _ = policy(state) - rng_state_post = state.rng.bit_generator.state["state"]["state"] + """Test whether randomized policies make use of the given random state.""" + rng = np.random.default_rng(1) + rng_state_pre = rng.bit_generator.state["state"]["state"] + _ = policy(state, rng=rng) + rng_state_post = rng.bit_generator.state["state"]["state"] assert rng_state_pre != rng_state_post diff --git a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py index e3f608d79..8201c2705 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py +++ b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pytest_cases import parametrize_with_cases +from pytest_cases import parametrize, parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, policies @@ -14,10 +14,12 @@ @parametrize_with_cases("policy", cases=cases_policies, glob="*unit_vector*") @parametrize_with_cases("state", cases=cases_states) +@parametrize("seed", [1, 3, 42]) def test_returns_unit_vector( - policy: policies.LinearSolverPolicy, state: LinearSolverState + policy: policies.LinearSolverPolicy, state: LinearSolverState, seed: int ): - action = policy(state) + rng = np.random.default_rng(seed) + action = policy(state, rng=rng) assert np.linalg.norm(action) == pytest.approx(1.0)