Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Jul 12, 2024
1 parent dc61fb1 commit daf43a3
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 35 deletions.
6 changes: 4 additions & 2 deletions gpytorch/kernels/cosine_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class CosineKernel(Kernel):
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
"""

is_stationary = True

def __init__(
self,
period_length_prior: Optional[Prior] = None,
Expand Down Expand Up @@ -85,6 +83,10 @@ def __init__(

self.register_constraint("raw_period_length", period_length_constraint)

@property
def is_stationary(self):
return True

@property
def period_length(self):
return self.raw_period_length_constraint.transform(self.raw_period_length)
Expand Down
4 changes: 1 addition & 3 deletions gpytorch/kernels/cylindrical_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from .. import settings
from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel
Expand Down Expand Up @@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal
else:
angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))

with settings.lazily_evaluate_kernels(False):
radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params)
return radial_kernel.mul(angular_kernel)

def kuma(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
7 changes: 7 additions & 0 deletions gpytorch/kernels/grid_interpolation_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def __init__(
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))

@property
def _lazily_evaluate(self) -> bool:
# GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel
# matrix always needs to be evaluated; regardless of the size of x1 and x2), and the
# InterpolatedLinearOperator structure is needed for fast predictions.
return False

@property
def _tight_grid_bounds(self):
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
Expand Down
13 changes: 9 additions & 4 deletions gpytorch/kernels/grid_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class GridKernel(Kernel):
http://www.cs.cmu.edu/~andrewgw/manet.pdf
"""

# TODO: update doc

is_stationary = True

