-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
98bd6d2
commit 8488cd8
Showing
5 changed files
with
116 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,23 @@ | ||
"""Stopping criteria for Bayesian quadrature methods.""" | ||
|
||
from ._bq_stopping_criterion import BQStoppingCriterion | ||
from ._immediate_stop import ImmediateStop | ||
from ._integral_variance_tol import IntegralVarianceTolerance | ||
from ._max_nevals import MaxNevals | ||
from ._rel_mean_change import RelativeMeanChange | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"BQStoppingCriterion", | ||
"ImmediateStop", | ||
"IntegralVarianceTolerance", | ||
"MaxNevals", | ||
"RelativeMeanChange", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
ImmediateStop.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Stopping criterion that stops immediately.""" | ||
|
||
from probnum.quad.solvers.bq_state import BQState | ||
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion | ||
|
||
# pylint: disable=too-few-public-methods | ||
|
||
|
||
class ImmediateStop(BQStoppingCriterion): | ||
"""Dummy stopping criterion that always stops. This is useful for fixed datasets | ||
when no policy or acquisition loop is required or given. | ||
""" | ||
|
||
def __call__(self, bq_state: BQState) -> bool: | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Tests for BQ stopping criteria.""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from probnum.quad import ( | ||
BQStoppingCriterion, | ||
ImmediateStop, | ||
IntegralVarianceTolerance, | ||
LebesgueMeasure, | ||
MaxNevals, | ||
RelativeMeanChange, | ||
) | ||
from probnum.quad.solvers.bq_state import BQState | ||
from probnum.randprocs.kernels import ExpQuad | ||
from probnum.randvars import Normal | ||
|
||
_nevals = 5 | ||
_rel_tol = 1e-5 | ||
_var_tol = 1e-5 | ||
|
||
|
||
@pytest.fixture() | ||
def input_dim(): | ||
return 2 | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
pytest.param(sc, id=sc[0].__name__) | ||
for sc in [ | ||
(MaxNevals, {"max_nevals": _nevals}), | ||
(IntegralVarianceTolerance, {"var_tol": _var_tol}), | ||
(RelativeMeanChange, {"rel_tol": _rel_tol}), | ||
] | ||
], | ||
name="stopping_criterion", | ||
) | ||
def fixture_stopping_criterion(request) -> BQStoppingCriterion: | ||
"""BQ stopping criterion.""" | ||
return request.param[0](**request.param[1]) | ||
|
||
|
||
@pytest.fixture() | ||
def bq_state_stops(input_dim) -> BQState: | ||
"""BQ state that triggers stopping in all stopping criteria.""" | ||
integral_mean = 1.0 | ||
integral_mean_previous = integral_mean * (1 - _rel_tol) | ||
return BQState( | ||
measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), | ||
kernel=ExpQuad(input_shape=(input_dim,)), | ||
integral_belief=Normal(integral_mean, 0.1 * _var_tol), | ||
previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),), | ||
nodes=np.ones((_nevals, input_dim)), | ||
fun_evals=np.ones(_nevals), | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def bq_state_does_not_stop(input_dim) -> BQState: | ||
"""BQ state that does not trigger stopping in all stopping criteria.""" | ||
integral_mean = 1.0 | ||
integral_mean_previous = 2 * integral_mean * (1 - _rel_tol) | ||
nevals = _nevals - 2 | ||
return BQState( | ||
measure=LebesgueMeasure(input_dim=input_dim, domain=(0, 1)), | ||
kernel=ExpQuad(input_shape=(input_dim,)), | ||
integral_belief=Normal(integral_mean, 10 * _var_tol), | ||
previous_integral_beliefs=(Normal(integral_mean_previous, _var_tol),), | ||
nodes=np.ones((nevals, input_dim)), | ||
fun_evals=np.ones(nevals), | ||
) | ||
|
||
|
||
def test_immediate_stop_values(bq_state_stops, bq_state_does_not_stop): | ||
# Immediate stop shall always stop | ||
sc = ImmediateStop() | ||
assert sc(bq_state_stops) | ||
assert sc(bq_state_does_not_stop) | ||
|
||
|
||
def test_stopping_criterion_values( | ||
stopping_criterion, bq_state_stops, bq_state_does_not_stop | ||
): | ||
assert stopping_criterion(bq_state_stops) | ||
assert not stopping_criterion(bq_state_does_not_stop) |