diff --git a/src/probnum/quad/__init__.py b/src/probnum/quad/__init__.py index 603e86462..ca81148a0 100644 --- a/src/probnum/quad/__init__.py +++ b/src/probnum/quad/__init__.py @@ -9,6 +9,7 @@ from probnum.quad.solvers.policies import Policy, RandomPolicy from probnum.quad.solvers.stopping_criteria import ( BQStoppingCriterion, + ImmediateStop, IntegralVarianceTolerance, MaxNevals, RelativeMeanChange, @@ -31,18 +32,21 @@ "bayesquad_from_data", "BayesianQuadrature", "IntegrationMeasure", + "ImmediateStop", "KernelEmbedding", "GaussianMeasure", "LebesgueMeasure", "BQStoppingCriterion", "IntegralVarianceTolerance", "MaxNevals", + "RandomPolicy", "RelativeMeanChange", ] # Set correct module paths. Corrects links and module paths in documentation. BayesianQuadrature.__module__ = "probnum.quad" BQStoppingCriterion.__module__ = "probnum.quad" +ImmediateStop.__module__ = "probnum.quad" IntegrationMeasure.__module__ = "probnum.quad" KernelEmbedding.__module__ = "probnum.quad" GaussianMeasure.__module__ = "probnum.quad" diff --git a/src/probnum/quad/solvers/stopping_criteria/__init__.py b/src/probnum/quad/solvers/stopping_criteria/__init__.py index 6e30c9de3..a8e8682dd 100644 --- a/src/probnum/quad/solvers/stopping_criteria/__init__.py +++ b/src/probnum/quad/solvers/stopping_criteria/__init__.py @@ -1,6 +1,7 @@ """Stopping criteria for Bayesian quadrature methods.""" from ._bq_stopping_criterion import BQStoppingCriterion +from ._immediate_stop import ImmediateStop from ._integral_variance_tol import IntegralVarianceTolerance from ._max_nevals import MaxNevals from ._rel_mean_change import RelativeMeanChange @@ -8,6 +9,7 @@ # Public classes and functions. Order is reflected in documentation. __all__ = [ "BQStoppingCriterion", + "ImmediateStop", "IntegralVarianceTolerance", "MaxNevals", "RelativeMeanChange", @@ -15,6 +17,7 @@ # Set correct module paths. Corrects links and module paths in documentation. BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria" +ImmediateStop.__module__ = "probnum.quad.solvers.stopping_criteria" IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria" MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria" RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria" diff --git a/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py index 57c6e65c6..a79ea91b0 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py +++ b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py @@ -18,6 +18,7 @@ class BQStoppingCriterion(StoppingCriterion): IntegralVarianceTolerance : Stop based on the variance of the integral estimator. RelativeMeanChange : Stop based on the absolute value of the integral variance. MaxNevals : Stop based on a maximum number of iterations. + ImmediateStop : Dummy stopping criterion that always stops. """ def __call__(self, bq_state: BQState) -> bool: @@ -25,7 +26,12 @@ def __call__(self, bq_state: BQState) -> bool: Parameters ---------- - bq_state: - State of the BQ loop. + bq_state + State of the BQ belief. + + Returns + ------- + stopping_decision : + Whether the stopping condition is met. """ raise NotImplementedError diff --git a/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py b/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py new file mode 100644 index 000000000..71734d9bd --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py @@ -0,0 +1,15 @@ +"""Stopping criterion that stops immediately.""" + +from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion + +# pylint: disable=too-few-public-methods + + +class ImmediateStop(BQStoppingCriterion): + """Dummy stopping criterion that always stops. This is useful for fixed datasets + when no policy or acquisition loop is required or given. + """ + + def __call__(self, bq_state: BQState) -> bool: + return True diff --git a/tests/test_quad/test_stopping_criterion.py b/tests/test_quad/test_stopping_criterion.py new file mode 100644 index 000000000..523b80816 --- /dev/null +++ b/tests/test_quad/test_stopping_criterion.py @@ -0,0 +1,86 @@ +"""Tests for BQ stopping criteria.""" + +import numpy as np +import pytest + +from probnum.quad import ( + BQStoppingCriterion, + ImmediateStop, + IntegralVarianceTolerance, + LebesgueMeasure, + MaxNevals, + RelativeMeanChange, +) +from probnum.quad.solvers.bq_state import BQState +from probnum.randprocs.kernels import ExpQuad +from probnum.randvars import Normal + +_nevals = 5 +_rel_tol = 1e-5 +_var_tol = 1e-5 + + +@pytest.fixture() +def input_dim(): + return 2 + + +@pytest.fixture( + params=[ + pytest.param(sc, id=sc[0].__name__) + for sc in [ + (MaxNevals, {"max_nevals": _nevals}), + (IntegralVarianceTolerance, {"var_tol": _var_tol}), + (RelativeMeanChange, {"rel_tol": _rel_tol}), + ] + ], + name="stopping_criterion", +) +def fixture_stopping_criterion(request) -> BQStoppingCriterion: + """BQ stopping criterion.""" + return request.param[0](**request.param[1]) + + +@pytest.fixture() +def bq_state_stops(input_dim) -> BQState: + """BQ state that triggers stopping in all stopping criteria.""" + integral_mean = 1.0 + integral_mean_previous = integral_mean * (1 - _rel_tol) + return BQState( + measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(input_dim,)), + integral_belief=Normal(integral_mean, 0.1 * _var_tol), + previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),), + nodes=np.ones((_nevals, input_dim)), + fun_evals=np.ones(_nevals), + ) + + +@pytest.fixture() +def bq_state_does_not_stop(input_dim) -> BQState: + """BQ state that does not trigger stopping in all stopping criteria.""" + integral_mean = 1.0 + integral_mean_previous = 2 * integral_mean * (1 - _rel_tol) + nevals = _nevals - 2 + return BQState( + measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(input_dim,)), + integral_belief=Normal(integral_mean, 10 * _var_tol), + previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),), + nodes=np.ones((nevals, input_dim)), + fun_evals=np.ones(nevals), + ) + + +def test_immediate_stop_values(bq_state_stops, bq_state_does_not_stop): + # Immediate stop shall always stop + sc = ImmediateStop() + assert sc(bq_state_stops) + assert sc(bq_state_does_not_stop) + + +def test_stopping_criterion_values( + stopping_criterion, bq_state_stops, bq_state_does_not_stop +): + assert stopping_criterion(bq_state_stops) + assert not stopping_criterion(bq_state_does_not_stop)