-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Gauss-Legendre Quadrature Fourier features #64
base: develop
Are you sure you want to change the base?
Changes from 17 commits
a8483bc
6e10cd0
61992ee
8c745b0
55d989a
5b47864
bb09fe8
52e4381
3f05fd8
0dcf4c6
031dc9a
e465f35
58da313
0cf22ff
abd375a
8d04e69
4173ff0
8d2e9a7
39e432e
383df2a
e9cbc77
6613736
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,98 @@ | ||||||||||||
# | ||||||||||||
# Copyright (c) 2021 The GPflux Contributors. | ||||||||||||
# | ||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||
# You may obtain a copy of the License at | ||||||||||||
# | ||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||
# | ||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||
# limitations under the License. | ||||||||||||
# | ||||||||||||
|
||||||||||||
from abc import ABC, abstractmethod | ||||||||||||
|
||||||||||||
import numpy as np | ||||||||||||
import tensorflow as tf | ||||||||||||
|
||||||||||||
from gpflow.base import TensorType | ||||||||||||
|
||||||||||||
from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase | ||||||||||||
from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat | ||||||||||||
from gpflux.types import ShapeType | ||||||||||||
|
||||||||||||
|
||||||||||||
class QuadratureFourierFeaturesBase(FourierFeaturesBase): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please give a type and short description of
Suggested change
|
||||||||||||
def _compute_output_dim(self, input_shape: ShapeType) -> int: | ||||||||||||
input_dim = input_shape[-1] | ||||||||||||
return 2 * self.n_components ** input_dim | ||||||||||||
|
||||||||||||
def _compute_bases(self, inputs: TensorType) -> tf.Tensor: | ||||||||||||
""" | ||||||||||||
Compute basis functions. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: the word "compute" made me first think the function would return a functor with the basis functions.
Suggested change
|
||||||||||||
|
||||||||||||
:return: A tensor with the shape ``(N, 2L^D)``. | ||||||||||||
""" | ||||||||||||
return _bases_concat(inputs, self.abscissa) | ||||||||||||
|
||||||||||||
def _compute_constant(self) -> tf.Tensor: | ||||||||||||
""" | ||||||||||||
Compute normalizing constant for basis functions. | ||||||||||||
|
||||||||||||
:return: A tensor with the shape ``(2L^D,)`` | ||||||||||||
""" | ||||||||||||
return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2]) | ||||||||||||
|
||||||||||||
|
||||||||||||
class Transform(ABC): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Transform" is quite an overloaded word in ML software packages. Is it possible to be more specific?
Suggested change
|
||||||||||||
r""" | ||||||||||||
This class encapsulates functions :math:`g(x) = u` and :math:`h(x)` such that | ||||||||||||
.. math:: | ||||||||||||
\int_{g(a)}^{g(b)} z(u) f(u) du | ||||||||||||
= \int_a^b w(x) h(x) f(g(x)) dx | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have you double-checked that the maths renders well in the docs? |
||||||||||||
for some integrand :math:`f(u)` and weight function :math:`z(u)`. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def __call__(self, x: TensorType) -> tf.Tensor: | ||||||||||||
pass | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def multiplier(self, x: TensorType) -> tf.Tensor: | ||||||||||||
pass | ||||||||||||
|
||||||||||||
|
||||||||||||
class TanTransform(Transform): | ||||||||||||
r""" | ||||||||||||
This class encapsulates functions :math:`g(x) = u` and :math:`h(x) = du/dx` | ||||||||||||
such that | ||||||||||||
.. math:: | ||||||||||||
\int_{-\infty}^{\infty} f(u) du | ||||||||||||
= \int_{-1}^{1} f(g(x)) h(x) dx | ||||||||||||
""" | ||||||||||||
CONST = 0.5 * np.pi | ||||||||||||
|
||||||||||||
def __call__(self, x: TensorType) -> tf.Tensor: | ||||||||||||
return tf.tan(TanTransform.CONST * x) | ||||||||||||
|
||||||||||||
def multiplier(self, x: TensorType) -> tf.Tensor: | ||||||||||||
return TanTransform.CONST / tf.square(tf.cos(TanTransform.CONST * x)) | ||||||||||||
|
||||||||||||
|
||||||||||||
class NormalWeightTransform(Transform): | ||||||||||||
r""" | ||||||||||||
This class encapsulates functions :math:`g(x) = u` and :math:`h(x)` such that | ||||||||||||
.. math:: | ||||||||||||
\int_{-\infty}^{\infty} \mathcal{N}(u|0,1) f(u) du | ||||||||||||
= \int_{-infty}^{infty} e^{-x^2} f(g(x)) h(x) dx | ||||||||||||
""" | ||||||||||||
|
||||||||||||
def __call__(self, x: TensorType) -> tf.Tensor: | ||||||||||||
return tf.sqrt(2.0) * x | ||||||||||||
|
||||||||||||
def multiplier(self, x: TensorType) -> tf.Tensor: | ||||||||||||
return tf.rsqrt(np.pi) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,39 +19,39 @@ | |||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
import warnings | ||||||||||||||||||
from typing import Mapping, Tuple, Type | ||||||||||||||||||
from typing import Mapping | ||||||||||||||||||
|
||||||||||||||||||
import numpy as np | ||||||||||||||||||
import tensorflow as tf | ||||||||||||||||||
from scipy.stats import multivariate_normal, multivariate_t | ||||||||||||||||||
|
||||||||||||||||||
import gpflow | ||||||||||||||||||
from gpflow.base import TensorType | ||||||||||||||||||
from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights | ||||||||||||||||||
|
||||||||||||||||||
from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase | ||||||||||||||||||
from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat | ||||||||||||||||||
from gpflux.types import ShapeType | ||||||||||||||||||
|
||||||||||||||||||
""" | ||||||||||||||||||
Kernels supported by :class:`QuadratureFourierFeatures`. | ||||||||||||||||||
# from gpflow.config import default_float | ||||||||||||||||||
from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights, repeat_as_list, reshape_Z_dZ | ||||||||||||||||||
|
||||||||||||||||||
Currently we only support the :class:`gpflow.kernels.SquaredExponential` kernel. | ||||||||||||||||||
For Matern kernels please use :class:`RandomFourierFeatures` | ||||||||||||||||||
or :class:`RandomFourierFeaturesCosine`. | ||||||||||||||||||
""" | ||||||||||||||||||
QFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( | ||||||||||||||||||
gpflow.kernels.SquaredExponential, | ||||||||||||||||||
from gpflux.layers.basis_functions.fourier_features.quadrature.base import ( | ||||||||||||||||||
QuadratureFourierFeaturesBase, | ||||||||||||||||||
TanTransform, | ||||||||||||||||||
) | ||||||||||||||||||
from gpflux.layers.basis_functions.fourier_features.utils import _matern_dof | ||||||||||||||||||
from gpflux.types import ShapeType | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class QuadratureFourierFeatures(FourierFeaturesBase): | ||||||||||||||||||
class GaussianQuadratureFourierFeatures(QuadratureFourierFeaturesBase): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): | ||||||||||||||||||
assert isinstance(kernel, QFF_SUPPORTED_KERNELS), "Unsupported Kernel" | ||||||||||||||||||
super(GaussianQuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs) | ||||||||||||||||||
if tf.reduce_any(tf.less(kernel.lengthscales, 1e-1)): | ||||||||||||||||||
warnings.warn( | ||||||||||||||||||
"Quadrature Fourier feature approximation of kernels " | ||||||||||||||||||
"with small lengthscale lead to unexpected behaviors!" | ||||||||||||||||||
"Fourier feature approximation of kernels with small " | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add what you mean by "small". |
||||||||||||||||||
"lengthscales using Gaussian quadrature can have " | ||||||||||||||||||
"unexpected behaviors!" | ||||||||||||||||||
) | ||||||||||||||||||
super(QuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class GaussHermiteQuadratureFourierFeatures(GaussianQuadratureFourierFeatures): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add a reference to which maths this class implements? |
||||||||||||||||||
|
||||||||||||||||||
SUPPORTED_KERNELS = (gpflow.kernels.SquaredExponential,) | ||||||||||||||||||
|
||||||||||||||||||
def build(self, input_shape: ShapeType) -> None: | ||||||||||||||||||
""" | ||||||||||||||||||
|
@@ -60,33 +60,72 @@ def build(self, input_shape: ShapeType) -> None: | |||||||||||||||||
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_. | ||||||||||||||||||
""" | ||||||||||||||||||
input_dim = input_shape[-1] | ||||||||||||||||||
abscissa_value, omegas_value = ndgh_points_and_weights( | ||||||||||||||||||
# (L^D, D), (L^D, 1) | ||||||||||||||||||
abscissa_value, factors_value = ndgh_points_and_weights( | ||||||||||||||||||
dim=input_dim, n_gh=self.n_components | ||||||||||||||||||
) | ||||||||||||||||||
omegas_value = tf.squeeze(omegas_value, axis=-1) | ||||||||||||||||||
factors_value = tf.squeeze(factors_value, axis=-1) # (L^D,) | ||||||||||||||||||
|
||||||||||||||||||
# Quadrature node points | ||||||||||||||||||
self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) # (M^D, D) | ||||||||||||||||||
# Gauss-Hermite weights | ||||||||||||||||||
self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,) | ||||||||||||||||||
super(QuadratureFourierFeatures, self).build(input_shape) | ||||||||||||||||||
# Gauss-Christoffel nodes (L^D, D) | ||||||||||||||||||
self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) | ||||||||||||||||||
# Gauss-Christoffel weights (L^D,) | ||||||||||||||||||
self.factors = tf.Variable(initial_value=factors_value, trainable=False) | ||||||||||||||||||
super(GaussHermiteQuadratureFourierFeatures, self).build(input_shape) | ||||||||||||||||||
|
||||||||||||||||||
def _compute_output_dim(self, input_shape: ShapeType) -> int: | ||||||||||||||||||
input_dim = input_shape[-1] | ||||||||||||||||||
return 2 * self.n_components ** input_dim | ||||||||||||||||||
|
||||||||||||||||||
def _compute_bases(self, inputs: TensorType) -> tf.Tensor: | ||||||||||||||||||
""" | ||||||||||||||||||
Compute basis functions. | ||||||||||||||||||
class GaussLegendreQuadratureFourierFeatures(GaussianQuadratureFourierFeatures): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add a reference to which maths this class implements? |
||||||||||||||||||
|
||||||||||||||||||
:return: A tensor with the shape ``[N, 2M^D]``. | ||||||||||||||||||
""" | ||||||||||||||||||
return _bases_concat(inputs, self.abscissa) | ||||||||||||||||||
SUPPORTED_KERNELS = ( | ||||||||||||||||||
gpflow.kernels.SquaredExponential, | ||||||||||||||||||
gpflow.kernels.Matern12, | ||||||||||||||||||
gpflow.kernels.Matern32, | ||||||||||||||||||
gpflow.kernels.Matern52, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
def _compute_constant(self) -> tf.Tensor: | ||||||||||||||||||
def build(self, input_shape: ShapeType) -> None: | ||||||||||||||||||
""" | ||||||||||||||||||
Compute normalizing constant for basis functions. | ||||||||||||||||||
|
||||||||||||||||||
:return: A tensor with the shape ``[2M^D,]`` | ||||||||||||||||||
Creates the variables of the layer. | ||||||||||||||||||
See `tf.keras.layers.Layer.build() | ||||||||||||||||||
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_. | ||||||||||||||||||
""" | ||||||||||||||||||
return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2]) | ||||||||||||||||||
input_dim = input_shape[-1] | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(self.kernel, gpflow.kernels.SquaredExponential): | ||||||||||||||||||
dist = multivariate_normal(mean=np.zeros(input_dim)) | ||||||||||||||||||
else: | ||||||||||||||||||
nu = _matern_dof(self.kernel) # degrees of freedom | ||||||||||||||||||
dist = multivariate_t(loc=np.zeros(input_dim), df=nu) | ||||||||||||||||||
|
||||||||||||||||||
# raw 1-dimensional quadrature nodes and weights (L,) (L,) | ||||||||||||||||||
abscissa_value_flat, factors_value_flat = np.polynomial.legendre.leggauss( | ||||||||||||||||||
deg=self.n_components | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# transformed 1-dimensional quadrature nodes and weights | ||||||||||||||||||
transform = TanTransform() | ||||||||||||||||||
factors_value_flat *= transform.multiplier(abscissa_value_flat) # (L,) | ||||||||||||||||||
abscissa_value_flat = transform(abscissa_value_flat) # (L,) | ||||||||||||||||||
|
||||||||||||||||||
# transformed D-dimensional quadrature nodes and weights | ||||||||||||||||||
abscissa_value_rep = repeat_as_list(abscissa_value_flat, n=input_dim) # (L, ..., L) | ||||||||||||||||||
factors_value_rep = repeat_as_list(factors_value_flat, n=input_dim) # (L, ..., L) | ||||||||||||||||||
# (L^D, D), (L^D, 1) | ||||||||||||||||||
abscissa_value, factors_value = reshape_Z_dZ(abscissa_value_rep, factors_value_rep) | ||||||||||||||||||
|
||||||||||||||||||
factors_value = tf.squeeze(factors_value, axis=-1) # (L^D,) | ||||||||||||||||||
factors_value *= dist.pdf(abscissa_value) # (L^D,) | ||||||||||||||||||
|
||||||||||||||||||
# Gauss-Christoffel nodes (L^D, D) | ||||||||||||||||||
self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) | ||||||||||||||||||
# Gauss-Christoffel weights (L^D,) | ||||||||||||||||||
self.factors = tf.Variable(initial_value=factors_value, trainable=False) | ||||||||||||||||||
|
||||||||||||||||||
super(GaussLegendreQuadratureFourierFeatures, self).build(input_shape) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class QuadratureFourierFeatures(GaussHermiteQuadratureFourierFeatures): | ||||||||||||||||||
""" | ||||||||||||||||||
Alias for `GaussHermiteQuadratureFourierFeatures`. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
pass | ||||||||||||||||||
Comment on lines
+175
to
+180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe explain why we have this alias in place? Is this the most common and recommended quadrature rule? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: You could make this message more informative by specifying which kernel the user used.