Skip to content

Commit

Permalink
ImmediateStop stopping criterion for quad (#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci authored Mar 14, 2022
1 parent 98bd6d2 commit 8488cd8
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/probnum/quad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from probnum.quad.solvers.policies import Policy, RandomPolicy
from probnum.quad.solvers.stopping_criteria import (
BQStoppingCriterion,
ImmediateStop,
IntegralVarianceTolerance,
MaxNevals,
RelativeMeanChange,
Expand All @@ -31,18 +32,21 @@
"bayesquad_from_data",
"BayesianQuadrature",
"IntegrationMeasure",
"ImmediateStop",
"KernelEmbedding",
"GaussianMeasure",
"LebesgueMeasure",
"BQStoppingCriterion",
"IntegralVarianceTolerance",
"MaxNevals",
"RandomPolicy",
"RelativeMeanChange",
]

# Set correct module paths. Corrects links and module paths in documentation.
BayesianQuadrature.__module__ = "probnum.quad"
BQStoppingCriterion.__module__ = "probnum.quad"
ImmediateStop.__module__ = "probnum.quad"
IntegrationMeasure.__module__ = "probnum.quad"
KernelEmbedding.__module__ = "probnum.quad"
GaussianMeasure.__module__ = "probnum.quad"
Expand Down
3 changes: 3 additions & 0 deletions src/probnum/quad/solvers/stopping_criteria/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
"""Stopping criteria for Bayesian quadrature methods."""

from ._bq_stopping_criterion import BQStoppingCriterion
from ._immediate_stop import ImmediateStop
from ._integral_variance_tol import IntegralVarianceTolerance
from ._max_nevals import MaxNevals
from ._rel_mean_change import RelativeMeanChange

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"BQStoppingCriterion",
"ImmediateStop",
"IntegralVarianceTolerance",
"MaxNevals",
"RelativeMeanChange",
]

# Set correct module paths. Corrects links and module paths in documentation.
BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria"
ImmediateStop.__module__ = "probnum.quad.solvers.stopping_criteria"
IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria"
MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria"
RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria"
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ class BQStoppingCriterion(StoppingCriterion):
IntegralVarianceTolerance : Stop based on the variance of the integral estimator.
RelativeMeanChange : Stop based on the absolute value of the integral variance.
MaxNevals : Stop based on a maximum number of iterations.
ImmediateStop : Dummy stopping criterion that always stops.
"""

def __call__(self, bq_state: BQState) -> bool:
"""Check whether tracked quantities meet a desired terminal condition.
Parameters
----------
bq_state:
State of the BQ loop.
bq_state
State of the BQ belief.
Returns
-------
stopping_decision :
Whether the stopping condition is met.
"""
raise NotImplementedError
15 changes: 15 additions & 0 deletions src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Stopping criterion that stops immediately."""

from probnum.quad.solvers.bq_state import BQState
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion

# pylint: disable=too-few-public-methods


class ImmediateStop(BQStoppingCriterion):
"""Dummy stopping criterion that always stops. This is useful for fixed datasets
when no policy or acquisition loop is required or given.
"""

def __call__(self, bq_state: BQState) -> bool:
return True
86 changes: 86 additions & 0 deletions tests/test_quad/test_stopping_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for BQ stopping criteria."""

import numpy as np
import pytest

from probnum.quad import (
BQStoppingCriterion,
ImmediateStop,
IntegralVarianceTolerance,
LebesgueMeasure,
MaxNevals,
RelativeMeanChange,
)
from probnum.quad.solvers.bq_state import BQState
from probnum.randprocs.kernels import ExpQuad
from probnum.randvars import Normal

_nevals = 5
_rel_tol = 1e-5
_var_tol = 1e-5


@pytest.fixture()
def input_dim():
return 2


@pytest.fixture(
params=[
pytest.param(sc, id=sc[0].__name__)
for sc in [
(MaxNevals, {"max_nevals": _nevals}),
(IntegralVarianceTolerance, {"var_tol": _var_tol}),
(RelativeMeanChange, {"rel_tol": _rel_tol}),
]
],
name="stopping_criterion",
)
def fixture_stopping_criterion(request) -> BQStoppingCriterion:
"""BQ stopping criterion."""
return request.param[0](**request.param[1])


@pytest.fixture()
def bq_state_stops(input_dim) -> BQState:
"""BQ state that triggers stopping in all stopping criteria."""
integral_mean = 1.0
integral_mean_previous = integral_mean * (1 - _rel_tol)
return BQState(
measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)),
kernel=ExpQuad(input_shape=(input_dim,)),
integral_belief=Normal(integral_mean, 0.1 * _var_tol),
previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),),
nodes=np.ones((_nevals, input_dim)),
fun_evals=np.ones(_nevals),
)


@pytest.fixture()
def bq_state_does_not_stop(input_dim) -> BQState:
"""BQ state that does not trigger stopping in all stopping criteria."""
integral_mean = 1.0
integral_mean_previous = 2 * integral_mean * (1 - _rel_tol)
nevals = _nevals - 2
return BQState(
measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)),
kernel=ExpQuad(input_shape=(input_dim,)),
integral_belief=Normal(integral_mean, 10 * _var_tol),
previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),),
nodes=np.ones((nevals, input_dim)),
fun_evals=np.ones(nevals),
)


def test_immediate_stop_values(bq_state_stops, bq_state_does_not_stop):
# Immediate stop shall always stop
sc = ImmediateStop()
assert sc(bq_state_stops)
assert sc(bq_state_does_not_stop)


def test_stopping_criterion_values(
stopping_criterion, bq_state_stops, bq_state_does_not_stop
):
assert stopping_criterion(bq_state_stops)
assert not stopping_criterion(bq_state_does_not_stop)

0 comments on commit 8488cd8

Please sign in to comment.