Skip to content

Commit

Permalink
Start adding typing annotations to ExactGP (#2436)
Browse files Browse the repository at this point in the history
* Start adding typing annotations to ExactGP

* Update typing annotations for ExactGP

Allow train_inputs to be Iterable instead of Sequence
  • Loading branch information
chrisyeh96 authored Jan 28, 2025
1 parent 42f4a17 commit 41f8b28
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/usr/bin/env python3

from __future__ import annotations

import warnings

from collections.abc import Iterable
from copy import deepcopy

import torch
from torch import Tensor

from .. import settings
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
Expand Down Expand Up @@ -52,15 +57,20 @@ class ExactGP(GP):
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
"""

def __init__(self, train_inputs, train_targets, likelihood):
if train_inputs is not None and torch.is_tensor(train_inputs):
def __init__(
self,
train_inputs: Tensor | Iterable[Tensor] | None,
train_targets: Tensor | None,
likelihood: _GaussianLikelihoodBase,
):
if train_inputs is not None and isinstance(train_inputs, Tensor):
train_inputs = (train_inputs,)
if train_inputs is not None and not all(torch.is_tensor(train_input) for train_input in train_inputs):
if train_inputs is not None and not all(isinstance(train_input, Tensor) for train_input in train_inputs):
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("ExactGP can only handle Gaussian likelihoods")

super(ExactGP, self).__init__()
super().__init__()
if train_inputs is not None:
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
self.train_targets = train_targets
Expand All @@ -72,20 +82,20 @@ def __init__(self, train_inputs, train_targets, likelihood):
self.prediction_strategy = None

@property
def train_targets(self):
def train_targets(self) -> tuple[Tensor] | None:
return self._train_targets

@train_targets.setter
def train_targets(self, value):
def train_targets(self, value: Tensor | None) -> None:
object.__setattr__(self, "_train_targets", value)

def _apply(self, fn):
if self.train_inputs is not None:
self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
self.train_targets = fn(self.train_targets)
return super(ExactGP, self)._apply(fn)
return super()._apply(fn)

def _clear_cache(self):
def _clear_cache(self) -> None:
# The precomputed caches from test time live in prediction_strategy
self.prediction_strategy = None

Expand All @@ -99,18 +109,20 @@ def local_load_samples(self, samples_dict, memo, prefix):
self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
super().local_load_samples(samples_dict, memo, prefix)

def set_train_data(self, inputs=None, targets=None, strict=True):
def set_train_data(
self, inputs: Tensor | Iterable[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
) -> None:
"""
Set training data (does not re-fit model hyper-parameters).
:param torch.Tensor inputs: The new training inputs.
:param torch.Tensor targets: The new training targets.
:param bool strict: (default True) If `True`, the new inputs and
targets must have the same shape, dtype, and device
as the current inputs and targets. Otherwise, any shape/dtype/device are allowed.
:param inputs: The new training inputs.
:param targets: The new training targets.
:param strict: If `True`, the new inputs and targets must have the same shape,
dtype, and device as the current inputs and targets. Otherwise, any
shape/dtype/device are allowed.
"""
if inputs is not None:
if torch.is_tensor(inputs):
if isinstance(inputs, Tensor):
inputs = (inputs,)
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
if strict:
Expand Down Expand Up @@ -218,7 +230,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
except KeyError:
fantasy_kwargs = {}

full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)

# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
Expand Down Expand Up @@ -257,7 +269,7 @@ def __call__(self, *args, **kwargs):
if self.training:
if self.train_inputs is None:
raise RuntimeError(
"train_inputs, train_targets cannot be None in training mode. "
"train_inputs cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
Expand All @@ -271,7 +283,7 @@ def __call__(self, *args, **kwargs):
# Prior mode
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
full_inputs = args
full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
Expand Down Expand Up @@ -313,7 +325,7 @@ def __call__(self, *args, **kwargs):
full_inputs.append(torch.cat([train_input, input], dim=-2))

# Get the joint distribution for training/test data
full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
Expand Down

0 comments on commit 41f8b28

Please sign in to comment.