Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Active dimension support (including for combination kernels) #105

Open
wants to merge 3 commits into
base: khurram/rff_additive_kernels
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions gpflux/layers/basis_functions/fourier_features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
""" Shared functionality for stationary kernel basis functions. """

from abc import ABC, abstractmethod
from typing import Mapping
from itertools import cycle
from typing import Mapping, Optional

import tensorflow as tf

Expand Down Expand Up @@ -74,12 +75,19 @@ def call(self, inputs: TensorType) -> tf.Tensor:
:return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case.
"""
if self.is_batched:
X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels]
X = tf.stack(X, 0) # [1, N, D] or [P, N, D]
bases = [
# restrict inputs to the appropriate active_dims for each sub_kernel
self._compute_bases(tf.divide(k.slice(inputs, None)[0], k.lengthscales), i)
# SharedIndependent repeatedly uses the same sub_kernel
for i, k in zip(range(self.batch_size), cycle(self.sub_kernels))
]
bases = tf.stack(bases, axis=0) # [P, N, M]
else:
X = tf.divide(inputs, self.kernel.lengthscales) # [N, D]
# restrict inputs to the kernel's active_dims
X = tf.divide(self.kernel.slice(inputs, None)[0], self.kernel.lengthscales) # [N, D]
bases = self._compute_bases(X, None) # [N, M]

const = self._compute_constant() # [] or [P, 1, 1]
bases = self._compute_bases(X) # [N, M] or [P, N, M]
output = const * bases

if self.is_batched and not self.is_multioutput:
Expand Down Expand Up @@ -139,8 +147,10 @@ def _compute_constant(self) -> tf.Tensor:
pass

@abstractmethod
def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.

For batched layers (self.is_batched), batch indicates which sub-kernel to target.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import warnings
from typing import Mapping, Tuple, Type
from typing import Mapping, Optional, Tuple, Type

import tensorflow as tf

Expand Down Expand Up @@ -75,7 +75,7 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
input_dim = input_shape[-1]
return 2 * self.n_components ** input_dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.

Expand Down
95 changes: 58 additions & 37 deletions gpflux/layers/basis_functions/fourier_features/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Mapping, Optional, Tuple, Type
from itertools import cycle
from typing import Callable, Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType
from gpflow.kernels import Kernel

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import (
Expand Down Expand Up @@ -116,18 +118,33 @@ def build(self, input_shape: ShapeType) -> None:
self._weights_build(input_dim, n_components=self.n_components)
super(RandomFourierFeaturesBase, self).build(input_shape)

def _active_input_dim(self, input_dim: int, kernel: Kernel) -> int:
dummy_X = tf.zeros((0, input_dim), dtype=tf.float64)
return kernel.slice(dummy_X, None)[0].shape[-1]

def _weights_build(self, input_dim: int, n_components: int) -> None:
# for batched layers we store a list of weights, as each may have a different
# active input dimension
if self.is_batched:
shape = (self.batch_size, n_components, input_dim) # [P, M, D]
self.W = [
self.add_weight(
name="weights",
trainable=False,
shape=(n_components, self._active_input_dim(input_dim, k)),
dtype=self.dtype,
initializer=self._weights_init(k),
)
# SharedIndependent repeatedly uses the same sub_kernel
for _, k in zip(range(self.batch_size), cycle(self.sub_kernels))
]
else:
shape = (n_components, input_dim) # type: ignore
self.W = self.add_weight(
name="weights",
trainable=False,
shape=shape,
dtype=self.dtype,
initializer=self._weights_init,
)
self.W = self.add_weight(
name="weights",
trainable=False,
shape=(n_components, self._active_input_dim(input_dim, self.kernel)),
dtype=self.dtype,
initializer=self._weights_init(self.kernel),
)

def _weights_init_individual(
self,
Expand All @@ -142,20 +159,11 @@ def _weights_init_individual(
nu = 2.0 * p + 1.0 # degrees of freedom
return _sample_students_t(nu, shape, dtype)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
if self.is_batched:
if isinstance(self.kernel, gpflow.kernels.SharedIndependent):
weights_list = [
self._weights_init_individual(self.sub_kernels[0], shape[1:], dtype)
for _ in range(self.batch_size)
]
else:
weights_list = [
self._weights_init_individual(k, shape[1:], dtype) for k in self.sub_kernels
]
return tf.stack(weights_list, 0) # [P, M, D]
else:
return self._weights_init_individual(self.kernel, shape, dtype) # [M, D]
def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]:
def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return self._weights_init_individual(kernel, shape, dtype) # [M, D]

return _initializer

@staticmethod
def rff_constant(variance: TensorType, output_dim: int) -> tf.Tensor:
Expand Down Expand Up @@ -207,13 +215,13 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.

:return: A tensor with the shape ``[N, 2M]`` or ``[P, N, 2M]``.
"""
return _bases_concat(inputs, self.W)
return _bases_concat(inputs, self.W if batch is None else self.W[batch])

def _compute_constant(self) -> tf.Tensor:
"""
Expand Down Expand Up @@ -271,17 +279,26 @@ def build(self, input_shape: ShapeType) -> None:
super(RandomFourierFeaturesCosine, self).build(input_shape)

def _bias_build(self, n_components: int) -> None:
# for batched layers we store a list of biases, to match the weights structure
if self.is_batched:
shape = (self.batch_size, 1, n_components)
self.b = [
self.add_weight(
name="bias",
trainable=False,
shape=(1, n_components),
dtype=self.dtype,
initializer=self._bias_init,
)
for _ in range(self.batch_size)
]
else:
shape = (1, n_components) # type: ignore
self.b = self.add_weight(
name="bias",
trainable=False,
shape=shape,
dtype=self.dtype,
initializer=self._bias_init,
)
self.b = self.add_weight(
name="bias",
trainable=False,
shape=(1, n_components),
dtype=self.dtype,
initializer=self._bias_init,
)

def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype)
Expand All @@ -294,13 +311,17 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.

:return: A tensor with the shape ``[N, M]`` or ``[P, N, M]``.
"""
return _bases_cosine(inputs, self.W, self.b)
return _bases_cosine(
inputs,
self.W if batch is None else self.W[batch],
self.b if batch is None else self.b[batch],
)

def _compute_constant(self) -> tf.Tensor:
"""
Expand Down
25 changes: 15 additions & 10 deletions gpflux/layers/basis_functions/fourier_features/random/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
#

from typing import Mapping, Optional, Tuple, Type
from typing import Callable, Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType
from gpflow.kernels import Kernel

from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures
from gpflux.types import ShapeType
Expand Down Expand Up @@ -73,15 +74,19 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M
assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel"
super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M
def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]:
def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType:

W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M

s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]
W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)
s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)

return _initializer
3 changes: 1 addition & 2 deletions tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ pytest
pytest-cov
pytest-random-order
pytest-mock
tqdm

# For mypy stubs:
types-Deprecated
numpy

tqdm

# Notebook tests:
jupytext
nbformat
Expand Down
Loading