Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to the probabilistic linear solver implementation #656

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/probnum/linalg/solvers/_probabilistic_linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ class ProbabilisticLinearSolver(

>>> pls = ProbabilisticLinearSolver(
... policy=policies.ConjugateGradientPolicy(),
... information_op=information_ops.ProjectedRHSInformationOp(),
... belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(),
... information_op=information_ops.ProjectedResidualInformationOp(),
... belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(),
... stopping_criterion=(
... stopping_criteria.MaxIterationsStoppingCriterion(100)
... | stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5)
Expand Down Expand Up @@ -152,7 +152,7 @@ def solve_iterator(
solver_state
State of the probabilistic linear solver.
"""
solver_state = LinearSolverState(problem=problem, prior=prior, rng=rng)
solver_state = LinearSolverState(problem=problem, prior=prior)

while True:

Expand All @@ -163,7 +163,7 @@ def solve_iterator(
break

# Compute action via policy
solver_state.action = self.policy(solver_state=solver_state)
solver_state.action = self.policy(solver_state=solver_state, rng=rng)

# Make observation via information operator
solver_state.observation = self.information_op(solver_state=solver_state)
Expand Down Expand Up @@ -234,8 +234,8 @@ def __init__(
):
super().__init__(
policy=policies.ConjugateGradientPolicy(),
information_op=information_ops.ProjectedRHSInformationOp(),
belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(),
information_op=information_ops.ProjectedResidualInformationOp(),
belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(),
stopping_criterion=stopping_criterion,
)

Expand Down Expand Up @@ -267,8 +267,8 @@ def __init__(
):
super().__init__(
policy=policies.RandomUnitVectorPolicy(),
information_op=information_ops.ProjectedRHSInformationOp(),
belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(),
information_op=information_ops.ProjectedResidualInformationOp(),
belief_update=belief_updates.solution_based.ProjectedResidualBeliefUpdate(),
stopping_criterion=stopping_criterion,
)

Expand Down
26 changes: 16 additions & 10 deletions src/probnum/linalg/solvers/_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""State of a probabilistic linear solver."""

from collections import defaultdict
import dataclasses
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, DefaultDict, List, Optional, Tuple

import numpy as np

Expand All @@ -24,21 +24,17 @@ class LinearSolverState:
Linear system to be solved.
prior
Prior belief over the quantities of interest of the linear system.
rng
Random number generator.
"""

def __init__(
self,
problem: problems.LinearSystem,
prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief",
rng: Optional[np.random.Generator] = None,
):
self.rng: Optional[np.random.Generator] = rng
self.problem: problems.LinearSystem = problem
self._problem: problems.LinearSystem = problem

# Belief
self.prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior
self._prior: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior
self._belief: "probnum.linalg.solvers.beliefs.LinearSystemBelief" = prior

# Caches
Expand All @@ -47,14 +43,24 @@ def __init__(
self._residuals: List[np.ndarray] = [
self.problem.b - self.problem.A @ self.belief.x.mean,
]
self.cache: Dict[str, Any] = {}
self.cache: DefaultDict[str, Any] = defaultdict(list)

# Solver info
self._step: int = 0

def __repr__(self) -> str:
return f"{self.__class__.__name__}(step={self.step})"

@property
def problem(self) -> problems.LinearSystem:
"""Linear system to be solved."""
return self._problem

@property
def prior(self) -> "probnum.linalg.solvers.beliefs.LinearSystemBelief":
"""Prior belief over the quantities of interest of the linear system."""
return self._prior

@property
def step(self) -> int:
"""Current step of the solver."""
Expand Down Expand Up @@ -110,7 +116,7 @@ def observations(self) -> Tuple[Any, ...]:

@property
def residual(self) -> np.ndarray:
r"""Cached residual :math:`b - Ax_i` for the current solution estimate :math:`x_i`."""
r"""Residual :math:`r_{i} = b - Ax_{i}`."""
if self._residuals[self.step] is None:
self._residuals[self.step] = (
self.problem.b - self.problem.A @ self.belief.x.mean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Solution-based belief updates for the quantities of interest of a linear system."""

from ._solution_based_proj_rhs_belief_update import (
SolutionBasedProjectedRHSBeliefUpdate,
)
from ._projected_residual_belief_update import ProjectedResidualBeliefUpdate

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"SolutionBasedProjectedRHSBeliefUpdate",
"ProjectedResidualBeliefUpdate",
]

# Set correct module paths. Corrects links and module paths in documentation.
SolutionBasedProjectedRHSBeliefUpdate.__module__ = (
ProjectedResidualBeliefUpdate.__module__ = (
"probnum.linalg.solvers.belief_updates.solution_based"
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from .._linear_solver_belief_update import LinearSolverBeliefUpdate


class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate):
r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information.
class ProjectedResidualBeliefUpdate(LinearSolverBeliefUpdate):
r"""Gaussian belief update given projected residual information.

Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`s\^top r_i = s^\top (b - Ax_i) = s^\top A (x - x_i)`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, such that

.. math ::
\begin{align}
x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\
x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top r_i,\\
\Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i,
\end{align}

Expand All @@ -28,11 +28,6 @@ class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate):
----------
noise_var :
Variance of the scalar observation noise.

References
----------
.. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian
Analysis*, 2019, 14, 937-1012
"""

def __init__(self, noise_var: FloatLike = 0.0) -> None:
Expand All @@ -44,14 +39,12 @@ def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

# Compute projected residual
action_A = solver_state.action @ solver_state.problem.A
pred = action_A @ solver_state.belief.x.mean
proj_resid = solver_state.observation - pred
proj_resid = solver_state.observation

# Compute gain and covariance update
action_A = solver_state.action.T @ solver_state.problem.A
cov_xy = solver_state.belief.x.cov @ action_A.T
gram = action_A @ cov_xy + self._noise_var
gram = action_A @ cov_xy + self.noise_var
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
gain = cov_xy * gram_pinv
cov_update = np.outer(gain, cov_xy)
Expand All @@ -68,3 +61,8 @@ def __call__(
return LinearSystemBelief(
x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b
)

@property
def noise_var(self) -> float:
"""Observation noise."""
return self._noise_var
36 changes: 21 additions & 15 deletions src/probnum/linalg/solvers/beliefs/_linear_system_belief.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from typing import Mapping, Optional

from probnum import linops, randvars
from probnum import randvars

# pylint: disable="invalid-name"

Expand Down Expand Up @@ -69,7 +69,7 @@ def dim_mismatch_error(**kwargs):
if A.shape[1] != x.shape[0]:
raise dim_mismatch_error(A=A, x=x)

if x.ndim > 1:
if x.ndim > 1 and b is not None:
if x.shape[1] != b.shape[1]:
raise dim_mismatch_error(x=x, b=b)
elif b is not None:
Expand All @@ -79,7 +79,7 @@ def dim_mismatch_error(**kwargs):
if Ainv is not None:
if Ainv.ndim != 2:
raise ValueError(
f"Belief over the inverse system matrix may have at most two dimensions, but has {A.ndim}."
f"Belief over the inverse system matrix may have at most two dimensions, but has {Ainv.ndim}."
)
if A is not None:
if A.shape != Ainv.shape:
Expand All @@ -100,6 +100,23 @@ def dim_mismatch_error(**kwargs):
f"Belief over right-hand-side may have either one or two dimensions but has {b.ndim}."
)

if x is not None and not isinstance(x, randvars.RandomVariable):
raise TypeError(
f"The belief about the solution 'x' must be a RandomVariable, but is {type(x)}."
)
if A is not None and not isinstance(A, randvars.RandomVariable):
raise TypeError(
f"The belief about the matrix 'A' must be a RandomVariable, but is {type(A)}."
)
if Ainv is not None and not isinstance(Ainv, randvars.RandomVariable):
raise TypeError(
f"The belief about the inverse matrix 'Ainv' must be a RandomVariable, but is {type(Ainv)}."
)
if b is not None and not isinstance(b, randvars.RandomVariable):
raise TypeError(
f"The belief about the right-hand-side 'b' must be a RandomVariable, but is {type(b)}."
)

self._x = x
self._A = A
self._Ainv = Ainv
Expand Down Expand Up @@ -134,8 +151,6 @@ def A(self) -> randvars.RandomVariable:
@property
def Ainv(self) -> Optional[randvars.RandomVariable]:
"""Belief about the (pseudo-)inverse of the system matrix."""
if self._Ainv is None:
return self._induced_Ainv()
return self._Ainv

@property
Expand All @@ -150,13 +165,4 @@ def _induced_x(self) -> randvars.RandomVariable:
to) the random variable :math:`x=Hb`. This assumes independence between
:math:`H` and :math:`b`.
"""
return self.Ainv @ self.b

def _induced_Ainv(self) -> randvars.RandomVariable:
r"""Induced belief about the inverse from a belief about the solution.

Computes a consistent belief about the inverse from a belief about the solution.
"""
return randvars.Constant(
linops.Scaling(factors=0.0, shape=(self._x.shape[0], self._x.shape[0]))
)
return randvars.asrandvar(self.Ainv @ self.b)
6 changes: 3 additions & 3 deletions src/probnum/linalg/solvers/information_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

from ._linear_solver_information_op import LinearSolverInformationOp
from ._matvec import MatVecInformationOp
from ._projected_rhs import ProjectedRHSInformationOp
from ._projected_residual import ProjectedResidualInformationOp

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"LinearSolverInformationOp",
"MatVecInformationOp",
"ProjectedRHSInformationOp",
"ProjectedResidualInformationOp",
]

# Set correct module paths. Corrects links and module paths in documentation.
LinearSolverInformationOp.__module__ = "probnum.linalg.solvers.information_ops"
MatVecInformationOp.__module__ = "probnum.linalg.solvers.information_ops"
ProjectedRHSInformationOp.__module__ = "probnum.linalg.solvers.information_ops"
ProjectedResidualInformationOp.__module__ = "probnum.linalg.solvers.information_ops"
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LinearSolverInformationOp(abc.ABC):
See Also
--------
MatVecInformationOp: Collect information via matrix-vector multiplication.
ProjectedRHSInformationOp: Collect information via a projection of the current residual.
ProjectedResidualInformationOp: Collect information via a projection of the current residual.
"""

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
from ._linear_solver_information_op import LinearSolverInformationOp


class ProjectedRHSInformationOp(LinearSolverInformationOp):
r"""Projected right hand side :math:`s_i \mapsto b^\top s_i = (Ax)^\top s_i` of the linear system.
class ProjectedResidualInformationOp(LinearSolverInformationOp):
r"""Projected residual information operator.

Obtain information about a linear system by projecting the right hand side :math:`b=Ax` onto a given action :math:`s_i` resulting in :math:`y_i = s_i^\top b`.
Obtain information about a linear system by projecting the residual :math:`b-Ax_{i-1}` onto a given action :math:`s_i` resulting in :math:`s_i \mapsto s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1}) = s_i^\top A (x - x_{i-1})`.
"""

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> np.ndarray:
r"""Projected right hand side :math:`s_i^\top b = s_i^\top Ax` of the linear system.
r"""Projected residual :math:`s_i^\top r_{i-1} = s_i^\top (b - A x_{i-1})` of the linear system.

Parameters
----------
solver_state :
Current state of the linear solver.
"""
return solver_state.action @ solver_state.problem.b
return solver_state.action.T @ solver_state.residual
26 changes: 14 additions & 12 deletions src/probnum/linalg/solvers/policies/_conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,27 @@ def __init__(
self._reorthogonalization_fn_action = reorthogonalization_fn_action

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
self,
solver_state: "probnum.linalg.solvers.LinearSolverState",
rng: Optional[np.random.Generator] = None,
) -> np.ndarray:

residual = solver_state.residual

if self._reorthogonalization_fn_residual is not None and solver_state.step == 0:
solver_state.cache["reorthogonalized_residuals"] = [solver_state.residual]
if solver_state.step == 0:
if self._reorthogonalization_fn_residual is not None:
solver_state.cache["reorthogonalized_residuals"].append(
solver_state.residual
)

if solver_state.step > 0:
return residual
else:
# Reorthogonalization of the residual
if self._reorthogonalization_fn_residual is not None:
residual, prev_residual = self._reorthogonalized_residuals(
residual, prev_residual = self._reorthogonalized_residual(
solver_state=solver_state
)
else:
residual = solver_state.residual
prev_residual = solver_state.residuals[solver_state.step - 1]

# A-conjugacy correction (in exact arithmetic)
Expand All @@ -67,16 +72,13 @@ def __call__(

# Reorthogonalization of the resulting action
if self._reorthogonalization_fn_action is not None:
return self._reorthogonalized_action(
action = self._reorthogonalized_action(
action=action, solver_state=solver_state
)

else:
action = residual

return action
return action

def _reorthogonalized_residuals(
def _reorthogonalized_residual(
self,
solver_state: "probnum.linalg.solvers.LinearSolverState",
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
Loading