Skip to content

Commit

Permalink
pylint cleanup quad (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci authored Jan 19, 2022
1 parent 3ae7949 commit 58c3023
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 64 deletions.
27 changes: 14 additions & 13 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
"""

import warnings
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import numpy as np

from probnum.quad.solvers.bq_state import BQInfo
from probnum.randprocs.kernels import Kernel
from probnum.randvars import Normal
from probnum.typing import FloatLike, IntLike
Expand All @@ -20,7 +21,6 @@
from .solvers import BayesianQuadrature


# pylint: disable=too-many-arguments, no-else-raise
def bayesquad(
fun: Callable,
input_dim: int,
Expand All @@ -35,8 +35,9 @@ def bayesquad(
rel_tol: Optional[FloatLike] = None,
batch_size: Optional[IntLike] = 1,
rng: Optional[np.random.Generator] = np.random.default_rng(),
) -> Tuple[Normal, Dict]:
r"""Infer the solution of the uni- or multivariate integral :math:`\int_\Omega f(x) d \mu(x)`
) -> Tuple[Normal, BQInfo]:
r"""Infer the solution of the uni- or multivariate integral
:math:`\int_\Omega f(x) d \mu(x)`
on a hyper-rectangle :math:`\Omega = [a_1, b_1] \times \cdots \times [a_D, b_D]`.
Bayesian quadrature (BQ) infers integrals of the form
Expand All @@ -47,12 +48,12 @@ def bayesquad(
:math:`\Omega \subset \mathbb{R}^D` against a measure :math:`\mu: \mathbb{R}^D
\mapsto \mathbb{R}`.
Bayesian quadrature methods return a probability distribution over the solution :math:`F` with
uncertainty arising from finite computation (here a finite number of function evaluations).
They start out with a random process encoding the prior belief about the function :math:`f`
to be integrated. Conditioned on either existing or acquired function evaluations according to a
policy, they update the belief on :math:`f`, which is translated into a posterior measure over
the integral :math:`F`.
Bayesian quadrature methods return a probability distribution over the solution
:math:`F` with uncertainty arising from finite computation (here a finite number
of function evaluations). They start out with a random process encoding the prior
belief about the function :math:`f` to be integrated. Conditioned on either existing
or acquired function evaluations according to a policy, they update the belief on
:math:`f`, which is translated into a posterior measure over the integral :math:`F`.
See Briol et al. [1]_ for a review on Bayesian quadrature.
Parameters
Expand Down Expand Up @@ -132,7 +133,7 @@ def bayesquad(
if domain is not None:
if isinstance(measure, GaussianMeasure):
raise ValueError("GaussianMeasure cannot be used with finite bounds.")
elif isinstance(measure, LebesgueMeasure):
if isinstance(measure, LebesgueMeasure):
warnings.warn(
"Both domain and a LebesgueMeasure are specified. The domain "
"information will be ignored."
Expand Down Expand Up @@ -165,7 +166,7 @@ def bayesquad_from_data(
Tuple[Union[np.ndarray, FloatLike], Union[np.ndarray, FloatLike]]
] = None,
measure: Optional[IntegrationMeasure] = None,
) -> Tuple[Normal, Dict]:
) -> Tuple[Normal, BQInfo]:
r"""Infer the value of an integral from a given set of nodes and function
evaluations.
Expand Down Expand Up @@ -219,7 +220,7 @@ def bayesquad_from_data(
if domain is not None:
if isinstance(measure, GaussianMeasure):
raise ValueError("GaussianMeasure cannot be used with finite bounds.")
elif isinstance(measure, LebesgueMeasure):
if isinstance(measure, LebesgueMeasure):
warnings.warn(
"Both domain and a LebesgueMeasure are specified. The domain "
"information will be ignored."
Expand Down
11 changes: 5 additions & 6 deletions src/probnum/quad/kernel_embeddings/_expquad_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from probnum.quad._integration_measures import GaussianMeasure
from probnum.randprocs.kernels import ExpQuad

# pylint: disable=invalid-name


def _kernel_mean_expquad_gauss(
x: np.ndarray, kernel: ExpQuad, measure: GaussianMeasure
Expand All @@ -20,17 +18,18 @@ def _kernel_mean_expquad_gauss(
Parameters
----------
x :
*shape=(n_eval, input_dim)* -- n_eval locations where to evaluate the kernel mean.
*shape=(n_eval, input_dim)* -- n_eval locations where to evaluate the kernel
mean.
kernel :
Instance of an ExpQuad kernel.
measure :
Instance of a GaussianMeasure.
Returns
-------
k_mean :
kernel_mean :
*shape (n_eval,)* -- The kernel integrated w.r.t. its first argument,
evaluated at locations x.
evaluated at locations ``x``.
"""
input_dim = kernel.input_dim

Expand Down Expand Up @@ -66,7 +65,7 @@ def _kernel_variance_expquad_gauss(kernel: ExpQuad, measure: GaussianMeasure) ->
Returns
-------
k_var :
kernel_variance :
The kernel integrated w.r.t. both arguments.
"""
input_dim = kernel.input_dim
Expand Down
11 changes: 6 additions & 5 deletions src/probnum/quad/kernel_embeddings/_expquad_lebesgue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Kernel embedding of exponentiated quadratic kernel with Lebesgue integration
measure."""

# pylint: disable=no-name-in-module, invalid-name
# pylint: disable=no-name-in-module

import numpy as np
from scipy.special import erf
Expand All @@ -19,17 +19,18 @@ def _kernel_mean_expquad_lebesgue(
Parameters
----------
x :
*shape (n_eval, input_dim)* -- n_eval locations where to evaluate the kernel mean.
*shape (n_eval, input_dim)* -- n_eval locations where to evaluate the kernel
mean.
kernel :
Instance of an ExpQuad kernel.
measure :
Instance of a LebesgueMeasure.
Returns
-------
k_mean :
kernel_mean :
*shape=(n_eval,)* -- The kernel integrated w.r.t. its first argument,
evaluated at locations x.
evaluated at locations ``x``.
"""
input_dim = kernel.input_dim

Expand Down Expand Up @@ -59,7 +60,7 @@ def _kernel_variance_expquad_lebesgue(
Returns
-------
k_var :
kernel_variance :
The kernel integrated w.r.t. both arguments.
"""

Expand Down
31 changes: 23 additions & 8 deletions src/probnum/quad/kernel_embeddings/_kernel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class KernelEmbedding:
Instance of a kernel.
measure:
Instance of an integration measure.
Raises
------
ValueError
If the input dimension of the kernel does not match the input dimension of the
measure.
"""

def __init__(self, kernel: Kernel, measure: IntegrationMeasure) -> None:
Expand All @@ -45,20 +51,20 @@ def __init__(self, kernel: Kernel, measure: IntegrationMeasure) -> None:
kernel=self.kernel, measure=self.measure
)

# pylint: disable=invalid-name
def kernel_mean(self, x: np.ndarray) -> np.ndarray:
"""Kernel mean w.r.t. its first argument against the integration measure.
Parameters
----------
x :
*shape=(n_eval, input_dim)* -- n_eval locations where to evaluate the kernel mean.
*shape=(n_eval, input_dim)* -- n_eval locations where to evaluate the
kernel mean.
Returns
-------
k_mean :
kernel_mean :
*shape=(n_eval,)* -- The kernel integrated w.r.t. its first argument,
evaluated at locations x.
evaluated at locations ``x``.
"""
return self._kmean(x=x, kernel=self.kernel, measure=self.measure)

Expand All @@ -67,7 +73,7 @@ def kernel_variance(self) -> float:
Returns
-------
k_var :
kernel_variance :
The kernel integrated w.r.t. both arguments.
"""
return self._kvar(kernel=self.kernel, measure=self.measure)
Expand All @@ -87,15 +93,24 @@ def _get_kernel_embedding(
Returns
-------
An instance of _KernelEmbedding.
kernel_mean :
The kernel mean function.
kernel_variance :
The kernel variance function.
Raises
------
NotImplementedError
If the given kernel is unknown.
NotImplementedError
If the kernel embedding of the kernel-measure pair is unknown.
"""

# Exponentiated quadratic kernel
if isinstance(kernel, ExpQuad):
# pylint: disable=no-else-return
if isinstance(measure, GaussianMeasure):
return _kernel_mean_expquad_gauss, _kernel_variance_expquad_gauss
elif isinstance(measure, LebesgueMeasure):
if isinstance(measure, LebesgueMeasure):
return _kernel_mean_expquad_lebesgue, _kernel_variance_expquad_lebesgue
raise NotImplementedError

Expand Down
63 changes: 42 additions & 21 deletions src/probnum/quad/solvers/bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ def from_problem(
Batch size used in node acquisition.
rng :
The random number generator.
Returns
-------
BayesianQuadrature
An instance of this class.
Raises
------
ValueError
If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random
number generator (``rng``) is given.
NotImplementedError
If an unknown ``policy`` is given.
"""
# Set up integration measure
if measure is None:
Expand Down Expand Up @@ -125,8 +138,8 @@ def from_problem(
"Policies other than random sampling are not available at the moment."
)

# Set stopping criteria
# If multiple stopping criteria are given, BQ stops once the first criterion is fulfilled.
# Set stopping criteria: If multiple stopping criteria are given, BQ stops
# once the first criterion is fulfilled.
def _stopcrit_or(sc1, sc2):
if sc1 is None:
return sc2
Expand All @@ -147,7 +160,8 @@ def _stopcrit_or(sc1, sc2):
_stopping_criterion, RelativeMeanChange(rel_tol)
)

# If no stopping criteria are given, use some default values (these are arbitrary values)
# If no stopping criteria are given, use some default values
# (these are arbitrary values)
if _stopping_criterion is None:
_stopping_criterion = IntegralVarianceTolerance(var_tol=1e-6) | MaxNevals(
max_nevals=input_dim * 25
Expand All @@ -167,13 +181,13 @@ def has_converged(self, bq_state: BQState) -> bool:
Parameters
----------
bq_state:
State of the Bayesian quadrature methods. Contains all necessary information about the
problem and the computation.
State of the Bayesian quadrature methods. Contains all necessary
information about the problem and the computation.
Returns
-------
has_converged:
Whether or not the solver has converged.
has_converged :
Whether the solver has converged.
"""

_has_converged = self.stopping_criterion(bq_state)
Expand All @@ -192,7 +206,8 @@ def bq_iterator(
) -> Tuple[Normal, np.ndarray, np.ndarray, BQState]:
"""Generator that implements the iteration of the BQ method.
This function exposes the state of the BQ method one step at a time while running the loop.
This function exposes the state of the BQ method one step at a time while
running the loop.
Parameters
----------
Expand All @@ -208,23 +223,22 @@ def bq_iterator(
integral_belief:
Current belief about the integral.
bq_state:
State of the Bayesian quadrature methods. Contains all necessary information about the
problem and the computation.
State of the Bayesian quadrature methods. Contains all necessary information
about the problem and the computation.
Returns
-------
integral_belief:
Yields
------
new_integral_belief :
Updated belief about the integral.
new_nodes:
new_nodes :
*shape=(n_new_eval, input_dim)* -- The new location(s) at which
``new_fun_evals`` are available found during the iteration.
new_fun_evals:
new_fun_evals :
*shape=(n_new_eval,)* -- The function evaluations at the new locations
``new_nodes``.
bq_state:
new_bq_state :
Updated state of the Bayesian quadrature methods.
"""
# pylint: disable=missing-yield-doc

# Setup
if bq_state is None:
Expand Down Expand Up @@ -293,8 +307,9 @@ def integrate(
) -> Tuple[Normal, BQState]:
"""Integrate the function ``fun``.
``fun`` may be analytically given, or numerically in terms of ``fun_evals`` at fixed nodes.
This function calls the generator ``bq_iterator`` until the first stopping criterion is met.
``fun`` may be analytically given, or numerically in terms of ``fun_evals`` at
fixed nodes. This function calls the generator ``bq_iterator`` until the first
stopping criterion is met.
Parameters
----------
Expand All @@ -310,10 +325,16 @@ def integrate(
Returns
-------
integral_belief:
integral_belief :
Posterior belief about the integral.
bq_state:
bq_state :
Final state of the Bayesian quadrature method.
Raises
------
ValueError
If neither the integrand function (``fun``) nor integrand evaluations
(``fun_evals``) are given.
"""
if fun is None and fun_evals is None:
raise ValueError("You need to provide a function to be integrated!")
Expand Down
5 changes: 3 additions & 2 deletions src/probnum/quad/solvers/belief_updates/_belief_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __call__(
Returns
-------
updated_belief :
Gaussian integral belief after conditioning on the new nodes and evaluations.
Gaussian integral belief after conditioning on the new nodes and
evaluations.
updated_state :
Updated version of ``bq_state`` that contains all updated quantities.
"""
Expand Down Expand Up @@ -110,7 +111,7 @@ def _solve_gram(gram: np.ndarray, rhs: np.ndarray) -> np.ndarray:
Returns
-------
x:
x :
The solution to the linear system :math:`K x = b`
"""
jitter = 1.0e-6
Expand Down
Loading

0 comments on commit 58c3023

Please sign in to comment.