From 86e27f79ace8ce81c9da7f0f282cdeeac5bf9669 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 10:38:01 -0500 Subject: [PATCH 01/14] unify usage of the term residual --- benchmarks/linearsolvers.py | 2 +- src/probnum/linalg/_bayescg.py | 2 +- src/probnum/linalg/solvers/_state.py | 13 ++++++++----- .../solvers/information_ops/_projected_rhs.py | 2 +- .../solvers/policies/_conjugate_gradient.py | 15 ++++++++------- .../solvers/stopping_criteria/_residual_norm.py | 2 +- .../test_policies/test_conjugate_gradient.py | 2 +- .../test_asymmetric.py | 2 +- .../test_symmetric.py | 2 +- tests/test_linalg/test_solvers/test_state.py | 2 +- 10 files changed, 24 insertions(+), 20 deletions(-) diff --git a/benchmarks/linearsolvers.py b/benchmarks/linearsolvers.py index d1343e7a9..473806aa8 100644 --- a/benchmarks/linearsolvers.py +++ b/benchmarks/linearsolvers.py @@ -71,7 +71,7 @@ def peakmem_solve(self, linsys, dim): problinsolve(A=self.linsys.A, b=self.linsys.b) def track_residual_norm(self, linsys, dim): - return np.linalg.norm(self.linsys.A @ self.xhat.mean - self.linsys.b) + return np.linalg.norm(self.linsys.b - self.linsys.A @ self.xhat.mean) def track_error_2norm(self, linsys, dim): return np.linalg.norm(self.linsys.solution - self.xhat.mean) diff --git a/src/probnum/linalg/_bayescg.py b/src/probnum/linalg/_bayescg.py index 0e98f0e46..9d96e295b 100644 --- a/src/probnum/linalg/_bayescg.py +++ b/src/probnum/linalg/_bayescg.py @@ -33,7 +33,7 @@ def bayescg( Maximum number of iterations. Defaults to :math:`10n`, where :math:`n` is the dimension of :math:`A`. atol - Absolute residual tolerance. If :math:`\lVert r_i \rVert = \lVert Ax_i - b + Absolute residual tolerance. If :math:`\lVert r_i \rVert = \lVert b - Ax_i \rVert < \text{atol}`, the iteration terminates. rtol Relative residual tolerance. If :math:`\lVert r_i \rVert < \text{rtol} diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index cd96958cc..7fa0d24e3 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -13,7 +13,10 @@ class LinearSolverState: """State of a probabilistic linear solver. - The solver state separates the state of a probabilistic linear solver from the algorithm itself, making the solver stateless. The state contains the problem to be solved, the current belief over the quantities of interest and any miscellaneous quantities computed during an iteration of a probabilistic linear solver. The solver state is passed between the different components of the solver and may be used internally to cache quantities which are used more than once. + The solver state separates the state of a probabilistic linear solver from the algorithm itself, making the solver stateless. + The state contains the problem to be solved, the current belief over the quantities of interest and any miscellaneous quantities + computed during an iteration of a probabilistic linear solver. The solver state is passed between the different components of the + solver and may be used internally to cache quantities which are used more than once. Parameters ---------- @@ -42,7 +45,7 @@ def __init__( self._actions: List[np.ndarray] = [None] self._observations: List[Any] = [None] self._residuals: List[np.ndarray] = [ - self.problem.A @ self.belief.x.mean - self.problem.b, + self.problem.b - self.problem.A @ self.belief.x.mean, ] self.cache: Dict[str, Any] = {} @@ -107,16 +110,16 @@ def observations(self) -> Tuple[Any, ...]: @property def residual(self) -> np.ndarray: - r"""Cached residual :math:`Ax_i-b` for the current solution estimate :math:`x_i`.""" + r"""Cached residual :math:`b - Ax_i` for the current solution estimate :math:`x_i`.""" if self._residuals[self.step] is None: self._residuals[self.step] = ( - self.problem.A @ self.belief.x.mean - self.problem.b + self.problem.b - self.problem.A @ self.belief.x.mean ) return self._residuals[self.step] @property def residuals(self) -> Tuple[np.ndarray, ...]: - r"""Residuals :math:`\{Ax_i - b\}_i`.""" + r"""Residuals :math:`\{b - Ax_i\}_i`.""" return tuple(self._residuals) def next_step(self) -> None: diff --git a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py b/src/probnum/linalg/solvers/information_ops/_projected_rhs.py index f9d7fcdbf..501bcdcad 100644 --- a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py +++ b/src/probnum/linalg/solvers/information_ops/_projected_rhs.py @@ -1,4 +1,4 @@ -"""Information operator returning a projection of the residual.""" +"""Information operator returning a projection of the right-hand-side.""" import numpy as np import probnum # pylint: disable="unused-import" diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 019e52e2c..84f1b96ec 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -13,16 +13,17 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy): r"""Policy returning :math:`A`-conjugate actions. - Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`. + Selects the negative gradient / residual as an initial action :math:`s_0 = b - A x_0` and then successively + generates :math:`A`-conjugate actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff :math:`i \neq j`. Parameters ---------- reorthogonalization_fn_residual - Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` - the residuals are reorthogonalized before the action is computed. + Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and + returns a reorthogonalized vector. If not `None` the residuals are reorthogonalized before the action is computed. reorthogonalization_fn_action - Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and returns a reorthogonalized vector. If not `None` - the computed action is reorthogonalized. + Reorthogonalization function, which takes a vector, an orthogonal basis and optionally an inner product and + returns a reorthogonalized vector. If not `None` the computed action is reorthogonalized. """ def __init__( @@ -62,7 +63,7 @@ def __call__( # A-conjugacy correction (in exact arithmetic) beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 - action = -residual + beta * solver_state.actions[solver_state.step - 1] + action = residual + beta * solver_state.actions[solver_state.step - 1] # Reorthogonalization of the resulting action if self._reorthogonalization_fn_action is not None: @@ -71,7 +72,7 @@ def __call__( ) else: - action = -residual + action = residual return action diff --git a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py index 0ed36e14e..bbc3cad7c 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py @@ -11,7 +11,7 @@ class ResidualNormStoppingCriterion(LinearSolverStoppingCriterion): r"""Residual stopping criterion. - Terminate when the euclidean norm of the residual :math:`r_{i} = A x_{i} - b` is + Terminate when the euclidean norm of the residual :math:`r_{i} = b - A x_{i}` is sufficiently small, i.e. if it satisfies :math:`\lVert r_i \rVert_2 \leq \max( \text{atol}, \text{rtol} \lVert b \rVert_2)`. diff --git a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py index 96cc5917e..60f5d142b 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py +++ b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py @@ -21,7 +21,7 @@ def test_initial_action_is_negative_gradient( assert state.step == 0 state = copy.deepcopy(state) action = policy(state) - np.testing.assert_allclose(action, -state.residual) + np.testing.assert_allclose(action, state.residual) @parametrize_with_cases("policy", cases=cases_policies, glob="*conjugate_*") diff --git a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py index 61bb4e6ac..8f8b8ebac 100644 --- a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py +++ b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_asymmetric.py @@ -27,7 +27,7 @@ def test_small_residual( prior=prior, problem=problem, rng=np.random.default_rng(42) ) - residual_norm = np.linalg.norm(problem.A @ belief.x.mean - problem.b, ord=2) + residual_norm = np.linalg.norm(problem.b - problem.A @ belief.x.mean, ord=2) assert residual_norm < 1e-5 or residual_norm < 1e-5 * np.linalg.norm( problem.b, ord=2 diff --git a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py index 3f0cdd5c7..734342e7d 100644 --- a/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py +++ b/tests/test_linalg/test_solvers/test_probabilistic_linear_solver/test_symmetric.py @@ -26,7 +26,7 @@ def test_small_residual( prior=prior, problem=problem, rng=np.random.default_rng(42) ) - residual_norm = np.linalg.norm(problem.A @ belief.x.mean - problem.b, ord=2) + residual_norm = np.linalg.norm(problem.b - problem.A @ belief.x.mean, ord=2) assert residual_norm < 1e-5 or residual_norm < 1e-5 * np.linalg.norm( problem.b, ord=2 diff --git a/tests/test_linalg/test_solvers/test_state.py b/tests/test_linalg/test_solvers/test_state.py index cd5cd2476..00b64eff8 100644 --- a/tests/test_linalg/test_solvers/test_state.py +++ b/tests/test_linalg/test_solvers/test_state.py @@ -12,7 +12,7 @@ def test_residual(state: LinearSolverState): """Test whether the state computes the residual correctly.""" linsys = state.problem - residual = linsys.A @ state.belief.x.mean - linsys.b + residual = linsys.b - linsys.A @ state.belief.x.mean np.testing.assert_allclose(residual, state.residual) From 459613323507a34053c9eb36b846ba94de220af1 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 12:06:43 -0500 Subject: [PATCH 02/14] replaced projected rhs information op with projected residual information op --- .../solvers/_probabilistic_linear_solver.py | 12 ++++++------ .../belief_updates/solution_based/__init__.py | 8 +++----- ...y => _projected_residual_belief_update.py} | 19 ++++++------------- .../solvers/information_ops/__init__.py | 6 +++--- .../_linear_solver_information_op.py | 2 +- ...rojected_rhs.py => _projected_residual.py} | 10 +++++----- .../test_solvers/cases/belief_updates.py | 4 ++-- .../test_solvers/cases/information_ops.py | 4 ++-- ... test_projected_residual_belief_update.py} | 19 ++++++++++--------- ...cted_rhs.py => test_projected_residual.py} | 8 +++++--- 10 files changed, 43 insertions(+), 49 deletions(-) rename src/probnum/linalg/solvers/belief_updates/solution_based/{_solution_based_proj_rhs_belief_update.py => _projected_residual_belief_update.py} (73%) rename src/probnum/linalg/solvers/information_ops/{_projected_rhs.py => _projected_residual.py} (54%) rename tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/{test_solution_based_proj_rhs_belief_update.py => test_projected_residual_belief_update.py} (83%) rename tests/test_linalg/test_solvers/test_information_ops/{test_projected_rhs.py => test_projected_residual.py} (82%) diff --git a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py index 9d999fb6c..fa2f7275f 100644 --- a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py +++ b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py @@ -90,8 +90,8 @@ class ProbabilisticLinearSolver( >>> pls = ProbabilisticLinearSolver( ... policy=policies.ConjugateGradientPolicy(), - ... information_op=information_ops.ProjectedRHSInformationOp(), - ... belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + ... information_op=information_ops.ProjectedResidualInformationOp(), + ... belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), ... stopping_criterion=( ... stopping_criteria.MaxIterationsStoppingCriterion(100) ... | stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5) @@ -234,8 +234,8 @@ def __init__( ): super().__init__( policy=policies.ConjugateGradientPolicy(), - information_op=information_ops.ProjectedRHSInformationOp(), - belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + information_op=information_ops.ProjectedResidualInformationOp(), + belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), stopping_criterion=stopping_criterion, ) @@ -267,8 +267,8 @@ def __init__( ): super().__init__( policy=policies.RandomUnitVectorPolicy(), - information_op=information_ops.ProjectedRHSInformationOp(), - belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), + information_op=information_ops.ProjectedResidualInformationOp(), + belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(), stopping_criterion=stopping_criterion, ) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py index fec2b1a72..01475f525 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py @@ -1,15 +1,13 @@ """Solution-based belief updates for the quantities of interest of a linear system.""" -from ._solution_based_proj_rhs_belief_update import ( - SolutionBasedProjectedRHSBeliefUpdate, -) +from ._projected_residual_belief_update import ProjectedResidualBeliefUpdate # Public classes and functions. Order is reflected in documentation. __all__ = [ - "SolutionBasedProjectedRHSBeliefUpdate", + "ProjectedResidualBeliefUpdate", ] # Set correct module paths. Corrects links and module paths in documentation. -SolutionBasedProjectedRHSBeliefUpdate.__module__ = ( +ProjectedResidualBeliefUpdate.__module__ = ( "probnum.linalg.solvers.belief_updates.solution_based" ) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py similarity index 73% rename from src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py rename to src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py index 1d6f4d337..3dcd67960 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py @@ -10,14 +10,14 @@ from .._linear_solver_belief_update import LinearSolverBeliefUpdate -class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): - r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information. +class ProjectedResidualBeliefUpdate(LinearSolverBeliefUpdate): + r"""Gaussian belief update given projected residual information. - Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that + Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`s\^top r_i = s^\top (b - Ax_i) = s^\top A (x - x_i)`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, such that .. math :: \begin{align} - x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\ + x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top r_i,\\ \Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i, \end{align} @@ -28,11 +28,6 @@ class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): ---------- noise_var : Variance of the scalar observation noise. - - References - ---------- - .. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian - Analysis*, 2019, 14, 937-1012 """ def __init__(self, noise_var: FloatLike = 0.0) -> None: @@ -44,12 +39,10 @@ def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> LinearSystemBelief: - # Compute projected residual - action_A = solver_state.action @ solver_state.problem.A - pred = action_A @ solver_state.belief.x.mean - proj_resid = solver_state.observation - pred + proj_resid = solver_state.observation # Compute gain and covariance update + action_A = solver_state.action @ solver_state.problem.A cov_xy = solver_state.belief.x.cov @ action_A.T gram = action_A @ cov_xy + self._noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 diff --git a/src/probnum/linalg/solvers/information_ops/__init__.py b/src/probnum/linalg/solvers/information_ops/__init__.py index 6c3b14e73..1e7869a01 100644 --- a/src/probnum/linalg/solvers/information_ops/__init__.py +++ b/src/probnum/linalg/solvers/information_ops/__init__.py @@ -9,16 +9,16 @@ from ._linear_solver_information_op import LinearSolverInformationOp from ._matvec import MatVecInformationOp -from ._projected_rhs import ProjectedRHSInformationOp +from ._projected_residual import ProjectedResidualInformationOp # Public classes and functions. Order is reflected in documentation. __all__ = [ "LinearSolverInformationOp", "MatVecInformationOp", - "ProjectedRHSInformationOp", + "ProjectedResidualInformationOp", ] # Set correct module paths. Corrects links and module paths in documentation. LinearSolverInformationOp.__module__ = "probnum.linalg.solvers.information_ops" MatVecInformationOp.__module__ = "probnum.linalg.solvers.information_ops" -ProjectedRHSInformationOp.__module__ = "probnum.linalg.solvers.information_ops" +ProjectedResidualInformationOp.__module__ = "probnum.linalg.solvers.information_ops" diff --git a/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py index 41aafcce2..6a83d82ce 100644 --- a/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py +++ b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py @@ -14,7 +14,7 @@ class LinearSolverInformationOp(abc.ABC): See Also -------- MatVecInformationOp: Collect information via matrix-vector multiplication. - ProjectedRHSInformationOp: Collect information via a projection of the current residual. + ProjectedResidualInformationOp: Collect information via a projection of the current residual. """ @abc.abstractmethod diff --git a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py b/src/probnum/linalg/solvers/information_ops/_projected_residual.py similarity index 54% rename from src/probnum/linalg/solvers/information_ops/_projected_rhs.py rename to src/probnum/linalg/solvers/information_ops/_projected_residual.py index 501bcdcad..ed566ce15 100644 --- a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py +++ b/src/probnum/linalg/solvers/information_ops/_projected_residual.py @@ -6,20 +6,20 @@ from ._linear_solver_information_op import LinearSolverInformationOp -class ProjectedRHSInformationOp(LinearSolverInformationOp): - r"""Projected right hand side :math:`s_i \mapsto b^\top s_i = (Ax)^\top s_i` of the linear system. +class ProjectedResidualInformationOp(LinearSolverInformationOp): + r"""Projected residual information operator. - Obtain information about a linear system by projecting the right hand side :math:`b=Ax` onto a given action :math:`s_i` resulting in :math:`y_i = s_i^\top b`. + Obtain information about a linear system by projecting the residual :math:`b-Ax_{i-1}` onto a given action :math:`s_i` resulting in :math:`s_i \mapsto s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1}) = s_i^\top A (x - x_{i-1})`. """ def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: - r"""Projected right hand side :math:`s_i^\top b = s_i^\top Ax` of the linear system. + r"""Projected residual :math:`s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1})` of the linear system. Parameters ---------- solver_state : Current state of the linear solver. """ - return solver_state.action @ solver_state.problem.b + return solver_state.action @ solver_state.residual diff --git a/tests/test_linalg/test_solvers/cases/belief_updates.py b/tests/test_linalg/test_solvers/cases/belief_updates.py index 628a4fe12..7629ad742 100644 --- a/tests/test_linalg/test_solvers/cases/belief_updates.py +++ b/tests/test_linalg/test_solvers/cases/belief_updates.py @@ -6,8 +6,8 @@ @parametrize(noise_var=[0.0, 0.001, 1.0]) -def case_solution_based_projected_rhs_belief_update(noise_var: float): - return solution_based.SolutionBasedProjectedRHSBeliefUpdate(noise_var=noise_var) +def case_solution_based_projected_residual_belief_update(noise_var: float): + return solution_based.ProjectedResidualBeliefUpdate(noise_var=noise_var) def case_matrix_based_linear_belief_update(): diff --git a/tests/test_linalg/test_solvers/cases/information_ops.py b/tests/test_linalg/test_solvers/cases/information_ops.py index b200bb610..b6323e297 100644 --- a/tests/test_linalg/test_solvers/cases/information_ops.py +++ b/tests/test_linalg/test_solvers/cases/information_ops.py @@ -7,5 +7,5 @@ def case_matvec(): return information_ops.MatVecInformationOp() -def case_projected_rhs(): - return information_ops.ProjectedRHSInformationOp() +def case_projected_residual(): + return information_ops.ProjectedResidualInformationOp() diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py similarity index 83% rename from tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py rename to tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py index 406bf8ae6..374dce6d0 100644 --- a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_projected_residual_belief_update.py @@ -16,7 +16,9 @@ @parametrize_with_cases( - "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" + "belief_update", + cases=cases_belief_updates, + glob="*solution_based_projected_residual*", ) @parametrize_with_cases( "state", @@ -24,7 +26,7 @@ has_tag=["has_action", "has_observation", "solution_based"], ) def test_returns_linear_system_belief( - belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + belief_update: belief_updates.solution_based.ProjectedResidualBeliefUpdate, state: LinearSolverState, ): belief = belief_update(solver_state=state) @@ -33,13 +35,13 @@ def test_returns_linear_system_belief( def test_negative_noise_variance_raises_error(): with pytest.raises(ValueError): - belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate( - noise_var=-1.0 - ) + belief_updates.solution_based.ProjectedResidualBeliefUpdate(noise_var=-1.0) @parametrize_with_cases( - "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" + "belief_update", + cases=cases_belief_updates, + glob="*solution_based_projected_residual*", ) @parametrize_with_cases( "state", @@ -47,7 +49,7 @@ def test_negative_noise_variance_raises_error(): has_tag=["has_action", "has_observation", "solution_based"], ) def test_beliefs_against_naive_implementation( - belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + belief_update: belief_updates.solution_based.ProjectedResidualBeliefUpdate, state: LinearSolverState, ): """Compare the updated belief to a naive implementation.""" @@ -61,8 +63,7 @@ def test_beliefs_against_naive_implementation( noise_var = belief_update._noise_var action_A = action @ state.problem.A - pred = action_A @ belief.x.mean - proj_resid = observ - pred + proj_resid = observ cov_xy = belief.x.cov @ action_A.T gram = action_A @ cov_xy + noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py similarity index 82% rename from tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py rename to tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py index 6915030c6..e5101b7f2 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_projected_residual.py @@ -12,10 +12,12 @@ cases_states = case_modules + ".states" -@parametrize_with_cases("info_op", cases=cases_information_ops, glob="*projected_rhs") +@parametrize_with_cases( + "info_op", cases=cases_information_ops, glob="*projected_residual" +) @parametrize_with_cases("state", cases=cases_states, has_tag=["has_action"]) -def test_is_projected_rhs( +def test_is_projected_residual( info_op: information_ops.LinearSolverInformationOp, state: LinearSolverState ): observation = info_op(state) - np.testing.assert_equal(observation, state.action.T @ state.problem.b) + np.testing.assert_equal(observation, state.action.T @ state.residual) From c27f0e4331072c94a70893391d5cc696b8b19675 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 13:00:55 -0500 Subject: [PATCH 03/14] removed random number generator from state --- .../solvers/_probabilistic_linear_solver.py | 4 ++-- src/probnum/linalg/solvers/_state.py | 18 ++++++++++++------ .../solvers/policies/_conjugate_gradient.py | 8 +++++--- .../solvers/policies/_linear_solver_policy.py | 7 ++++++- .../solvers/policies/_random_unit_vector.py | 9 +++++++-- tests/test_linalg/test_solvers/cases/states.py | 18 +++++++----------- .../test_policies/test_linear_solver_policy.py | 15 +++++++-------- .../test_policies/test_random_unit_vector.py | 8 +++++--- 8 files changed, 51 insertions(+), 36 deletions(-) diff --git a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py index fa2f7275f..e6939abab 100644 --- a/src/probnum/linalg/solvers/_probabilistic_linear_solver.py +++ b/src/probnum/linalg/solvers/_probabilistic_linear_solver.py @@ -152,7 +152,7 @@ def solve_iterator( solver_state State of the probabilistic linear solver. """ - solver_state = LinearSolverState(problem=problem, prior=prior, rng=rng) + solver_state = LinearSolverState(problem=problem, prior=prior) while True: @@ -163,7 +163,7 @@ def solve_iterator( break # Compute action via policy - solver_state.action = self.policy(solver_state=solver_state) + solver_state.action = self.policy(solver_state=solver_state, rng=rng) # Make observation via information operator solver_state.observation = self.information_op(solver_state=solver_state) diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 7fa0d24e3..3db84d65d 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -24,21 +24,17 @@ class LinearSolverState: Linear system to be solved. prior Prior belief over the quantities of interest of the linear system. - rng - Random number generator. """ def __init__( self, problem: problems.LinearSystem, prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief", - rng: Optional[np.random.Generator] = None, ): - self.rng: Optional[np.random.Generator] = rng - self.problem: problems.LinearSystem = problem + self._problem: problems.LinearSystem = problem # Belief - self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior + self._prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior # Caches @@ -55,6 +51,16 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(step={self.step})" + @property + def problem(self) -> problems.LinearSystem: + """Linear system to be solved.""" + return self._problem + + @property + def prior(self) -> "probnum.linalg.solvers.beliefs.LinearSystemBelief": + """Prior belief over the quantities of interest of the linear system.""" + return self._prior + @property def step(self) -> int: """Current step of the solver.""" diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 84f1b96ec..f88833f4a 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -43,7 +43,9 @@ def __init__( self._reorthogonalization_fn_action = reorthogonalization_fn_action def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: residual = solver_state.residual @@ -54,7 +56,7 @@ def __call__( if solver_state.step > 0: # Reorthogonalization of the residual if self._reorthogonalization_fn_residual is not None: - residual, prev_residual = self._reorthogonalized_residuals( + residual, prev_residual = self._reorthogonalized_residual( solver_state=solver_state ) else: @@ -76,7 +78,7 @@ def __call__( return action - def _reorthogonalized_residuals( + def _reorthogonalized_residual( self, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py index a0bc6d2d2..43197b49b 100644 --- a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py +++ b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py @@ -1,5 +1,6 @@ """Base class for policies of probabilistic linear solvers returning actions.""" import abc +from typing import Optional import numpy as np @@ -22,7 +23,9 @@ class LinearSolverPolicy(abc.ABC): @abc.abstractmethod def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: """Return an action for a given solver state. @@ -30,6 +33,8 @@ def __call__( ---------- solver_state Current state of the linear solver. + rng + Random number generator. Returns ------- diff --git a/src/probnum/linalg/solvers/policies/_random_unit_vector.py b/src/probnum/linalg/solvers/policies/_random_unit_vector.py index c8762584d..5e299bcbd 100644 --- a/src/probnum/linalg/solvers/policies/_random_unit_vector.py +++ b/src/probnum/linalg/solvers/policies/_random_unit_vector.py @@ -1,4 +1,7 @@ """Policy returning randomly drawn standard unit vectors.""" + +from typing import Optional + import numpy as np import probnum # pylint: disable="unused-import" @@ -29,7 +32,9 @@ def __init__(self, probabilities: str = "uniform", replace=True) -> None: self.replace = replace def __call__( - self, solver_state: "probnum.linalg.solvers.LinearSolverState" + self, + solver_state: "probnum.linalg.solvers.LinearSolverState", + rng: Optional[np.random.Generator] = None, ) -> np.ndarray: nrows = solver_state.problem.A.shape[0] @@ -47,7 +52,7 @@ def __call__( raise NotImplementedError # Sample unit vector - idx = solver_state.rng.choice( + idx = rng.choice( a=nrows, size=1, p=solver_state.cache["row_sample_probs"], diff --git a/tests/test_linalg/test_solvers/cases/states.py b/tests/test_linalg/test_solvers/cases/states.py index 8d059bf37..589b1b30c 100644 --- a/tests/test_linalg/test_solvers/cases/states.py +++ b/tests/test_linalg/test_solvers/cases/states.py @@ -26,11 +26,9 @@ @case(tags=["initial"]) -def case_initial_state( - rng: np.random.Generator, -): +def case_initial_state(): """Initial state of a linear solver.""" - return linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + return linalg.solvers.LinearSolverState(problem=linsys, prior=prior) @case(tags=["has_action"]) @@ -38,7 +36,7 @@ def case_state( rng: np.random.Generator, ): """State of a linear solver.""" - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) return state @@ -61,7 +59,7 @@ def case_state_matrix_based( ), b=b, ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) state.observation = rng.standard_normal(size=state.problem.A.shape[1]) @@ -85,7 +83,7 @@ def case_state_symmetric_matrix_based( ), b=b, ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) state.action = rng.standard_normal(size=state.problem.A.shape[1]) state.observation = rng.standard_normal(size=state.problem.A.shape[1]) @@ -97,9 +95,7 @@ def case_state_solution_based( rng: np.random.Generator, ): """State of a solution-based linear solver.""" - initial_state = linalg.solvers.LinearSolverState( - problem=linsys, prior=prior, rng=rng - ) + initial_state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior) initial_state.action = rng.standard_normal(size=initial_state.problem.A.shape[1]) initial_state.observation = rng.standard_normal() @@ -116,5 +112,5 @@ def case_state_converged( x=randvars.Constant(linsys.solution), b=randvars.Constant(linsys.b), ) - state = linalg.solvers.LinearSolverState(problem=linsys, prior=belief, rng=rng) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=belief) return state diff --git a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py index 4df954456..0d51acae2 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py +++ b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py @@ -16,7 +16,7 @@ @parametrize_with_cases("state", cases=cases_states) def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolverState): state = copy.deepcopy(state) - action = policy(state) + action = policy(state, rng=np.random.default_rng(1)) assert isinstance(action, np.ndarray) @@ -24,7 +24,7 @@ def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolve @parametrize_with_cases("state", cases=cases_states) def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): state = copy.deepcopy(state) - action = policy(state) + action = policy(state, rng=np.random.default_rng(1)) assert action.shape[0] == state.problem.A.shape[1] @@ -33,10 +33,9 @@ def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): def test_uses_solver_state_random_number_generator( policy: policies.LinearSolverPolicy, state: LinearSolverState ): - """Test whether randomized policies make use of the random number generator stored - in the linear solver state.""" - state = copy.deepcopy(state) - rng_state_pre = state.rng.bit_generator.state["state"]["state"] - _ = policy(state) - rng_state_post = state.rng.bit_generator.state["state"]["state"] + """Test whether randomized policies make use of the given random state.""" + rng = np.random.default_rng(1) + rng_state_pre = rng.bit_generator.state["state"]["state"] + _ = policy(state, rng=rng) + rng_state_post = rng.bit_generator.state["state"]["state"] assert rng_state_pre != rng_state_post diff --git a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py index e3f608d79..8201c2705 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py +++ b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pytest_cases import parametrize_with_cases +from pytest_cases import parametrize, parametrize_with_cases from probnum.linalg.solvers import LinearSolverState, policies @@ -14,10 +14,12 @@ @parametrize_with_cases("policy", cases=cases_policies, glob="*unit_vector*") @parametrize_with_cases("state", cases=cases_states) +@parametrize("seed", [1, 3, 42]) def test_returns_unit_vector( - policy: policies.LinearSolverPolicy, state: LinearSolverState + policy: policies.LinearSolverPolicy, state: LinearSolverState, seed: int ): - action = policy(state) + rng = np.random.default_rng(seed) + action = policy(state, rng=rng) assert np.linalg.norm(action) == pytest.approx(1.0) From 83590ce2fd147852ead901f7583d2bb0a177f749 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 13:40:52 -0500 Subject: [PATCH 04/14] turned state.cache into a defaultdict --- src/probnum/linalg/solvers/_state.py | 10 +++++----- .../solvers/policies/_conjugate_gradient.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 3db84d65d..6d02ff02c 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -1,7 +1,7 @@ """State of a probabilistic linear solver.""" - +from collections import defaultdict import dataclasses -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, DefaultDict, List, Optional, Tuple import numpy as np @@ -43,7 +43,7 @@ def __init__( self._residuals: List[np.ndarray] = [ self.problem.b - self.problem.A @ self.belief.x.mean, ] - self.cache: Dict[str, Any] = {} + self.cache: DefaultDict[str, Any] = defaultdict(list) # Solver info self._step: int = 0 @@ -116,7 +116,7 @@ def observations(self) -> Tuple[Any, ...]: @property def residual(self) -> np.ndarray: - r"""Cached residual :math:`b - Ax_i` for the current solution estimate :math:`x_i`.""" + r"""Cached residual :math:`r_{i-1} = b - Ax_{i-1}` for the current solution estimate :math:`x_{i-1}`.""" if self._residuals[self.step] is None: self._residuals[self.step] = ( self.problem.b - self.problem.A @ self.belief.x.mean @@ -125,7 +125,7 @@ def residual(self) -> np.ndarray: @property def residuals(self) -> Tuple[np.ndarray, ...]: - r"""Residuals :math:`\{b - Ax_i\}_i`.""" + r"""Residuals :math:`\{b - Ax_j\}_j^{i-1}`.""" return tuple(self._residuals) def next_step(self) -> None: diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index f88833f4a..4c521a61d 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -50,17 +50,20 @@ def __call__( residual = solver_state.residual - if self._reorthogonalization_fn_residual is not None and solver_state.step == 0: - solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual] + if solver_state.step == 0: + if self._reorthogonalization_fn_residual is not None: + solver_state.cache["reorthogonalized_residuals"].append( + solver_state.residual + ) - if solver_state.step > 0: + return residual + else: # Reorthogonalization of the residual if self._reorthogonalization_fn_residual is not None: residual, prev_residual = self._reorthogonalized_residual( solver_state=solver_state ) else: - residual = solver_state.residual prev_residual = solver_state.residuals[solver_state.step - 1] # A-conjugacy correction (in exact arithmetic) @@ -69,14 +72,11 @@ def __call__( # Reorthogonalization of the resulting action if self._reorthogonalization_fn_action is not None: - return self._reorthogonalized_action( + action = self._reorthogonalized_action( action=action, solver_state=solver_state ) - else: - action = residual - - return action + return action def _reorthogonalized_residual( self, From 384775d927d0186c4d15bcbde505ab5baa0ad47a Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 14:05:46 -0500 Subject: [PATCH 05/14] minor docs improvements --- src/probnum/linalg/solvers/_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 6d02ff02c..e58d9d488 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -116,7 +116,7 @@ def observations(self) -> Tuple[Any, ...]: @property def residual(self) -> np.ndarray: - r"""Cached residual :math:`r_{i-1} = b - Ax_{i-1}` for the current solution estimate :math:`x_{i-1}`.""" + r"""Residual :math:`r_{i} = b - Ax_{i}`.""" if self._residuals[self.step] is None: self._residuals[self.step] = ( self.problem.b - self.problem.A @ self.belief.x.mean @@ -125,7 +125,7 @@ def residual(self) -> np.ndarray: @property def residuals(self) -> Tuple[np.ndarray, ...]: - r"""Residuals :math:`\{b - Ax_j\}_j^{i-1}`.""" + r"""Residuals :math:`\{b - Ax_i\}_i`.""" return tuple(self._residuals) def next_step(self) -> None: From f65366f833fd15abad564426f09a0847079f57fb Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 14:12:38 -0500 Subject: [PATCH 06/14] minor --- .../solution_based/_projected_residual_belief_update.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py index 3dcd67960..7fe8dc0bd 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py @@ -44,7 +44,7 @@ def __call__( # Compute gain and covariance update action_A = solver_state.action @ solver_state.problem.A cov_xy = solver_state.belief.x.cov @ action_A.T - gram = action_A @ cov_xy + self._noise_var + gram = action_A @ cov_xy + self.noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 gain = cov_xy * gram_pinv cov_update = np.outer(gain, cov_xy) @@ -61,3 +61,8 @@ def __call__( return LinearSystemBelief( x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b ) + + @property + def noise_var(self) -> float: + """Observation noise.""" + return self._noise_var From 0e512bb580e9c4a5e42dee0cbfe8f1840624aea5 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 15:03:05 -0500 Subject: [PATCH 07/14] raise type errors if linear system belief is constructed from wrong types --- .../solvers/beliefs/_linear_system_belief.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 8fa830032..391983fc6 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -100,6 +100,23 @@ def dim_mismatch_error(**kwargs): f"Belief over right-hand-side may have either one or two dimensions but has {b.ndim}." ) + if x is not None and not isinstance(randvars.RandomVariable): + raise TypeError( + f"The belief about the solution 'x' must be a RandomVariable, but is {type(x)}." + ) + if A is not None and not isinstance(randvars.RandomVariable): + raise TypeError( + f"The belief about the matrix 'A' must be a RandomVariable, but is {type(A)}." + ) + if Ainv is not None and not isinstance(randvars.RandomVariable): + raise TypeError( + f"The belief about the inverse matrix 'Ainv' must be a RandomVariable, but is {type(Ainv)}." + ) + if b is not None and not isinstance(randvars.RandomVariable): + raise TypeError( + f"The belief about the right-hand-side 'b' must be a RandomVariable, but is {type(b)}." + ) + self._x = x self._A = A self._Ainv = Ainv From e9dd397d4d28b9f997ef3c7a2a83294ef970b8a4 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 15:10:42 -0500 Subject: [PATCH 08/14] type error check bugs fixed --- .../solvers/beliefs/_linear_system_belief.py | 8 ++++---- .../test_beliefs/test_linear_system_belief.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 391983fc6..6c670a54f 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -100,19 +100,19 @@ def dim_mismatch_error(**kwargs): f"Belief over right-hand-side may have either one or two dimensions but has {b.ndim}." ) - if x is not None and not isinstance(randvars.RandomVariable): + if x is not None and not isinstance(x, randvars.RandomVariable): raise TypeError( f"The belief about the solution 'x' must be a RandomVariable, but is {type(x)}." ) - if A is not None and not isinstance(randvars.RandomVariable): + if A is not None and not isinstance(A, randvars.RandomVariable): raise TypeError( f"The belief about the matrix 'A' must be a RandomVariable, but is {type(A)}." ) - if Ainv is not None and not isinstance(randvars.RandomVariable): + if Ainv is not None and not isinstance(Ainv, randvars.RandomVariable): raise TypeError( f"The belief about the inverse matrix 'Ainv' must be a RandomVariable, but is {type(Ainv)}." ) - if b is not None and not isinstance(randvars.RandomVariable): + if b is not None and not isinstance(b, randvars.RandomVariable): raise TypeError( f"The belief about the right-hand-side 'b' must be a RandomVariable, but is {type(b)}." ) diff --git a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py index 7b3ffa417..9ee46d524 100644 --- a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py +++ b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py @@ -80,6 +80,25 @@ def test_non_two_dimensional_raises_value_error(): LinearSystemBelief(A=A, Ainv=Ainv, x=x, b=b[:, None]) +def test_non_randvar_arguments_raises_type_error(): + A = np.eye(5) + Ainv = np.eye(5) + x = np.ones((5, 1)) + b = np.ones((5, 1)) + + with pytest.raises(TypeError): + LinearSystemBelief(x=x) + + with pytest.raises(TypeError): + LinearSystemBelief(Ainv=Ainv) + + with pytest.raises(TypeError): + LinearSystemBelief(A=A) + + with pytest.raises(TypeError): + LinearSystemBelief(b=b) + + def test_induced_solution_belief(rng: np.random.Generator): """Test whether a consistent belief over the solution is inferred from a belief over the inverse.""" From efc223171de2ec3fa4c7cad8fc163709cf662e23 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 15:33:04 -0500 Subject: [PATCH 09/14] tests fixed and unused induced belief removed --- .../linalg/solvers/beliefs/_linear_system_belief.py | 13 +------------ .../test_beliefs/test_linear_system_belief.py | 4 ++-- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 6c670a54f..ba380ebf0 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -69,7 +69,7 @@ def dim_mismatch_error(**kwargs): if A.shape[1] != x.shape[0]: raise dim_mismatch_error(A=A, x=x) - if x.ndim > 1: + if x.ndim > 1 and b is not None: if x.shape[1] != b.shape[1]: raise dim_mismatch_error(x=x, b=b) elif b is not None: @@ -151,8 +151,6 @@ def A(self) -> randvars.RandomVariable: @property def Ainv(self) -> Optional[randvars.RandomVariable]: """Belief about the (pseudo-)inverse of the system matrix.""" - if self._Ainv is None: - return self._induced_Ainv() return self._Ainv @property @@ -168,12 +166,3 @@ def _induced_x(self) -> randvars.RandomVariable: :math:`H` and :math:`b`. """ return self.Ainv @ self.b - - def _induced_Ainv(self) -> randvars.RandomVariable: - r"""Induced belief about the inverse from a belief about the solution. - - Computes a consistent belief about the inverse from a belief about the solution. - """ - return randvars.Constant( - linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0])) - ) diff --git a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py index 9ee46d524..25ad93b3f 100644 --- a/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py +++ b/tests/test_linalg/test_solvers/test_beliefs/test_linear_system_belief.py @@ -93,10 +93,10 @@ def test_non_randvar_arguments_raises_type_error(): LinearSystemBelief(Ainv=Ainv) with pytest.raises(TypeError): - LinearSystemBelief(A=A) + LinearSystemBelief(x=randvars.Constant(x), A=A) with pytest.raises(TypeError): - LinearSystemBelief(b=b) + LinearSystemBelief(x=randvars.Constant(x), b=b) def test_induced_solution_belief(rng: np.random.Generator): From 5e5b436cc207ad5a5ca4c6fbbfcfff203f07bd24 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 15:35:33 -0500 Subject: [PATCH 10/14] minor bug in error message fixed --- src/probnum/linalg/solvers/beliefs/_linear_system_belief.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index ba380ebf0..707ef281f 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -79,7 +79,7 @@ def dim_mismatch_error(**kwargs): if Ainv is not None: if Ainv.ndim != 2: raise ValueError( - f"Belief over the inverse system matrix may have at most two dimensions, but has {A.ndim}." + f"Belief over the inverse system matrix may have at most two dimensions, but has {Ainv.ndim}." ) if A is not None: if A.shape != Ainv.shape: From 8e2a2fc7195547d925407540291b7a689310eb69 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 18:10:10 -0500 Subject: [PATCH 11/14] pylint fix --- src/probnum/linalg/solvers/beliefs/_linear_system_belief.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index 707ef281f..d4dc14425 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -7,7 +7,7 @@ from functools import cached_property from typing import Mapping, Optional -from probnum import linops, randvars +from probnum import randvars # pylint: disable="invalid-name" From fcad0b14cebd12ce23ae0f43f71c5c67ec7cc986 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 18:18:49 -0500 Subject: [PATCH 12/14] minor --- .../linalg/solvers/information_ops/_projected_residual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/information_ops/_projected_residual.py b/src/probnum/linalg/solvers/information_ops/_projected_residual.py index ed566ce15..4d5e9a0d2 100644 --- a/src/probnum/linalg/solvers/information_ops/_projected_residual.py +++ b/src/probnum/linalg/solvers/information_ops/_projected_residual.py @@ -22,4 +22,4 @@ def __call__( solver_state : Current state of the linear solver. """ - return solver_state.action @ solver_state.residual + return solver_state.action.T @ solver_state.residual From 779f0dcba82fa2f873359534bf1afcc9fd1270fa Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 3 Mar 2022 18:28:27 -0500 Subject: [PATCH 13/14] minor --- .../solution_based/_projected_residual_belief_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py index 7fe8dc0bd..df63fe25f 100644 --- a/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_projected_residual_belief_update.py @@ -42,7 +42,7 @@ def __call__( proj_resid = solver_state.observation # Compute gain and covariance update - action_A = solver_state.action @ solver_state.problem.A + action_A = solver_state.action.T @ solver_state.problem.A cov_xy = solver_state.belief.x.cov @ action_A.T gram = action_A @ cov_xy + self.noise_var gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 From 85ff994676cc3f51880de082f0febca933498ab3 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Fri, 4 Mar 2022 09:21:52 -0500 Subject: [PATCH 14/14] minor --- src/probnum/linalg/solvers/beliefs/_linear_system_belief.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py index d4dc14425..22eae5491 100644 --- a/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py +++ b/src/probnum/linalg/solvers/beliefs/_linear_system_belief.py @@ -165,4 +165,4 @@ def _induced_x(self) -> randvars.RandomVariable: to) the random variable :math:`x=Hb`. This assumes independence between :math:`H` and :math:`b`. """ - return self.Ainv @ self.b + return randvars.asrandvar(self.Ainv @ self.b)