diff --git a/.github/workflows/quality-check.yaml b/.github/workflows/quality-check.yaml index e5bd665e..0971f8f7 100644 --- a/.github/workflows/quality-check.yaml +++ b/.github/workflows/quality-check.yaml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.7, 3.8] tensorflow: ["~=2.5.0"] name: Python-${{ matrix.python-version }} tensorflow${{ matrix.tensorflow }} env: diff --git a/Makefile b/Makefile index ffab011e..f0ef2583 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,9 @@ LINT_FILE_IGNORES = "$(LIB_NAME)/__init__.py:F401,F403 \ $(LIB_NAME)/initializers/__init__.py:F401 \ $(LIB_NAME)/layers/__init__.py:F401 \ $(LIB_NAME)/layers/basis_functions/__init__.py:F401 \ + $(LIB_NAME)/layers/basis_functions/fourier_features/__init__.py:F401 \ + $(LIB_NAME)/layers/basis_functions/fourier_features/random/__init__.py:F401 \ + $(LIB_NAME)/layers/basis_functions/fourier_features/quadrature/__init__.py:F401 \ $(LIB_NAME)/models/__init__.py:F401 \ $(LIB_NAME)/optimization/__init__.py:F401 \ $(LIB_NAME)/sampling/__init__.py:F401 \ diff --git a/gpflux/layers/basis_functions/fourier_features/__init__.py b/gpflux/layers/basis_functions/fourier_features/__init__.py index 42c09d47..6f191b50 100644 --- a/gpflux/layers/basis_functions/fourier_features/__init__.py +++ b/gpflux/layers/basis_functions/fourier_features/__init__.py @@ -18,16 +18,16 @@ :class:`gpflux.sampling.KernelWithFeatureDecomposition` """ -from gpflux.layers.basis_functions.fourier_features.quadrature import QuadratureFourierFeatures +from gpflux.layers.basis_functions.fourier_features.quadrature import ( + GaussHermiteQuadratureFourierFeatures, + GaussLegendreQuadratureFourierFeatures, + QuadratureFourierFeatures, + SimpsonQuadratureFourierFeatures, +) from gpflux.layers.basis_functions.fourier_features.random import ( OrthogonalRandomFeatures, + QuasiRandomFourierFeatures, RandomFourierFeatures, RandomFourierFeaturesCosine, + ReweightedQuasiRandomFourierFeatures, ) - -__all__ = [ - "QuadratureFourierFeatures", - "OrthogonalRandomFeatures", - "RandomFourierFeatures", - "RandomFourierFeaturesCosine", -] diff --git a/gpflux/layers/basis_functions/fourier_features/base.py b/gpflux/layers/basis_functions/fourier_features/base.py index 80461c80..cf9952b6 100644 --- a/gpflux/layers/basis_functions/fourier_features/base.py +++ b/gpflux/layers/basis_functions/fourier_features/base.py @@ -16,7 +16,7 @@ """ Shared functionality for stationary kernel basis functions. """ from abc import ABC, abstractmethod -from typing import Mapping +from typing import Mapping, Tuple, Type import tensorflow as tf @@ -27,6 +27,9 @@ class FourierFeaturesBase(ABC, tf.keras.layers.Layer): + + SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] + def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): """ :param kernel: kernel to approximate using a set of Fourier bases. @@ -34,6 +37,9 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M quadrature nodes, etc.) used to numerically approximate the kernel. """ super(FourierFeaturesBase, self).__init__(**kwargs) + assert isinstance( + kernel, self.SUPPORTED_KERNELS + ), f"Only the following kernels are supported: {self.SUPPORTED_KERNELS}" self.kernel = kernel self.n_components = n_components if kwargs.get("input_dim", None): diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/__init__.py b/gpflux/layers/basis_functions/fourier_features/quadrature/__init__.py index bd3e54fb..3bb0a3f5 100644 --- a/gpflux/layers/basis_functions/fourier_features/quadrature/__init__.py +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/__init__.py @@ -15,7 +15,10 @@ # """ A kernel's features and coefficients using quadrature Fourier features (QFF). """ from gpflux.layers.basis_functions.fourier_features.quadrature.gaussian import ( + GaussHermiteQuadratureFourierFeatures, + GaussLegendreQuadratureFourierFeatures, QuadratureFourierFeatures, ) - -__all__ = ["QuadratureFourierFeatures"] +from gpflux.layers.basis_functions.fourier_features.quadrature.newton_cotes import ( + SimpsonQuadratureFourierFeatures, +) diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/base.py b/gpflux/layers/basis_functions/fourier_features/quadrature/base.py new file mode 100644 index 00000000..5b7f5ef1 --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/base.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod + +import numpy as np +import tensorflow as tf + +from gpflow.base import TensorType + +from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase +from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat +from gpflux.types import ShapeType + + +class QuadratureFourierFeaturesBase(FourierFeaturesBase): + def _compute_output_dim(self, input_shape: ShapeType) -> int: + input_dim = input_shape[-1] + return 2 * self.n_components ** input_dim + + def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + """ + Compute basis functions. + + :return: A tensor with the shape ``(N, 2L^D)``. + """ + return _bases_concat(inputs, self.abscissa) + + def _compute_constant(self) -> tf.Tensor: + """ + Compute normalizing constant for basis functions. + + :return: A tensor with the shape ``(2L^D,)`` + """ + return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2]) + + +class Transform(ABC): + r""" + This class encapsulates functions :math:`g(x) = u` and :math:`h(x)` such that + .. math:: + \int_{g(a)}^{g(b)} z(u) f(u) du + = \int_a^b w(x) h(x) f(g(x)) dx + for some integrand :math:`f(u)` and weight function :math:`z(u)`. + """ + + @abstractmethod + def __call__(self, x: TensorType) -> tf.Tensor: + pass + + @abstractmethod + def multiplier(self, x: TensorType) -> tf.Tensor: + pass + + +class TanTransform(Transform): + r""" + This class encapsulates functions :math:`g(x) = u` and :math:`h(x) = du/dx` + such that + .. math:: + \int_{-\infty}^{\infty} f(u) du + = \int_{-1}^{1} f(g(x)) h(x) dx + """ + CONST = 0.5 * np.pi + + def __call__(self, x: TensorType) -> tf.Tensor: + return tf.tan(TanTransform.CONST * x) + + def multiplier(self, x: TensorType) -> tf.Tensor: + return TanTransform.CONST / tf.square(tf.cos(TanTransform.CONST * x)) + + +class NormalWeightTransform(Transform): + r""" + This class encapsulates functions :math:`g(x) = u` and :math:`h(x)` such that + .. math:: + \int_{-\infty}^{\infty} \mathcal{N}(u|0,1) f(u) du + = \int_{-infty}^{infty} e^{-x^2} f(g(x)) h(x) dx + """ + + def __call__(self, x: TensorType) -> tf.Tensor: + return tf.sqrt(2.0) * x + + def multiplier(self, x: TensorType) -> tf.Tensor: + return tf.rsqrt(np.pi) diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py index 2391bb13..ee408d24 100644 --- a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py @@ -19,39 +19,43 @@ """ import warnings -from typing import Mapping, Tuple, Type +from typing import Mapping +import numpy as np import tensorflow as tf +import tensorflow_probability as tfp -import gpflow -from gpflow.base import TensorType -from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights +from scipy.stats import multivariate_normal, multivariate_t -from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase -from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat -from gpflux.types import ShapeType +import gpflow -""" -Kernels supported by :class:`QuadratureFourierFeatures`. +# from gpflow.config import default_float +from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights, repeat_as_list, reshape_Z_dZ -Currently we only support the :class:`gpflow.kernels.SquaredExponential` kernel. -For Matern kernels please use :class:`RandomFourierFeatures` -or :class:`RandomFourierFeaturesCosine`. -""" -QFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( - gpflow.kernels.SquaredExponential, +from gpflux.layers.basis_functions.fourier_features.quadrature.base import ( + QuadratureFourierFeaturesBase, + TanTransform, ) +from gpflux.layers.basis_functions.fourier_features.utils import _matern_dof +from gpflux.types import ShapeType +tfd = tfp.distributions -class QuadratureFourierFeatures(FourierFeaturesBase): + +class GaussianQuadratureFourierFeatures(QuadratureFourierFeaturesBase): def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): - assert isinstance(kernel, QFF_SUPPORTED_KERNELS), "Unsupported Kernel" + super(GaussianQuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs) if tf.reduce_any(tf.less(kernel.lengthscales, 1e-1)): warnings.warn( - "Quadrature Fourier feature approximation of kernels " - "with small lengthscale lead to unexpected behaviors!" + "Fourier feature approximation of kernels with small " + "lengthscales using Gaussian quadrature can have " + "unexpected behaviors!" ) - super(QuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs) + + +class GaussHermiteQuadratureFourierFeatures(GaussianQuadratureFourierFeatures): + + SUPPORTED_KERNELS = (gpflow.kernels.SquaredExponential,) def build(self, input_shape: ShapeType) -> None: """ @@ -60,33 +64,117 @@ def build(self, input_shape: ShapeType) -> None: `_. """ input_dim = input_shape[-1] - abscissa_value, omegas_value = ndgh_points_and_weights( + # (L^D, D), (L^D, 1) + abscissa_value, factors_value = ndgh_points_and_weights( dim=input_dim, n_gh=self.n_components ) - omegas_value = tf.squeeze(omegas_value, axis=-1) + factors_value = tf.squeeze(factors_value, axis=-1) # (L^D,) - # Quadrature node points - self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) # (M^D, D) - # Gauss-Hermite weights - self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,) - super(QuadratureFourierFeatures, self).build(input_shape) + # Gauss-Christoffel nodes (L^D, D) + self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) + # Gauss-Christoffel weights (L^D,) + self.factors = tf.Variable(initial_value=factors_value, trainable=False) + super(GaussHermiteQuadratureFourierFeatures, self).build(input_shape) - def _compute_output_dim(self, input_shape: ShapeType) -> int: - input_dim = input_shape[-1] - return 2 * self.n_components ** input_dim - def _compute_bases(self, inputs: TensorType) -> tf.Tensor: - """ - Compute basis functions. +class ReweightedGaussHermiteQuadratureFourierFeatures(GaussHermiteQuadratureFourierFeatures): - :return: A tensor with the shape ``[N, 2M^D]``. - """ - return _bases_concat(inputs, self.abscissa) + SUPPORTED_KERNELS = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, + ) def _compute_constant(self) -> tf.Tensor: """ Compute normalizing constant for basis functions. - :return: A tensor with the shape ``[2M^D,]`` + :return: A tensor with the shape ``[]`` (i.e. a scalar). + """ + return ( + tf.tile(tf.sqrt(self.importance_weight), multiples=[2]) + * super(ReweightedGaussHermiteQuadratureFourierFeatures, self)._compute_constant() + ) + + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. + """ + super(ReweightedGaussHermiteQuadratureFourierFeatures, self).build(input_shape) + + input_dim = input_shape[-1] + importance_weight_value = tf.ones(self.abscissa.shape[0], dtype=self.dtype) + + if not isinstance(self.kernel, gpflow.kernels.SquaredExponential): + nu = _matern_dof(self.kernel) # degrees of freedom + q = tfd.MultivariateNormalDiag(loc=tf.zeros(input_dim, dtype=self.dtype)) + p = tfd.MultivariateStudentTLinearOperator( + df=nu, + loc=tf.zeros(input_dim, dtype=self.dtype), + scale=tf.linalg.LinearOperatorLowerTriangular(tf.eye(input_dim, dtype=self.dtype)), + ) + importance_weight_value = tf.exp(p.log_prob(self.abscissa) - q.log_prob(self.abscissa)) + + self.importance_weight = tf.Variable(initial_value=importance_weight_value, + trainable=False) + + +class GaussLegendreQuadratureFourierFeatures(GaussianQuadratureFourierFeatures): + + SUPPORTED_KERNELS = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, + ) + + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. """ - return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2]) + input_dim = input_shape[-1] + + if isinstance(self.kernel, gpflow.kernels.SquaredExponential): + dist = multivariate_normal(mean=np.zeros(input_dim)) + else: + nu = _matern_dof(self.kernel) # degrees of freedom + dist = multivariate_t(loc=np.zeros(input_dim), df=nu) + + # raw 1-dimensional quadrature nodes and weights (L,) (L,) + abscissa_value_flat, factors_value_flat = np.polynomial.legendre.leggauss( + deg=self.n_components + ) + + # transformed 1-dimensional quadrature nodes and weights + transform = TanTransform() + factors_value_flat *= transform.multiplier(abscissa_value_flat) # (L,) + abscissa_value_flat = transform(abscissa_value_flat) # (L,) + + # transformed D-dimensional quadrature nodes and weights + abscissa_value_rep = repeat_as_list(abscissa_value_flat, n=input_dim) # (L, ..., L) + factors_value_rep = repeat_as_list(factors_value_flat, n=input_dim) # (L, ..., L) + # (L^D, D), (L^D, 1) + abscissa_value, factors_value = reshape_Z_dZ(abscissa_value_rep, factors_value_rep) + + factors_value = tf.squeeze(factors_value, axis=-1) # (L^D,) + factors_value *= dist.pdf(abscissa_value) # (L^D,) + + # Gauss-Christoffel nodes (L^D, D) + self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) + # Gauss-Christoffel weights (L^D,) + self.factors = tf.Variable(initial_value=factors_value, trainable=False) + + super(GaussLegendreQuadratureFourierFeatures, self).build(input_shape) + + +class QuadratureFourierFeatures(GaussHermiteQuadratureFourierFeatures): + """ + Alias for `GaussHermiteQuadratureFourierFeatures`. + """ + + pass diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/newton_cotes.py b/gpflux/layers/basis_functions/fourier_features/quadrature/newton_cotes.py new file mode 100644 index 00000000..a9143067 --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/newton_cotes.py @@ -0,0 +1,100 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Kernel decompositon into features and coefficients based on Newton-Cotes quadrature. +""" +import numpy as np +import tensorflow as tf +from scipy.stats import multivariate_normal, multivariate_t + +import gpflow +from gpflow.quadrature.gauss_hermite import repeat_as_list, reshape_Z_dZ + +from gpflux.layers.basis_functions.fourier_features.quadrature.base import ( + QuadratureFourierFeaturesBase, + TanTransform, +) +from gpflux.layers.basis_functions.fourier_features.utils import _matern_dof +from gpflux.types import ShapeType + + +class SimpsonQuadratureFourierFeatures(QuadratureFourierFeaturesBase): + + SUPPORTED_KERNELS = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, + ) + + def _compute_output_dim(self, input_shape: ShapeType) -> int: + input_dim = input_shape[-1] + n_abscissa = 2 * self.n_components + 1 + return 2 * n_abscissa ** input_dim + + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. + """ + input_dim = input_shape[-1] + + if isinstance(self.kernel, gpflow.kernels.SquaredExponential): + dist = multivariate_normal(mean=np.zeros(input_dim)) + else: + nu = _matern_dof(self.kernel) # degrees of freedom + dist = multivariate_t(loc=np.zeros(input_dim), df=nu) + + stop = 1.0 + start = -1.0 + + # `n_components` denotes half the desired number of intervals + n_abscissa = 2 * self.n_components + 1 + width = np.true_divide(stop - start, n_abscissa - 1) + + # raw 1-dimensional quadrature nodes (L,) + abscissa_value_flat = np.linspace(start, stop, n_abscissa) + + alpha = np.atleast_2d(4.0) + beta = np.atleast_2d(2.0) + a = np.hstack([beta, alpha]) + + factors_value_flat = np.append(np.tile(a, reps=self.n_components), beta, axis=-1) + factors_value_flat *= width + factors_value_flat /= 3.0 + factors_value_flat[..., [0, -1]] /= 2.0 # halve first and last weight + + # transformed 1-dimensional quadrature nodes and weights + transform = TanTransform() + factors_value_flat *= transform.multiplier(abscissa_value_flat) # (L,) + abscissa_value_flat = transform(abscissa_value_flat) # (L,) + + # transformed D-dimensional quadrature nodes and weights + abscissa_value_rep = repeat_as_list(abscissa_value_flat, n=input_dim) # (L, ..., L) + factors_value_rep = repeat_as_list(factors_value_flat, n=input_dim) # (L, ..., L) + # (L^D, D), (L^D, 1) + abscissa_value, factors_value = reshape_Z_dZ(abscissa_value_rep, factors_value_rep) + + factors_value = tf.squeeze(factors_value, axis=-1) # (L^D,) + factors_value *= dist.pdf(abscissa_value) # (L^D,) + + # Quadrature nodes (L^D, D) + self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) + # Quadrature weights (L^D,) + self.factors = tf.Variable(initial_value=factors_value, trainable=False) + + super(SimpsonQuadratureFourierFeatures, self).build(input_shape) diff --git a/gpflux/layers/basis_functions/fourier_features/random/__init__.py b/gpflux/layers/basis_functions/fourier_features/random/__init__.py index ee8337be..0e2bfaa5 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/__init__.py +++ b/gpflux/layers/basis_functions/fourier_features/random/__init__.py @@ -22,9 +22,7 @@ from gpflux.layers.basis_functions.fourier_features.random.orthogonal import ( OrthogonalRandomFeatures, ) - -__all__ = [ - "OrthogonalRandomFeatures", - "RandomFourierFeatures", - "RandomFourierFeaturesCosine", -] +from gpflux.layers.basis_functions.fourier_features.random.quasi import ( + QuasiRandomFourierFeatures, + ReweightedQuasiRandomFourierFeatures, +) diff --git a/gpflux/layers/basis_functions/fourier_features/random/base.py b/gpflux/layers/basis_functions/fourier_features/random/base.py index 82ae4aa9..eacd5f5e 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/base.py +++ b/gpflux/layers/basis_functions/fourier_features/random/base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Mapping, Optional, Tuple, Type +from typing import Optional, Tuple, Type import numpy as np import tensorflow as tf @@ -25,23 +25,10 @@ from gpflux.layers.basis_functions.fourier_features.utils import ( _bases_concat, _bases_cosine, - _matern_number, + _matern_dof, ) from gpflux.types import ShapeType -""" -Kernels supported by :class:`RandomFourierFeatures`. - -You can build RFF for shift-invariant stationary kernels from which you can -sample frequencies from their power spectrum, following Bochner's theorem. -""" -RFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( - gpflow.kernels.SquaredExponential, - gpflow.kernels.Matern12, - gpflow.kernels.Matern32, - gpflow.kernels.Matern52, -) - def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType: """ @@ -73,9 +60,13 @@ def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType: class RandomFourierFeaturesBase(FourierFeaturesBase): - def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): - assert isinstance(kernel, RFF_SUPPORTED_KERNELS), "Unsupported Kernel" - super(RandomFourierFeaturesBase, self).__init__(kernel, n_components, **kwargs) + + SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, + ) def build(self, input_shape: ShapeType) -> None: """ @@ -101,8 +92,7 @@ def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> Ten if isinstance(self.kernel, gpflow.kernels.SquaredExponential): return tf.random.normal(shape, dtype=dtype) else: - p = _matern_number(self.kernel) - nu = 2.0 * p + 1.0 # degrees of freedom + nu = _matern_dof(self.kernel) # degrees of freedom return _sample_students_t(nu, shape, dtype) @staticmethod diff --git a/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py b/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py index 395da743..13f78d94 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py +++ b/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py @@ -14,7 +14,7 @@ # limitations under the License. # -from typing import Mapping, Optional, Tuple, Type +from typing import Optional import numpy as np import tensorflow as tf @@ -25,18 +25,6 @@ from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures from gpflux.types import ShapeType -""" -Kernels supported by :class:`OrthogonalRandomFeatures`. - -This random matrix sampling scheme only applies to the :class:`gpflow.kernels.SquaredExponential` -kernel. -For Matern kernels please use :class:`RandomFourierFeatures` -or :class:`RandomFourierFeaturesCosine`. -""" -ORF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( - gpflow.kernels.SquaredExponential, -) - def _sample_chi_squared(nu: float, shape: ShapeType, dtype: DType) -> TensorType: """ @@ -69,9 +57,7 @@ class OrthogonalRandomFeatures(RandomFourierFeatures): efficient and accurate kernel approximations than :class:`RandomFourierFeatures`. """ - def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): - assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel" - super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs) + SUPPORTED_KERNELS = (gpflow.kernels.SquaredExponential,) def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: n_components, input_dim = shape # M, D diff --git a/gpflux/layers/basis_functions/fourier_features/random/quasi.py b/gpflux/layers/basis_functions/fourier_features/random/quasi.py new file mode 100644 index 00000000..bc2cdcfa --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/random/quasi.py @@ -0,0 +1,90 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +from scipy.stats import multivariate_normal, multivariate_t +from scipy.stats.qmc import MultivariateNormalQMC + +import gpflow +from gpflow.base import DType, TensorType + +from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures +from gpflux.layers.basis_functions.fourier_features.utils import _matern_dof +from gpflux.types import ShapeType + +tfd = tfp.distributions + + +class QuasiRandomFourierFeatures(RandomFourierFeatures): + + """ + Quasi-random Fourier features (ORF) :cite:p:`yang2014quasi` for more + efficient and accurate kernel approximations than random Fourier features. + """ + + def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: + n_components, input_dim = shape # M, D + sampler = MultivariateNormalQMC(mean=np.zeros(input_dim)) + return sampler.random(n=n_components) # shape [M, D] + + +class ReweightedQuasiRandomFourierFeatures(QuasiRandomFourierFeatures): + + SUPPORTED_KERNELS = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, + ) + + def _compute_constant(self) -> tf.Tensor: + """ + Compute normalizing constant for basis functions. + + :return: A tensor with the shape ``[]`` (i.e. a scalar). + """ + return ( + tf.tile(tf.sqrt(self.importance_weight), multiples=[2]) + * super(ReweightedQuasiRandomFourierFeatures, self)._compute_constant() + ) + + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. + """ + super(ReweightedQuasiRandomFourierFeatures, self).build(input_shape) + + input_dim = input_shape[-1] + importance_weight_value = tf.ones(self.n_components, dtype=self.dtype) + + if not isinstance(self.kernel, gpflow.kernels.SquaredExponential): + nu = _matern_dof(self.kernel) # degrees of freedom + q = tfd.MultivariateNormalDiag(loc=tf.zeros(input_dim, dtype=self.dtype)) + p = tfd.MultivariateStudentTLinearOperator( + df=nu, + loc=tf.zeros(input_dim, dtype=self.dtype), + scale=tf.linalg.LinearOperatorLowerTriangular(tf.eye(input_dim, dtype=self.dtype)), + ) + importance_weight_value = tf.exp(p.log_prob(self.W) - q.log_prob(self.W)) + + self.importance_weight = tf.Variable(initial_value=importance_weight_value, + trainable=False) diff --git a/gpflux/layers/basis_functions/fourier_features/utils.py b/gpflux/layers/basis_functions/fourier_features/utils.py index 71372e9e..9c88873c 100644 --- a/gpflux/layers/basis_functions/fourier_features/utils.py +++ b/gpflux/layers/basis_functions/fourier_features/utils.py @@ -34,6 +34,14 @@ def _matern_number(kernel: gpflow.kernels.Kernel) -> int: return p +def _matern_dof(kernel: gpflow.kernels.Kernel) -> float: + """ + Degrees of freedom corresponding to a kernel from the Matern family. + """ + p = _matern_number(kernel) + return 2.0 * p + 1.0 # degrees of freedom + + def _bases_cosine(X: TensorType, W: TensorType, b: TensorType) -> TensorType: """ Feature map for random Fourier features (RFF) as originally prescribed diff --git a/setup.py b/setup.py index 9108d9ff..e8d16d6f 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ "deprecated", "gpflow>=2.1", "numpy", - "scipy", + "scipy>=1.6.0", "tensorflow>=2.5.0,<2.6.0", "tensorflow-probability>=0.12.0,<0.14.0", ] diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_quadrature.py b/tests/gpflux/layers/basis_functions/fourier_features/test_quadrature.py index 2784cf53..e15618bd 100644 --- a/tests/gpflux/layers/basis_functions/fourier_features/test_quadrature.py +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_quadrature.py @@ -24,7 +24,6 @@ from gpflow.utilities.ops import difference_matrix from gpflux.layers.basis_functions.fourier_features.quadrature import QuadratureFourierFeatures -from gpflux.layers.basis_functions.fourier_features.quadrature.gaussian import QFF_SUPPORTED_KERNELS @pytest.fixture(name="n_dims", params=[1, 2, 3]) @@ -52,7 +51,7 @@ def _batch_size_fixture(request): return request.param -@pytest.fixture(name="kernel_cls", params=list(QFF_SUPPORTED_KERNELS)) +@pytest.fixture(name="kernel_cls", params=list(QuadratureFourierFeatures.SUPPORTED_KERNELS)) def _kernel_cls_fixture(request): return request.param @@ -61,7 +60,7 @@ def test_throw_for_unsupported_kernel(): kernel = gpflow.kernels.Constant() with pytest.raises(AssertionError) as excinfo: QuadratureFourierFeatures(kernel, n_components=1) - assert "Unsupported Kernel" in str(excinfo.value) + assert "Only the following kernels are supported" in str(excinfo.value) def test_quadrature_fourier_features_can_approximate_kernel_multidim( diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py index 3211a0d5..1c0e8226 100644 --- a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py @@ -26,7 +26,6 @@ RandomFourierFeatures, RandomFourierFeaturesCosine, ) -from gpflux.layers.basis_functions.fourier_features.random.base import RFF_SUPPORTED_KERNELS @pytest.fixture(name="n_dims", params=[1, 2, 3, 5, 10, 20]) @@ -54,7 +53,7 @@ def _n_features_fixture(request): return request.param -@pytest.fixture(name="kernel_cls", params=list(RFF_SUPPORTED_KERNELS)) +@pytest.fixture(name="kernel_cls", params=list(RandomFourierFeatures.SUPPORTED_KERNELS)) def _kernel_cls_fixture(request): return request.param @@ -79,7 +78,7 @@ def test_throw_for_unsupported_kernel(basis_func_cls): kernel = gpflow.kernels.Constant() with pytest.raises(AssertionError) as excinfo: basis_func_cls(kernel, n_components=1) - assert "Unsupported Kernel" in str(excinfo.value) + assert "Only the following kernels are supported" in str(excinfo.value) def test_random_fourier_features_can_approximate_kernel_multidim(