-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Rename `Kernel` -> `CovarianceFunction` * Renaming in `test_randprocs` * Renaming in benchmarks * Pylint fixes * Bugfix * Update CODEOWNERS
- Loading branch information
1 parent
e17a258
commit bddbd23
Showing
41 changed files
with
643 additions
and
598 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Benchmarks for covariance functions.""" | ||
|
||
import numpy as np | ||
|
||
from probnum.randprocs import covfuncs | ||
|
||
# Module level variables | ||
COVFUNC_NAMES = [ | ||
"white_noise", | ||
"linear", | ||
"polynomial", | ||
"exp_quad", | ||
"rat_quad", | ||
"matern12", | ||
"matern32", | ||
"matern52", | ||
"matern72", | ||
] | ||
|
||
N_DATAPOINTS = [10, 100, 1000] | ||
|
||
|
||
def get_covfunc(covfunc_name, input_shape): | ||
"""Return a covariance function for a given name.""" | ||
if covfunc_name == "white_noise": | ||
k = covfuncs.WhiteNoise(input_shape=input_shape) | ||
elif covfunc_name == "linear": | ||
k = covfuncs.Linear(input_shape=input_shape) | ||
elif covfunc_name == "polynomial": | ||
k = covfuncs.Polynomial(input_shape=input_shape) | ||
elif covfunc_name == "exp_quad": | ||
k = covfuncs.ExpQuad(input_shape=input_shape) | ||
elif covfunc_name == "rat_quad": | ||
k = covfuncs.RatQuad(input_shape=input_shape) | ||
elif covfunc_name == "matern12": | ||
k = covfuncs.Matern(input_shape=input_shape, nu=0.5) | ||
elif covfunc_name == "matern32": | ||
k = covfuncs.Matern(input_shape=input_shape, nu=1.5) | ||
elif covfunc_name == "matern52": | ||
k = covfuncs.Matern(input_shape=input_shape, nu=2.5) | ||
elif covfunc_name == "matern72": | ||
k = covfuncs.Matern(input_shape=input_shape, nu=3.5) | ||
else: | ||
raise ValueError(f"Covariance function '{covfunc_name}' not recognized.") | ||
|
||
return k | ||
|
||
|
||
class CovarianceFunctions: | ||
"""Benchmark evaluation of a covariance function at a set of inputs.""" | ||
|
||
param_names = ["covfunc", "n_datapoints"] | ||
params = [COVFUNC_NAMES, N_DATAPOINTS] | ||
|
||
def setup(self, covfunc, n_datapoints): | ||
rng = np.random.default_rng(42) | ||
self.input_dim = 100 | ||
self.data = rng.normal(size=(n_datapoints, self.input_dim)) | ||
self.covfunc = get_covfunc(covfunc_name=covfunc, input_shape=self.input_dim) | ||
|
||
def time_covfunc_call(self, covfunc, n_datapoints): | ||
self.covfunc(self.data, None) | ||
|
||
def time_covfunc_matrix(self, covfunc, n_datapoints): | ||
"""Times sampling from this distribution.""" | ||
self.covfunc.matrix(self.data) | ||
|
||
def peakmem_covfunc_matrix(self, covfunc, n_datapoints): | ||
"""Peak memory of sampling process.""" | ||
self.covfunc.matrix(self.data) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,5 +9,5 @@ probnum.randprocs | |
.. toctree:: | ||
:hidden: | ||
|
||
randprocs/covfuncs | ||
randprocs/markov | ||
randprocs/kernels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
************************** | ||
probnum.randprocs.covfuncs | ||
************************** | ||
|
||
.. automodapi:: probnum.randprocs.covfuncs | ||
:no-heading: | ||
:headings: "=" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Covariance functions. | ||
Covariance functions describe the spatial or temporal variation of a random process. | ||
If evaluated at two sets of points, a covariance function computes the covariance of the | ||
values of the random process at these locations. | ||
Covariance functions support basic algebraic operations, including scaling, addition | ||
and multiplication. | ||
""" | ||
|
||
from ._covariance_function import CovarianceFunction, IsotropicMixin | ||
from ._exponentiated_quadratic import ExpQuad | ||
from ._linear import Linear | ||
from ._matern import Matern | ||
from ._polynomial import Polynomial | ||
from ._product_matern import ProductMatern | ||
from ._rational_quadratic import RatQuad | ||
from ._white_noise import WhiteNoise | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"CovarianceFunction", | ||
"IsotropicMixin", | ||
"WhiteNoise", | ||
"Linear", | ||
"Polynomial", | ||
"ExpQuad", | ||
"RatQuad", | ||
"Matern", | ||
"ProductMatern", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
CovarianceFunction.__module__ = "probnum.randprocs.covfuncs" | ||
IsotropicMixin.__module__ = "probnum.randprocs.covfuncs" | ||
|
||
WhiteNoise.__module__ = "probnum.randprocs.covfuncs" | ||
Linear.__module__ = "probnum.randprocs.covfuncs" | ||
Polynomial.__module__ = "probnum.randprocs.covfuncs" | ||
ExpQuad.__module__ = "probnum.randprocs.covfuncs" | ||
RatQuad.__module__ = "probnum.randprocs.covfuncs" | ||
Matern.__module__ = "probnum.randprocs.covfuncs" | ||
ProductMatern.__module__ = "probnum.randprocs.covfuncs" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Covariance function arithmetic.""" | ||
from ._arithmetic_fallbacks import SumCovarianceFunction, _mul_fallback | ||
from ._covariance_function import BinaryOperandType, CovarianceFunction | ||
|
||
|
||
# pylint: disable=missing-param-doc | ||
def add(op1: BinaryOperandType, op2: BinaryOperandType) -> CovarianceFunction: | ||
"""Covariance function summation.""" | ||
return SumCovarianceFunction(op1, op2) | ||
|
||
|
||
def mul(op1: BinaryOperandType, op2: BinaryOperandType) -> CovarianceFunction: | ||
"""Covariance function multiplication.""" | ||
return _mul_fallback(op1, op2) | ||
|
||
|
||
# pylint: enable=missing-param-doc |
Oops, something went wrong.