-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BQ stopping criteria use base class (#570)
- Loading branch information
1 parent
db8921e
commit ac10e59
Showing
16 changed files
with
192 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
"""Stopping criteria for Bayesian quadrature methods.""" | ||
|
||
from ._bq_stopping_criterion import BQStoppingCriterion | ||
from ._integral_variance_tol import IntegralVarianceTolerance | ||
from ._max_nevals import MaxNevals | ||
from ._rel_mean_change import RelativeMeanChange | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"BQStoppingCriterion", | ||
"IntegralVarianceTolerance", | ||
"MaxNevals", | ||
"RelativeMeanChange", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria" | ||
RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria" |
30 changes: 30 additions & 0 deletions
30
src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Base class for Bayesian quadrature stopping criteria.""" | ||
|
||
from probnum import StoppingCriterion | ||
from probnum.quad.solvers.bq_state import BQState | ||
|
||
# pylint: disable=too-few-public-methods, fixme | ||
# pylint: disable=arguments-differ | ||
|
||
|
||
class BQStoppingCriterion(StoppingCriterion): | ||
r"""Stopping criterion of a Bayesian quadrature method. | ||
Checks whether quantities tracked by the :class:`~probnum.quad.solvers.BQState` meet a desired terminal condition. | ||
See Also | ||
-------- | ||
IntegralVarianceTolerance : Stop based on the variance of the integral estimator. | ||
RelativeMeanChange : Stop based on the absolute value of the integral variance. | ||
MaxNevals : Stop based on a maximum number of iterations. | ||
""" | ||
|
||
def __call__(self, bq_state: BQState) -> bool: | ||
"""Check whether tracked quantities meet a desired terminal condition. | ||
Parameters | ||
---------- | ||
bq_state: | ||
State of the BQ loop. | ||
""" | ||
raise NotImplementedError |
23 changes: 23 additions & 0 deletions
23
src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Stopping criterion based on the absolute value of the integral variance""" | ||
|
||
from probnum.quad.solvers.bq_state import BQState | ||
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion | ||
from probnum.typing import FloatArgType | ||
|
||
# pylint: disable=too-few-public-methods, fixme | ||
|
||
|
||
class IntegralVarianceTolerance(BQStoppingCriterion): | ||
"""Stop once the integral variance is below some tolerance. | ||
Parameters | ||
---------- | ||
var_tol: | ||
Tolerance value of the variance. | ||
""" | ||
|
||
def __init__(self, var_tol: FloatArgType): | ||
self.var_tol = var_tol | ||
|
||
def __call__(self, bq_state: BQState) -> bool: | ||
return bq_state.integral_belief.var <= self.var_tol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Stopping criterion based on a maximum number of integrand evaluations.""" | ||
|
||
from probnum.quad.solvers.bq_state import BQState | ||
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion | ||
from probnum.typing import IntArgType | ||
|
||
# pylint: disable=too-few-public-methods | ||
|
||
|
||
class MaxNevals(BQStoppingCriterion): | ||
"""Stop once a maximum number of integrand evaluations is reached. | ||
Parameters | ||
---------- | ||
max_nevals: | ||
Maximum number of integrand evaluations. | ||
""" | ||
|
||
def __init__(self, max_nevals: IntArgType): | ||
self.max_nevals = max_nevals | ||
|
||
def __call__(self, bq_state: BQState) -> bool: | ||
return bq_state.info.nevals >= self.max_nevals |
37 changes: 37 additions & 0 deletions
37
src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Stopping criterion based on the relative change of the successive integral estimators.""" | ||
|
||
import numpy as np | ||
|
||
from probnum.quad.solvers.bq_state import BQState | ||
from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion | ||
from probnum.typing import FloatArgType | ||
|
||
# pylint: disable=too-few-public-methods | ||
|
||
|
||
class RelativeMeanChange(BQStoppingCriterion): | ||
"""Stop once the relative change of consecutive integral estimates are smaller than | ||
a tolerance. | ||
The stopping criterion is: :math:`|\\hat{F}_{c} - \\hat{F}_{p}|/ |\\hat{F}_{c}| \\leq r` | ||
where :math:`\\hat{F}_{c}` and :math:`\\hat{F}_{p}` are the integral estimates of the current and previous iteration | ||
respectively, and :math:`r` is the relative tolerance. | ||
Parameters | ||
---------- | ||
rel_tol: | ||
Relative error tolerance on consecutive integral mean values. | ||
""" | ||
|
||
def __init__(self, rel_tol: FloatArgType): | ||
self.rel_tol = rel_tol | ||
|
||
def __call__(self, bq_state: BQState) -> bool: | ||
integral_belief = bq_state.integral_belief | ||
return ( | ||
np.abs( | ||
(integral_belief.mean - bq_state.previous_integral_beliefs[-1].mean) | ||
/ integral_belief.mean | ||
) | ||
<= self.rel_tol | ||
) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.