Skip to content

Commit

Permalink
BQ stopping criteria use base class (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci authored Nov 22, 2021
1 parent db8921e commit ac10e59
Show file tree
Hide file tree
Showing 16 changed files with 192 additions and 143 deletions.
34 changes: 19 additions & 15 deletions src/probnum/quad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@
choosing points to evaluate the integrand based on said model.
"""

from probnum.quad.solvers.policies import Policy, RandomPolicy
from probnum.quad.solvers.stopping_criteria import (
BQStoppingCriterion,
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
)

from ._bayesquad import bayesquad, bayesquad_from_data
from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure
from .bq_methods import (
BayesianQuadrature,
BQBeliefUpdate,
BQInfo,
BQStandardBeliefUpdate,
BQState,
)
from .kernel_embeddings import (
KernelEmbedding,
_kernel_mean_expquad_gauss,
_kernel_mean_expquad_lebesgue,
_kernel_variance_expquad_gauss,
_kernel_variance_expquad_lebesgue,
)
from .policies import Policy, RandomPolicy
from .stop_criteria import (
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
StoppingCriterion,
from .solvers import (
BayesianQuadrature,
BQBeliefUpdate,
BQInfo,
BQStandardBeliefUpdate,
BQState,
)

# Public classes and functions. Order is reflected in documentation.
Expand All @@ -39,13 +40,16 @@
"KernelEmbedding",
"GaussianMeasure",
"LebesgueMeasure",
"StoppingCriterion",
"BQStoppingCriterion",
"IntegralVarianceTolerance",
"MaxNevals",
"RelativeMeanChange",
]

# Set correct module paths. Corrects links and module paths in documentation.
BayesianQuadrature.__module__ = "probnum.quad"
BQStoppingCriterion.__module__ = "probnum.quad"
IntegrationMeasure.__module__ = "probnum.quad"
KernelEmbedding.__module__ = "probnum.quad"
GaussianMeasure.__module__ = "probnum.quad"
LebesgueMeasure.__module__ = "probnum.quad"
StoppingCriterion.__module__ = "probnum.quad"
2 changes: 1 addition & 1 deletion src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from probnum.typing import FloatArgType, IntArgType

from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure
from .bq_methods import BayesianQuadrature
from .solvers import BayesianQuadrature


# pylint: disable=too-many-arguments, no-else-raise
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
"""Probabilistic numerical methods for solving integrals."""

from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import numpy as np

from probnum.quad.solvers.policies import Policy, RandomPolicy
from probnum.quad.solvers.stopping_criteria import (
BQStoppingCriterion,
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
)
from probnum.randprocs.kernels import ExpQuad, Kernel
from probnum.randvars import Normal
from probnum.typing import FloatArgType, IntArgType

from .._integration_measures import IntegrationMeasure, LebesgueMeasure
from ..kernel_embeddings import KernelEmbedding
from ..policies import Policy, RandomPolicy
from ..stop_criteria import (
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
StoppingCriterion,
)
from .belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate
from .bq_state import BQState

Expand All @@ -38,8 +38,8 @@ class BayesianQuadrature:
The policy choosing nodes at which to evaluate the integrand.
belief_update :
The inference method.
stopping_criteria :
List of criteria that determine convergence.
stopping_criterion :
The criterion that determines convergence.
"""
# pylint: disable=too-many-arguments

Expand All @@ -49,13 +49,13 @@ def __init__(
measure: IntegrationMeasure,
policy: Policy,
belief_update: BQBeliefUpdate,
stopping_criteria: List[StoppingCriterion],
stopping_criterion: BQStoppingCriterion,
) -> None:
self.kernel = kernel
self.measure = measure
self.policy = policy
self.belief_update = belief_update
self.stopping_criteria = stopping_criteria
self.stopping_criterion = stopping_criterion

@classmethod
def from_problem(
Expand Down Expand Up @@ -127,25 +127,38 @@ def from_problem(

# Set stopping criteria
# If multiple stopping criteria are given, BQ stops once the first criterion is fulfilled.
_stopping_criteria = []
def _stopcrit_or(sc1, sc2):
if sc1 is None:
return sc2
return sc1 | sc2

_stopping_criterion = None

if max_evals is not None:
_stopping_criteria.append(MaxNevals(max_evals))
_stopping_criterion = _stopcrit_or(
_stopping_criterion, MaxNevals(max_evals)
)
if var_tol is not None:
_stopping_criteria.append(IntegralVarianceTolerance(var_tol))
_stopping_criterion = _stopcrit_or(
_stopping_criterion, IntegralVarianceTolerance(var_tol)
)
if rel_tol is not None:
_stopping_criteria.append(RelativeMeanChange(rel_tol))
_stopping_criterion = _stopcrit_or(
_stopping_criterion, RelativeMeanChange(rel_tol)
)

# If no stopping criteria are given, use some default values (these are arbitrary values)
if not _stopping_criteria:
_stopping_criteria.append(IntegralVarianceTolerance(var_tol=1e-6))
_stopping_criteria.append(MaxNevals(max_evals=input_dim * 25))
if _stopping_criterion is None:
_stopping_criterion = IntegralVarianceTolerance(var_tol=1e-6) | MaxNevals(
max_nevals=input_dim * 25
)

return cls(
kernel=kernel,
measure=measure,
policy=policy,
belief_update=belief_update,
stopping_criteria=_stopping_criteria,
stopping_criterion=_stopping_criterion,
)

def has_converged(self, bq_state: BQState) -> bool:
Expand All @@ -163,12 +176,10 @@ def has_converged(self, bq_state: BQState) -> bool:
Whether or not the solver has converged.
"""

for stopping_criterion in self.stopping_criteria:
_has_converged = stopping_criterion(bq_state.integral_belief, bq_state)
if _has_converged:
bq_state.info.has_converged = True
bq_state.info.stopping_criterion = stopping_criterion.__class__.__name__
return True
_has_converged = self.stopping_criterion(bq_state)
if _has_converged:
bq_state.info.has_converged = True
return True
return False

def bq_iterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from scipy.linalg import cho_factor, cho_solve

from probnum.quad.bq_methods.bq_state import BQState
from probnum.quad.solvers.bq_state import BQState
from probnum.randvars import Normal

# pylint: disable=too-few-public-methods, too-many-locals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,17 @@ class BQInfo:
Number of evaluations collected.
has_converged :
True if the BQ loop fulfils a stopping criterion, otherwise False.
stopping_criterion:
The stopping criterion used to determine convergence.
"""

def __init__(
self,
iteration: int = 0,
nevals: int = 0,
has_converged: bool = False,
stopping_criterion: "probnum.quad.StoppingCriterion" = None,
):
self.iteration = iteration
self.nevals = nevals
self.has_converged = has_converged
self.stopping_criterion = stopping_criterion

def update_iteration(self, batch_size: int) -> None:
"""Update the quantities tracking iteration info.
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from probnum.quad.bq_methods.bq_state import BQState
from probnum.quad.solvers.bq_state import BQState

# pylint: disable=too-few-public-methods, fixme

Expand Down
20 changes: 20 additions & 0 deletions src/probnum/quad/solvers/stopping_criteria/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Stopping criteria for Bayesian quadrature methods."""

from ._bq_stopping_criterion import BQStoppingCriterion
from ._integral_variance_tol import IntegralVarianceTolerance
from ._max_nevals import MaxNevals
from ._rel_mean_change import RelativeMeanChange

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"BQStoppingCriterion",
"IntegralVarianceTolerance",
"MaxNevals",
"RelativeMeanChange",
]

# Set correct module paths. Corrects links and module paths in documentation.
BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria"
IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria"
MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria"
RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Base class for Bayesian quadrature stopping criteria."""

from probnum import StoppingCriterion
from probnum.quad.solvers.bq_state import BQState

# pylint: disable=too-few-public-methods, fixme
# pylint: disable=arguments-differ


class BQStoppingCriterion(StoppingCriterion):
r"""Stopping criterion of a Bayesian quadrature method.
Checks whether quantities tracked by the :class:`~probnum.quad.solvers.BQState` meet a desired terminal condition.
See Also
--------
IntegralVarianceTolerance : Stop based on the variance of the integral estimator.
RelativeMeanChange : Stop based on the absolute value of the integral variance.
MaxNevals : Stop based on a maximum number of iterations.
"""

def __call__(self, bq_state: BQState) -> bool:
"""Check whether tracked quantities meet a desired terminal condition.
Parameters
----------
bq_state:
State of the BQ loop.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Stopping criterion based on the absolute value of the integral variance"""

from probnum.quad.solvers.bq_state import BQState
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion
from probnum.typing import FloatArgType

# pylint: disable=too-few-public-methods, fixme


class IntegralVarianceTolerance(BQStoppingCriterion):
"""Stop once the integral variance is below some tolerance.
Parameters
----------
var_tol:
Tolerance value of the variance.
"""

def __init__(self, var_tol: FloatArgType):
self.var_tol = var_tol

def __call__(self, bq_state: BQState) -> bool:
return bq_state.integral_belief.var <= self.var_tol
23 changes: 23 additions & 0 deletions src/probnum/quad/solvers/stopping_criteria/_max_nevals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Stopping criterion based on a maximum number of integrand evaluations."""

from probnum.quad.solvers.bq_state import BQState
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion
from probnum.typing import IntArgType

# pylint: disable=too-few-public-methods


class MaxNevals(BQStoppingCriterion):
"""Stop once a maximum number of integrand evaluations is reached.
Parameters
----------
max_nevals:
Maximum number of integrand evaluations.
"""

def __init__(self, max_nevals: IntArgType):
self.max_nevals = max_nevals

def __call__(self, bq_state: BQState) -> bool:
return bq_state.info.nevals >= self.max_nevals
37 changes: 37 additions & 0 deletions src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Stopping criterion based on the relative change of the successive integral estimators."""

import numpy as np

from probnum.quad.solvers.bq_state import BQState
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion
from probnum.typing import FloatArgType

# pylint: disable=too-few-public-methods


class RelativeMeanChange(BQStoppingCriterion):
"""Stop once the relative change of consecutive integral estimates are smaller than
a tolerance.
The stopping criterion is: :math:`|\\hat{F}_{c} - \\hat{F}_{p}|/ |\\hat{F}_{c}| \\leq r`
where :math:`\\hat{F}_{c}` and :math:`\\hat{F}_{p}` are the integral estimates of the current and previous iteration
respectively, and :math:`r` is the relative tolerance.
Parameters
----------
rel_tol:
Relative error tolerance on consecutive integral mean values.
"""

def __init__(self, rel_tol: FloatArgType):
self.rel_tol = rel_tol

def __call__(self, bq_state: BQState) -> bool:
integral_belief = bq_state.integral_belief
return (
np.abs(
(integral_belief.mean - bq_state.previous_integral_beliefs[-1].mean)
/ integral_belief.mean
)
<= self.rel_tol
)
6 changes: 0 additions & 6 deletions src/probnum/quad/stop_criteria/__init__.py

This file was deleted.

Loading

0 comments on commit ac10e59

Please sign in to comment.