def __init__(
self,
base_kernel: Kernel,
Expand Down Expand Up @@ -76,6 +72,15 @@ def __init__(
# Also create the full_grid buffer
self.update_grid(grid)

@property
def _lazily_evaluate(self) -> bool:
# Toeplitz structure is very efficient; no need to lazily evaluate
return False

@property
def is_stationary(self) -> bool:
return True

def _clear_cache(self):
if hasattr(self, "_cached_kernel_mat"):
del self._cached_kernel_mat
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/index_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def __init__(

self.register_constraint("raw_var", var_constraint)

@property
def _lazily_evaluate(self) -> bool:
# IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always
# computed regardless of x1 and x2
return False

@property
def var(self):
return self.raw_var_constraint.transform(self.raw_var)
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/inducing_point_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def _clear_cache(self):
if hasattr(self, "_cached_kernel_inv_root"):
del self._cached_kernel_inv_root

@property
def _lazily_evaluate(self) -> bool:
# InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula,
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
return False

@property
def _inducing_mat(self):
if not self.training and hasattr(self, "_cached_kernel_mat"):
Expand Down
1 change: 0 additions & 1 deletion gpytorch/kernels/keops/rbf_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

# from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
Expand Down
161 changes: 151 additions & 10 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import warnings
from abc import abstractmethod
from collections import defaultdict, OrderedDict
from copy import deepcopy
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch
from linear_operator import to_dense, to_linear_operator
from linear_operator.operators import LinearOperator, ZeroLinearOperator
from linear_operator.operators import KernelLinearOperator, LinearOperator, ZeroLinearOperator
from torch import Tensor
from torch.nn import ModuleList

Expand Down Expand Up @@ -81,6 +82,44 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
return self._postprocess(res) if postprocess else res


class _autograd_kernel_hack:
"""
Helper class.
When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients.
(Any Tensor that `covar_func` closes over will not backpropagate gradients.)
Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters.
This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an
external set of references to these parameters.
The external set of references will be passed in by KernelLinearOperator.
This way, when calling self.forward, no parameter references are closed over, and so all parameters
will receive the appropriate gradients.
"""

def __init__(
self,
kernel: Kernel,
params: Dict[str, torch.nn.Parameters],
module_params: Dict[torch.nn.Module, Iterable[str]],
):
self.temp_module_param_dicts = defaultdict(OrderedDict)
for module, param_names in module_params.items():
self.temp_module_param_dicts[module] = OrderedDict(
(param_name.rsplit(".", 1)[-1], params[param_name]) for param_name in param_names
)
self.orig_model_param_dicts = dict((module, module._parameters) for module in self.temp_module_param_dicts)

def __enter__(self):
for module, temp_param_dict in self.temp_module_param_dicts.items():
object.__setattr__(module, "_parameters", temp_param_dict)

def __exit__(self, type, value, traceback):
for module, orig_param_dict in self.orig_model_param_dicts.items():
object.__setattr__(module, "_parameters", orig_param_dict)


class Kernel(Module):
r"""
Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor`
Expand Down Expand Up @@ -212,6 +251,45 @@ def __init__(
# TODO: Remove this on next official PyTorch release.
self.__pdist_supports_batch = True

@property
def _lazily_evaluate(self) -> bool:
r"""
Determines whether or not the kernel is lazily evaluated.
If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated
over x1 and x2.
If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function.
The kernel function will only be evaluated when either
- An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or
- An indexing operation is performed on the kernel matrix to select specific covariance entries.
In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation
offers no gains and there is specific structure that will be lost with lazy evaluation
(e.g. low-rank/Nystrom approximations).
"""
return True

def _kernel_linear_operator_covar_func(
self,
x1: Tensor,
x2: Tensor,
non_param_kwargs: Dict[str, Any],
module_params: Dict[torch.nn.Module, Iterable[str]],
**params: torch.nn.Parameter,
) -> Union[Tensor, LinearOperator]:
# This is the `covar_function` that is passed into KernelLinearOperator
# This function calls self.forward, but does so in a way so that no parameters are closed over
# (by using the _autograd_kernel_hack context manager)
try:
if any(param.requires_grad for param in params.values()):
with _autograd_kernel_hack(self, params, module_params):
return self.forward(x1, x2, **non_param_kwargs)
else:
return self.forward(x1, x2, **non_param_kwargs)
except Exception as e:
raise e

def _lengthscale_param(self, m: Kernel) -> Tensor:
# Used by the lengthscale_prior
return m.lengthscale
Expand Down Expand Up @@ -501,8 +579,63 @@ def __call__(
return res

else:
if settings.lazily_evaluate_kernels.on():
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params)
if settings.lazily_evaluate_kernels.on() and self._lazily_evaluate:
num_outputs_per_input = self.num_outputs_per_input(x1_, x2_)
if isinstance(num_outputs_per_input, int):
num_outputs_per_input = (num_outputs_per_input, num_outputs_per_input)

def _get_parameter_parent_module_and_batch_shape(module):
num_module_batch_dimension = len(module.batch_shape) if isinstance(module, Kernel) else 0
for name, param in module._parameters.items():
yield name, (param, module, param.dim() - num_module_batch_dimension)

# The following returns a list of tuples for each parameter + parameters of sub-modules:
# (param_name, (param_val, param_parent_module, param_batch_shape))
named_parameters_parent_modules_and_batch_dimensions = tuple(
self._named_members(
_get_parameter_parent_module_and_batch_shape,
prefix="",
recurse=True,
)
)

if len(named_parameters_parent_modules_and_batch_dimensions):
# Information we need for the KernelLinearOperator, as well as the autograd hack:
# - the names/values of all parameters
# - the parent module associated with each parameter
# - the number of non-batch dimensions associated with each parameter
# WE get this information from the list constructed in the previous step
params = dict()
module_params = defaultdict(list)
num_nonbatch_dimensions = dict()
for name, (
param,
parent_module,
num_nonbatch_dimension,
) in named_parameters_parent_modules_and_batch_dimensions:
params[name] = param
module_params[parent_module].append(name)
num_nonbatch_dimensions[name] = num_nonbatch_dimension

# Construct the KernelLinearOperator
res = KernelLinearOperator(
x1_,
x2_,
covar_func=self._kernel_linear_operator_covar_func,
num_outputs_per_input=num_outputs_per_input,
num_nonbatch_dimensions=num_nonbatch_dimensions,
module_params=module_params, # params for _kernel_linear_operator_covar_func
non_param_kwargs=dict(**params), # params for forward
**params,
)
else:
res = KernelLinearOperator(
x1_,
x2_,
covar_func=self.forward,
num_outputs_per_input=num_outputs_per_input,
non_param_kwargs=dict(**params), # params for forward
)
else:
res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params))
return res
Expand Down Expand Up @@ -575,13 +708,17 @@ class AdditiveKernel(Kernel):
:param kernels: Kernels to add together.
"""

def __init__(self, *kernels: Iterable[Kernel]):
super(AdditiveKernel, self).__init__()
self.kernels = ModuleList(kernels)

@property
def is_stationary(self) -> bool:
return all(k.is_stationary for k in self.kernels)

def __init__(self, *kernels: Iterable[Kernel]):
super(AdditiveKernel, self).__init__()
self.kernels = ModuleList(kernels)
@property
def _lazily_evaluate(self) -> bool:
return all(k._lazily_evaluate for k in self.kernels)

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
res = ZeroLinearOperator() if not diag else 0
Expand Down Expand Up @@ -617,13 +754,17 @@ class ProductKernel(Kernel):
:param kernels: Kernels to multiply together.
"""

def __init__(self, *kernels: Iterable[Kernel]):
super(ProductKernel, self).__init__()
self.kernels = ModuleList(kernels)

@property
def is_stationary(self) -> bool:
return all(k.is_stationary for k in self.kernels)

def __init__(self, *kernels: Iterable[Kernel]):
super(ProductKernel, self).__init__()
self.kernels = ModuleList(kernels)
@property
def _lazily_evaluate(self) -> bool:
return False

def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
x1_eq_x2 = torch.equal(x1, x2)
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def __init__(

self.register_constraint("raw_variance", variance_constraint)

@property
def _lazily_evaluate(self) -> bool:
# LinearKernel should not lazily evaluate; to use the Woodbury formula,
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
return False

@property
def variance(self) -> Tensor:
return self.raw_variance_constraint.transform(self.raw_variance)
Expand Down
8 changes: 8 additions & 0 deletions gpytorch/kernels/multi_device_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ def __init__(
self.__cached_x1 = torch.empty(1)
self.__cached_x2 = torch.empty(1)

@property
def _lazily_evaluate(self) -> bool:
return self.base_kernel._lazily_evaluate

@property
def base_kernel(self):
return self.module

@property
def is_stationary(self):
return self.base_kernel.is_stationary

def forward(self, x1, x2, diag=False, **kwargs):
if diag:
return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device)
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/kernels/rff_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def __init__(self, num_samples: int, num_dims: Optional[int] = None, **kwargs):
if num_dims is not None:
self._init_weights(num_dims, num_samples)

@property
def _lazily_evaluate(self) -> bool:
# RFF kernels should not lazily evaluate; to use the Woodbury formula,
# we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator.
return False

def _init_weights(
self, num_dims: Optional[int] = None, num_samples: Optional[int] = None, randn_weights: Optional[Tensor] = None
):
Expand Down
Loading

0 comments on commit daf43a3

Please sign in to comment.