Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make quad policies stateless #744

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def bayesquad(
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: Optional[np.random.Generator] = np.random.default_rng(),
rng: Optional[np.random.Generator] = None,
jitter: FloatLike = 1.0e-8,
) -> Tuple[Normal, BQIterInfo]:
r"""Infer the solution of the uni- or multivariate integral
Expand Down Expand Up @@ -100,7 +100,7 @@ def bayesquad(
Number of new observations at each update. Defaults to 1.
rng
Random number generator. Used by Bayesian Monte Carlo other random sampling
policies. Optional. Default is `np.random.default_rng()`.
policies.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Expand Down Expand Up @@ -145,9 +145,9 @@ def bayesquad(

>>> input_dim = 1
>>> domain = (0, 1)
>>> def f(x):
>>> def fun(x):
... return x.reshape(-1, )
>>> F, info = bayesquad(fun=f, input_dim=input_dim, domain=domain)
>>> F, info = bayesquad(fun, input_dim, domain=domain, rng=np.random.default_rng(0))
>>> print(F.mean)
0.5
"""
Expand All @@ -167,12 +167,13 @@ def bayesquad(
var_tol=var_tol,
rel_tol=rel_tol,
batch_size=batch_size,
rng=rng,
jitter=jitter,
)

# Integrate
integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None)
integral_belief, _, info = bq_method.integrate(
fun=fun, nodes=None, fun_evals=None, rng=rng
)

return integral_belief, info

