From ac10e59196f39b51b65b1080294b064452bd3e7e Mon Sep 17 00:00:00 2001 From: Maren Mahsereci <42842079+mmahsereci@users.noreply.github.com> Date: Mon, 22 Nov 2021 14:07:36 +0100 Subject: [PATCH] BQ stopping criteria use base class (#570) --- src/probnum/quad/__init__.py | 34 +++---- src/probnum/quad/_bayesquad.py | 2 +- .../quad/{bq_methods => solvers}/__init__.py | 0 .../bayesian_quadrature.py | 63 +++++++------ .../belief_updates/__init__.py | 0 .../belief_updates/_belief_update.py | 2 +- .../quad/{bq_methods => solvers}/bq_state.py | 4 - .../quad/{ => solvers}/policies/__init__.py | 0 .../quad/{ => solvers}/policies/_policy.py | 2 +- .../solvers/stopping_criteria/__init__.py | 20 +++++ .../_bq_stopping_criterion.py | 30 +++++++ .../_integral_variance_tol.py | 23 +++++ .../solvers/stopping_criteria/_max_nevals.py | 23 +++++ .../stopping_criteria/_rel_mean_change.py | 37 ++++++++ src/probnum/quad/stop_criteria/__init__.py | 6 -- .../quad/stop_criteria/_stopping_criterion.py | 89 ------------------- 16 files changed, 192 insertions(+), 143 deletions(-) rename src/probnum/quad/{bq_methods => solvers}/__init__.py (100%) rename src/probnum/quad/{bq_methods => solvers}/bayesian_quadrature.py (88%) rename src/probnum/quad/{bq_methods => solvers}/belief_updates/__init__.py (100%) rename src/probnum/quad/{bq_methods => solvers}/belief_updates/_belief_update.py (98%) rename src/probnum/quad/{bq_methods => solvers}/bq_state.py (95%) rename src/probnum/quad/{ => solvers}/policies/__init__.py (100%) rename src/probnum/quad/{ => solvers}/policies/_policy.py (97%) create mode 100644 src/probnum/quad/solvers/stopping_criteria/__init__.py create mode 100644 src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py create mode 100644 src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py create mode 100644 src/probnum/quad/solvers/stopping_criteria/_max_nevals.py create mode 100644 src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py delete mode 100644 src/probnum/quad/stop_criteria/__init__.py delete mode 100644 src/probnum/quad/stop_criteria/_stopping_criterion.py diff --git a/src/probnum/quad/__init__.py b/src/probnum/quad/__init__.py index 29ff47d42..7e9ba4448 100644 --- a/src/probnum/quad/__init__.py +++ b/src/probnum/quad/__init__.py @@ -6,15 +6,16 @@ choosing points to evaluate the integrand based on said model. """ +from probnum.quad.solvers.policies import Policy, RandomPolicy +from probnum.quad.solvers.stopping_criteria import ( + BQStoppingCriterion, + IntegralVarianceTolerance, + MaxNevals, + RelativeMeanChange, +) + from ._bayesquad import bayesquad, bayesquad_from_data from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure -from .bq_methods import ( - BayesianQuadrature, - BQBeliefUpdate, - BQInfo, - BQStandardBeliefUpdate, - BQState, -) from .kernel_embeddings import ( KernelEmbedding, _kernel_mean_expquad_gauss, @@ -22,12 +23,12 @@ _kernel_variance_expquad_gauss, _kernel_variance_expquad_lebesgue, ) -from .policies import Policy, RandomPolicy -from .stop_criteria import ( - IntegralVarianceTolerance, - MaxNevals, - RelativeMeanChange, - StoppingCriterion, +from .solvers import ( + BayesianQuadrature, + BQBeliefUpdate, + BQInfo, + BQStandardBeliefUpdate, + BQState, ) # Public classes and functions. Order is reflected in documentation. @@ -39,13 +40,16 @@ "KernelEmbedding", "GaussianMeasure", "LebesgueMeasure", - "StoppingCriterion", + "BQStoppingCriterion", + "IntegralVarianceTolerance", + "MaxNevals", + "RelativeMeanChange", ] # Set correct module paths. Corrects links and module paths in documentation. BayesianQuadrature.__module__ = "probnum.quad" +BQStoppingCriterion.__module__ = "probnum.quad" IntegrationMeasure.__module__ = "probnum.quad" KernelEmbedding.__module__ = "probnum.quad" GaussianMeasure.__module__ = "probnum.quad" LebesgueMeasure.__module__ = "probnum.quad" -StoppingCriterion.__module__ = "probnum.quad" diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index 53772e8c6..f9d8e3863 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -17,7 +17,7 @@ from probnum.typing import FloatArgType, IntArgType from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure -from .bq_methods import BayesianQuadrature +from .solvers import BayesianQuadrature # pylint: disable=too-many-arguments, no-else-raise diff --git a/src/probnum/quad/bq_methods/__init__.py b/src/probnum/quad/solvers/__init__.py similarity index 100% rename from src/probnum/quad/bq_methods/__init__.py rename to src/probnum/quad/solvers/__init__.py diff --git a/src/probnum/quad/bq_methods/bayesian_quadrature.py b/src/probnum/quad/solvers/bayesian_quadrature.py similarity index 88% rename from src/probnum/quad/bq_methods/bayesian_quadrature.py rename to src/probnum/quad/solvers/bayesian_quadrature.py index 95d797921..bfe5042d7 100644 --- a/src/probnum/quad/bq_methods/bayesian_quadrature.py +++ b/src/probnum/quad/solvers/bayesian_quadrature.py @@ -1,22 +1,22 @@ """Probabilistic numerical methods for solving integrals.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np +from probnum.quad.solvers.policies import Policy, RandomPolicy +from probnum.quad.solvers.stopping_criteria import ( + BQStoppingCriterion, + IntegralVarianceTolerance, + MaxNevals, + RelativeMeanChange, +) from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal from probnum.typing import FloatArgType, IntArgType from .._integration_measures import IntegrationMeasure, LebesgueMeasure from ..kernel_embeddings import KernelEmbedding -from ..policies import Policy, RandomPolicy -from ..stop_criteria import ( - IntegralVarianceTolerance, - MaxNevals, - RelativeMeanChange, - StoppingCriterion, -) from .belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate from .bq_state import BQState @@ -38,8 +38,8 @@ class BayesianQuadrature: The policy choosing nodes at which to evaluate the integrand. belief_update : The inference method. - stopping_criteria : - List of criteria that determine convergence. + stopping_criterion : + The criterion that determines convergence. """ # pylint: disable=too-many-arguments @@ -49,13 +49,13 @@ def __init__( measure: IntegrationMeasure, policy: Policy, belief_update: BQBeliefUpdate, - stopping_criteria: List[StoppingCriterion], + stopping_criterion: BQStoppingCriterion, ) -> None: self.kernel = kernel self.measure = measure self.policy = policy self.belief_update = belief_update - self.stopping_criteria = stopping_criteria + self.stopping_criterion = stopping_criterion @classmethod def from_problem( @@ -127,25 +127,38 @@ def from_problem( # Set stopping criteria # If multiple stopping criteria are given, BQ stops once the first criterion is fulfilled. - _stopping_criteria = [] + def _stopcrit_or(sc1, sc2): + if sc1 is None: + return sc2 + return sc1 | sc2 + + _stopping_criterion = None + if max_evals is not None: - _stopping_criteria.append(MaxNevals(max_evals)) + _stopping_criterion = _stopcrit_or( + _stopping_criterion, MaxNevals(max_evals) + ) if var_tol is not None: - _stopping_criteria.append(IntegralVarianceTolerance(var_tol)) + _stopping_criterion = _stopcrit_or( + _stopping_criterion, IntegralVarianceTolerance(var_tol) + ) if rel_tol is not None: - _stopping_criteria.append(RelativeMeanChange(rel_tol)) + _stopping_criterion = _stopcrit_or( + _stopping_criterion, RelativeMeanChange(rel_tol) + ) # If no stopping criteria are given, use some default values (these are arbitrary values) - if not _stopping_criteria: - _stopping_criteria.append(IntegralVarianceTolerance(var_tol=1e-6)) - _stopping_criteria.append(MaxNevals(max_evals=input_dim * 25)) + if _stopping_criterion is None: + _stopping_criterion = IntegralVarianceTolerance(var_tol=1e-6) | MaxNevals( + max_nevals=input_dim * 25 + ) return cls( kernel=kernel, measure=measure, policy=policy, belief_update=belief_update, - stopping_criteria=_stopping_criteria, + stopping_criterion=_stopping_criterion, ) def has_converged(self, bq_state: BQState) -> bool: @@ -163,12 +176,10 @@ def has_converged(self, bq_state: BQState) -> bool: Whether or not the solver has converged. """ - for stopping_criterion in self.stopping_criteria: - _has_converged = stopping_criterion(bq_state.integral_belief, bq_state) - if _has_converged: - bq_state.info.has_converged = True - bq_state.info.stopping_criterion = stopping_criterion.__class__.__name__ - return True + _has_converged = self.stopping_criterion(bq_state) + if _has_converged: + bq_state.info.has_converged = True + return True return False def bq_iterator( diff --git a/src/probnum/quad/bq_methods/belief_updates/__init__.py b/src/probnum/quad/solvers/belief_updates/__init__.py similarity index 100% rename from src/probnum/quad/bq_methods/belief_updates/__init__.py rename to src/probnum/quad/solvers/belief_updates/__init__.py diff --git a/src/probnum/quad/bq_methods/belief_updates/_belief_update.py b/src/probnum/quad/solvers/belief_updates/_belief_update.py similarity index 98% rename from src/probnum/quad/bq_methods/belief_updates/_belief_update.py rename to src/probnum/quad/solvers/belief_updates/_belief_update.py index 345ad10a2..f6cd0e0d3 100644 --- a/src/probnum/quad/bq_methods/belief_updates/_belief_update.py +++ b/src/probnum/quad/solvers/belief_updates/_belief_update.py @@ -6,7 +6,7 @@ import numpy as np from scipy.linalg import cho_factor, cho_solve -from probnum.quad.bq_methods.bq_state import BQState +from probnum.quad.solvers.bq_state import BQState from probnum.randvars import Normal # pylint: disable=too-few-public-methods, too-many-locals diff --git a/src/probnum/quad/bq_methods/bq_state.py b/src/probnum/quad/solvers/bq_state.py similarity index 95% rename from src/probnum/quad/bq_methods/bq_state.py rename to src/probnum/quad/solvers/bq_state.py index 74d4c4033..68b6e51a7 100644 --- a/src/probnum/quad/bq_methods/bq_state.py +++ b/src/probnum/quad/solvers/bq_state.py @@ -23,8 +23,6 @@ class BQInfo: Number of evaluations collected. has_converged : True if the BQ loop fulfils a stopping criterion, otherwise False. - stopping_criterion: - The stopping criterion used to determine convergence. """ def __init__( @@ -32,12 +30,10 @@ def __init__( iteration: int = 0, nevals: int = 0, has_converged: bool = False, - stopping_criterion: "probnum.quad.StoppingCriterion" = None, ): self.iteration = iteration self.nevals = nevals self.has_converged = has_converged - self.stopping_criterion = stopping_criterion def update_iteration(self, batch_size: int) -> None: """Update the quantities tracking iteration info. diff --git a/src/probnum/quad/policies/__init__.py b/src/probnum/quad/solvers/policies/__init__.py similarity index 100% rename from src/probnum/quad/policies/__init__.py rename to src/probnum/quad/solvers/policies/__init__.py diff --git a/src/probnum/quad/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py similarity index 97% rename from src/probnum/quad/policies/_policy.py rename to src/probnum/quad/solvers/policies/_policy.py index fe14a2fbc..4cc9d7b86 100644 --- a/src/probnum/quad/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -5,7 +5,7 @@ import numpy as np -from probnum.quad.bq_methods.bq_state import BQState +from probnum.quad.solvers.bq_state import BQState # pylint: disable=too-few-public-methods, fixme diff --git a/src/probnum/quad/solvers/stopping_criteria/__init__.py b/src/probnum/quad/solvers/stopping_criteria/__init__.py new file mode 100644 index 000000000..6e30c9de3 --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/__init__.py @@ -0,0 +1,20 @@ +"""Stopping criteria for Bayesian quadrature methods.""" + +from ._bq_stopping_criterion import BQStoppingCriterion +from ._integral_variance_tol import IntegralVarianceTolerance +from ._max_nevals import MaxNevals +from ._rel_mean_change import RelativeMeanChange + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "BQStoppingCriterion", + "IntegralVarianceTolerance", + "MaxNevals", + "RelativeMeanChange", +] + +# Set correct module paths. Corrects links and module paths in documentation. +BQStoppingCriterion.__module__ = "probnum.quad.solvers.stopping_criteria" +IntegralVarianceTolerance.__module__ = "probnum.quad.solvers.stopping_criteria" +MaxNevals.__module__ = "probnum.quad.solvers.stopping_criteria" +RelativeMeanChange.__module__ = "probnum.quad.solvers.stopping_criteria" diff --git a/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py new file mode 100644 index 000000000..57052a571 --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py @@ -0,0 +1,30 @@ +"""Base class for Bayesian quadrature stopping criteria.""" + +from probnum import StoppingCriterion +from probnum.quad.solvers.bq_state import BQState + +# pylint: disable=too-few-public-methods, fixme +# pylint: disable=arguments-differ + + +class BQStoppingCriterion(StoppingCriterion): + r"""Stopping criterion of a Bayesian quadrature method. + + Checks whether quantities tracked by the :class:`~probnum.quad.solvers.BQState` meet a desired terminal condition. + + See Also + -------- + IntegralVarianceTolerance : Stop based on the variance of the integral estimator. + RelativeMeanChange : Stop based on the absolute value of the integral variance. + MaxNevals : Stop based on a maximum number of iterations. + """ + + def __call__(self, bq_state: BQState) -> bool: + """Check whether tracked quantities meet a desired terminal condition. + + Parameters + ---------- + bq_state: + State of the BQ loop. + """ + raise NotImplementedError diff --git a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py new file mode 100644 index 000000000..bd863ac11 --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py @@ -0,0 +1,23 @@ +"""Stopping criterion based on the absolute value of the integral variance""" + +from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion +from probnum.typing import FloatArgType + +# pylint: disable=too-few-public-methods, fixme + + +class IntegralVarianceTolerance(BQStoppingCriterion): + """Stop once the integral variance is below some tolerance. + + Parameters + ---------- + var_tol: + Tolerance value of the variance. + """ + + def __init__(self, var_tol: FloatArgType): + self.var_tol = var_tol + + def __call__(self, bq_state: BQState) -> bool: + return bq_state.integral_belief.var <= self.var_tol diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py new file mode 100644 index 000000000..fb40e3f32 --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -0,0 +1,23 @@ +"""Stopping criterion based on a maximum number of integrand evaluations.""" + +from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion +from probnum.typing import IntArgType + +# pylint: disable=too-few-public-methods + + +class MaxNevals(BQStoppingCriterion): + """Stop once a maximum number of integrand evaluations is reached. + + Parameters + ---------- + max_nevals: + Maximum number of integrand evaluations. + """ + + def __init__(self, max_nevals: IntArgType): + self.max_nevals = max_nevals + + def __call__(self, bq_state: BQState) -> bool: + return bq_state.info.nevals >= self.max_nevals diff --git a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py new file mode 100644 index 000000000..f74fb6262 --- /dev/null +++ b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py @@ -0,0 +1,37 @@ +"""Stopping criterion based on the relative change of the successive integral estimators.""" + +import numpy as np + +from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion +from probnum.typing import FloatArgType + +# pylint: disable=too-few-public-methods + + +class RelativeMeanChange(BQStoppingCriterion): + """Stop once the relative change of consecutive integral estimates are smaller than + a tolerance. + + The stopping criterion is: :math:`|\\hat{F}_{c} - \\hat{F}_{p}|/ |\\hat{F}_{c}| \\leq r` + where :math:`\\hat{F}_{c}` and :math:`\\hat{F}_{p}` are the integral estimates of the current and previous iteration + respectively, and :math:`r` is the relative tolerance. + + Parameters + ---------- + rel_tol: + Relative error tolerance on consecutive integral mean values. + """ + + def __init__(self, rel_tol: FloatArgType): + self.rel_tol = rel_tol + + def __call__(self, bq_state: BQState) -> bool: + integral_belief = bq_state.integral_belief + return ( + np.abs( + (integral_belief.mean - bq_state.previous_integral_beliefs[-1].mean) + / integral_belief.mean + ) + <= self.rel_tol + ) diff --git a/src/probnum/quad/stop_criteria/__init__.py b/src/probnum/quad/stop_criteria/__init__.py deleted file mode 100644 index 411db2101..000000000 --- a/src/probnum/quad/stop_criteria/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ._stopping_criterion import ( - IntegralVarianceTolerance, - MaxNevals, - RelativeMeanChange, - StoppingCriterion, -) diff --git a/src/probnum/quad/stop_criteria/_stopping_criterion.py b/src/probnum/quad/stop_criteria/_stopping_criterion.py deleted file mode 100644 index 2d875316a..000000000 --- a/src/probnum/quad/stop_criteria/_stopping_criterion.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Stopping criteria for Bayesian quadrature.""" - - -import numpy as np - -from probnum.quad.bq_methods.bq_state import BQState -from probnum.randvars import Normal -from probnum.typing import FloatArgType, IntArgType - -# pylint: disable=too-few-public-methods - - -class StoppingCriterion: - """Base class for a stopping criterion.""" - - def __call__(self, integral_belief: Normal, bq_state: BQState) -> bool: - """Evaluate the stopping criterion. - - Parameters - ---------- - integral_belief : - Current Gaussian belief about the integral. - bq_state: - State of the BQ loop. - - Returns - ------- - has_converged: - Boolean whether a stopping criterion has been reached - """ - raise NotImplementedError - - -class IntegralVarianceTolerance(StoppingCriterion): - """Stop once the integral variance is below some tolerance. - - Parameters - ---------- - var_tol: - Tolerance value of the variance. - """ - - def __init__(self, var_tol: FloatArgType): - self.var_tol = var_tol - - def __call__(self, integral_belief: Normal, bq_state: BQState) -> bool: - return integral_belief.var <= self.var_tol - - -class RelativeMeanChange(StoppingCriterion): - """Stop once the relative change of consecutive integral estimates are smaller than - a tolerance. That is, the stopping criterion is. - - | current_integral_estimate - previous_integral_estimate) / - current_integral_estimate | <= rel_tol. - - Parameters - ---------- - rel_tol: - Relative error tolerance on consecutive integral mean values. - """ - - def __init__(self, rel_tol: FloatArgType): - self.rel_tol = rel_tol - - def __call__(self, integral_belief: Normal, bq_state: BQState) -> bool: - return ( - np.abs( - (integral_belief.mean - bq_state.previous_integral_beliefs[-1].mean) - / integral_belief.mean - ) - <= self.rel_tol - ) - - -class MaxNevals(StoppingCriterion): - """Stop once a maximum number of iterations is reached. - - Parameters - ---------- - max_evals: - Maximum number of function evaluations. - """ - - def __init__(self, max_evals: IntArgType): - self.max_evals = max_evals - - def __call__(self, integral_belief: Normal, bq_state: BQState) -> bool: - return bq_state.info.nevals >= self.max_evals