Expand Down Expand Up @@ -261,7 +262,7 @@ def bayesquad_from_data(

# Integrate
integral_belief, _, info = bq_method.integrate(
fun=None, nodes=nodes, fun_evals=fun_evals
fun=None, nodes=nodes, fun_evals=fun_evals, rng=None
)

return integral_belief, info
Expand Down
43 changes: 23 additions & 20 deletions src/probnum/quad/solvers/_bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def from_problem(
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: np.random.Generator = None,
jitter: FloatLike = 1.0e-8,
) -> "BayesianQuadrature":

Expand Down Expand Up @@ -112,8 +111,6 @@ def from_problem(
Relative tolerance as stopping criterion.
batch_size
Batch size used in node acquisition. Defaults to 1.
rng
The random number generator.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Expand All @@ -127,9 +124,6 @@ def from_problem(
------
ValueError
If neither a ``domain`` nor a ``measure`` are given.
ValueError
If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random
number generator (``rng``) is given.
NotImplementedError
If an unknown ``policy`` is given.
"""
Expand All @@ -153,15 +147,9 @@ def from_problem(
# require an acquisition loop. The error handling is done in ``integrate``.
pass
elif policy == "bmc":
if rng is None:
errormsg = (
"Policy 'bmc' relies on random sampling, "
"thus requires a random number generator ('rng')."
)
raise ValueError(errormsg)
policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng)
policy = RandomPolicy(batch_size, measure.sample)
elif policy == "vdc":
policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size)
policy = VanDerCorputPolicy(batch_size, measure)
else:
raise NotImplementedError(f"The given policy ({policy}) is unknown.")

Expand Down Expand Up @@ -215,6 +203,7 @@ def bq_iterator(
bq_state: BQState,
info: Optional[BQIterInfo],
fun: Optional[Callable],
rng: Optional[np.random.Generator],
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Generator that implements the iteration of the BQ method.

Expand All @@ -231,6 +220,8 @@ def bq_iterator(
fun
Function to be integrated. It needs to accept a shape=(n_eval, input_dim)
``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``.
rng
The random number generator used for random methods.

Yields
------
Expand Down Expand Up @@ -258,7 +249,7 @@ def bq_iterator(
break

# Select new nodes via policy
new_nodes = self.policy(bq_state=bq_state)
new_nodes = self.policy(bq_state, rng)

# Evaluate the integrand at new nodes
new_fun_evals = fun(new_nodes)
Expand All @@ -278,6 +269,7 @@ def integrate(
fun: Optional[Callable],
nodes: Optional[np.ndarray],
fun_evals: Optional[np.ndarray],
rng: Optional[np.random.Generator] = None,
) -> Tuple[Normal, BQState, BQIterInfo]:
"""Integrates the function ``fun``.

Expand All @@ -297,6 +289,8 @@ def integrate(
fun_evals
*shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available
from the start.
rng
The random number generator used for random methods.

Returns
-------
Expand All @@ -308,14 +302,17 @@ def integrate(
Raises
------
ValueError
If neither the integrand function (``fun``) nor integrand evaluations
(``fun_evals``) are given.
If neither the integrand function ``fun`` nor integrand evaluations
``fun_evals`` are given.
ValueError
If ``nodes`` are not given and no policy is present.
If neither ``nodes`` nor ``policy`` is given.
ValueError
If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their
shapes do not match.
ValueError
If ``rng`` is not given but ``policy`` requires it.
"""

# no policy given: Integrate on fixed dataset.
if self.policy is None:
# nodes must be provided if no policy is given.
Expand All @@ -325,13 +322,19 @@ def integrate(
# Use fun_evals and disregard fun if both are given
if fun is not None and fun_evals is not None:
warnings.warn(
"No policy available: 'fun_eval' are used instead of 'fun'."
"No policy available: 'fun_evals' are used instead of 'fun'."
)
fun = None

# override stopping condition as no policy is given.
self.stopping_criterion = ImmediateStop()

elif self.policy.requires_rng and rng is None:
raise ValueError(
f"The policy '{self.policy.__class__.__name__}' requires a random "
f"number generator (rng) to be given."
)

# Check if integrand function is provided
if fun is None and fun_evals is None:
raise ValueError(
Expand Down Expand Up @@ -375,7 +378,7 @@ def integrate(
)

info = None
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun):
for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng):
pass

return bq_state.integral_belief, bq_state, info
21 changes: 18 additions & 3 deletions src/probnum/quad/solvers/policies/_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Abstract base class for BQ policies."""

from __future__ import annotations

import abc
from typing import Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

# pylint: disable=too-few-public-methods, fixme

Expand All @@ -18,17 +22,28 @@ class Policy(abc.ABC):
Size of batch of nodes when calling the policy once.
"""

def __init__(self, batch_size: int) -> None:
self.batch_size = batch_size
def __init__(self, batch_size: IntLike) -> None:
self.batch_size = int(batch_size)

@property
@abc.abstractmethod
def __call__(self, bq_state: BQState) -> np.ndarray:
def requires_rng(self) -> bool:
"""Whether the policy requires a random number generator when called."""
raise NotImplementedError

@abc.abstractmethod
def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
"""Find nodes according to the policy.

Parameters
----------
bq_state
State of the BQ belief.
rng
A random number generator.

Returns
-------
nodes :
Expand Down
27 changes: 16 additions & 11 deletions src/probnum/quad/solvers/policies/_random_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Random policy for Bayesian Monte Carlo."""

from typing import Callable
from __future__ import annotations

from typing import Callable, Optional

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -16,25 +19,27 @@ class RandomPolicy(Policy):

Parameters
----------
batch_size
Size of batch of nodes when calling the policy once.
sample_func
The sample function. Needs to have the following interface:
`sample_func(batch_size: int, rng: np.random.Generator)` and return an array of
shape (batch_size, n_dim).
batch_size
Size of batch of nodes when calling the policy once.
rng
A random number generator.
shape (batch_size, input_dim).
"""

def __init__(
self,
batch_size: IntLike,
sample_func: Callable,
batch_size: int,
rng: np.random.Generator = np.random.default_rng(),
) -> None:
super().__init__(batch_size=batch_size)
self.sample_func = sample_func
self.rng = rng

def __call__(self, bq_state: BQState) -> np.ndarray:
return self.sample_func(self.batch_size, rng=self.rng)
@property
def requires_rng(self) -> bool:
return True

def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
return self.sample_func(self.batch_size, rng=rng)
17 changes: 13 additions & 4 deletions src/probnum/quad/solvers/policies/_van_der_corput_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Van der Corput points for integration on 1D intervals."""

from __future__ import annotations

from typing import Optional

import numpy as np

from probnum.quad.integration_measures import IntegrationMeasure
from probnum.quad.solvers._bq_state import BQState
from probnum.typing import IntLike

from ._policy import Policy

Expand All @@ -22,17 +25,17 @@ class VanDerCorputPolicy(Policy):

Parameters
----------
measure
The integration measure with finite domain.
batch_size
Size of batch of nodes when calling the policy once.
measure
The integration measure with finite domain.

References
--------
.. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence
"""

def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None:
super().__init__(batch_size=batch_size)

if int(measure.input_dim) > 1:
Expand All @@ -46,7 +49,13 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None:
self.domain_a = domain_a
self.domain_b = domain_b

def __call__(self, bq_state: BQState) -> np.ndarray:
@property
def requires_rng(self) -> bool:
return False

def __call__(
self, bq_state: BQState, rng: Optional[np.random.Generator]
) -> np.ndarray:
n_nodes = bq_state.nodes.shape[0]
vdc_seq = VanDerCorputPolicy.van_der_corput_sequence(
n_nodes + 1, n_nodes + 1 + self.batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion):
"""

def __init__(self, max_nevals: IntLike):
self.max_nevals = max_nevals
self.max_nevals = int(max_nevals)

def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool:
return info.nevals >= self.max_nevals
Loading