From e6386c6aaadcae745b9d5af86e0f020a31106ced Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 5 Nov 2024 13:58:52 -0800 Subject: [PATCH] Robust Gaussian Processes via Relevance Pursuit (#2608) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2608 This commit adds the implementation of the [Robust Gaussian Processes via Relevance Pursuit](https://arxiv.org/pdf/2410.24222) models and algorithms of the NeurIPS 2024 article. Differential Revision: D65343571 --- botorch/models/gpytorch.py | 13 +- .../likelihoods/sparse_outlier_noise.py | 453 +++++++++ botorch/models/relevance_pursuit.py | 948 ++++++++++++++++++ botorch/test_functions/base.py | 126 +++ botorch/test_functions/synthetic.py | 2 +- botorch/utils/constraints.py | 47 + botorch/utils/testing.py | 22 +- sphinx/source/models.rst | 7 + test/models/test_relevance_pursuit.py | 444 ++++++++ test/test_functions/test_base.py | 27 +- .../multi_objective/test_scalarization.py | 1 + 11 files changed, 2082 insertions(+), 8 deletions(-) create mode 100644 botorch/models/likelihoods/sparse_outlier_noise.py create mode 100644 botorch/models/relevance_pursuit.py create mode 100644 test/models/test_relevance_pursuit.py diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 708f4b8ec2..4e86c14548 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -387,7 +387,7 @@ def _apply_noise( obs_noise = observation_noise.squeeze(-1) mvn = self.likelihood( mvn, - X, + [X], noise=obs_noise.expand(noise_shape), ) elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood): @@ -395,13 +395,15 @@ def _apply_noise( observation_noise = self.likelihood.noise.mean(dim=-1, keepdim=True) mvn = self.likelihood( mvn, - X, + [X], noise=observation_noise.expand(noise_shape), ) else: - mvn = self.likelihood(mvn, X) + mvn = self.likelihood(mvn, [X]) return mvn + # pyre-ignore[14]: Inconsistent override. Could not find parameter + # `Keywords(typing.Any)` in overriding signature. def posterior( self, X: Tensor, @@ -470,6 +472,7 @@ def posterior( return posterior_transform(posterior) return posterior + # pyre-ignore[14]: Inconsistent override. Could not find parameter `noise`. def condition_on_observations( self, X: Tensor, Y: Tensor, **kwargs: Any ) -> BatchedMultiOutputGPyTorchModel: @@ -632,7 +635,7 @@ def batch_shape(self) -> torch.Size: raise NotImplementedError(msg + " that are not broadcastble.") return next(iter(batch_shapes)) - # pyre-fixme[15]: Inconsistent override in return types + # pyre-fixme[14]: Inconsistent override in return types def posterior( self, X: Tensor, @@ -838,6 +841,8 @@ def _apply_noise( ) return self.likelihood(mvn, X) + # pyre-ignore[14]: Inconsistent override. Could not find parameter + # `Keywords(typing.Any)` in overriding signature. def posterior( self, X: Tensor, diff --git a/botorch/models/likelihoods/sparse_outlier_noise.py b/botorch/models/likelihoods/sparse_outlier_noise.py new file mode 100644 index 0000000000..792241bb03 --- /dev/null +++ b/botorch/models/likelihoods/sparse_outlier_noise.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any +from warnings import warn + +import torch +from botorch.exceptions.warnings import InputDataWarning +from botorch.models.model import Model +from botorch.models.relevance_pursuit import RelevancePursuitMixin +from botorch.utils.constraints import NonTransformedInterval +from gpytorch.distributions import MultivariateNormal +from gpytorch.likelihoods import _GaussianLikelihoodBase +from gpytorch.likelihoods.noise_models import FixedGaussianNoise, Noise +from gpytorch.mlls import ExactMarginalLogLikelihood +from gpytorch.priors import Prior +from linear_operator.operators import DiagLinearOperator, LinearOperator +from linear_operator.utils.cholesky import psd_safe_cholesky +from torch import Tensor +from torch.nn.parameter import Parameter + + +class SparseOutlierGaussianLikelihood(_GaussianLikelihoodBase): + def __init__( + self, + base_noise: Noise | FixedGaussianNoise, + dim: int, + outlier_indices: list[int] | None = None, + rho_prior: Prior | None = None, + rho_constraint: NonTransformedInterval | None = None, + batch_shape: torch.Size | None = None, + convex_parameterization: bool = True, + loo: bool = True, + ) -> None: + """A likelihood that models the noise of a GP with SparseOutlierNoise, a noise + model in the Relevance Pursuit family of models, permitting additional "robust" + variance for a small set of outlier data points. Notably, the indices of the + outlier data points are inferred during the optimization of the associated log + marginal likelihood via the Relevance Pursuit algorithm. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + NOTE: Letting base_noise also use the non-transformed constraints, will lead + to more stable optimization, but is orthogonal implementation-wise. If the base + noise is a HomoskedasticNoise, one can pass the non-transformed constraint as + the `noise_constraint`. + + Example: + >>> base_noise = HomoskedasticNoise( + >>> noise_constraint=NonTransformedInterval( + >>> 1e-5, 1e-1, initial_value=1e-3 + >>> ) + >>> ) + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `RelevancePursuitMixin` + >>> sparse_module = likelihood.noise_covar + >>> backward_relevance_pursuit(sparse_module, mll) + + Args: + base_noise: The base noise model. + dim: The number of training observations, which determines the maximum + number of data-point-specific noise variances of the noise model. + outlier_indices: The indices of the outliers. + rho_prior: Prior for `self.noise_covar`'s rho parameter. + rho_constraint: Constraint for `self.noise_covar`'s rho parameter. Needs to + be a NonTransformedInterval because exact sparsity cannot be represented + using smooth transforms like a softplus or sigmoid. + batch_shape: The batch shape of the learned noise parameter (default: []). + convex_parameterization: Whether to use the convex parameterization of rho, + which generally improves optimization results and is thus recommended. + loo: Whether to use leave-one-out (LOO) update equations that can compute + the optimal values of each individual rho, keeping all else equal. + """ + noise_covar = SparseOutlierNoise( + base_noise=base_noise, + dim=dim, + outlier_indices=outlier_indices, + rho_prior=rho_prior, + rho_constraint=rho_constraint, + batch_shape=batch_shape, + convex_parameterization=convex_parameterization, + loo=loo, + ) + super().__init__(noise_covar=noise_covar) + + # pyre-ignore[14]: Inconsistent override because the super class accepts `**kwargs` + def marginal( + self, + function_dist: MultivariateNormal, + *params: Any, + ) -> MultivariateNormal: + mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix + # this scales the rhos by the diagonal of the "non-robust" covariance matrix + diag_K = covar.diagonal() if self.noise_covar.convex_parameterization else None + noise_covar = self.noise_covar.forward(*params, shape=mean.shape, diag_K=diag_K) + full_covar = covar + noise_covar + return function_dist.__class__(mean, full_covar) + + def expected_log_prob( + self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any + ) -> Tensor: + raise NotImplementedError( + "SparseOutlierGaussianLikelihood does not yet support variational inference" + ", but this is not a fundamental limitation. It will require an adding " + "the `expected_log_prob` method to SparseOutlierGaussianLikelihood." + ) + + +class SparseOutlierNoise(Noise, RelevancePursuitMixin): + def __init__( + self, + base_noise: Noise | FixedGaussianNoise, + dim: int, + outlier_indices: list[int] | None = None, + rho_prior: Prior | None = None, + rho_constraint: NonTransformedInterval | None = None, + batch_shape: torch.Size | None = None, + convex_parameterization: bool = True, + loo: bool = True, + ): + """A noise model in the Relevance Pursuit family of models, permitting + additional "robust" variance for a small set of outlier data points. + See also `SparseOutlierGaussianLikelihood`, which leverages this noise model. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + Example: + >>> base_noise = HomoskedasticNoise( + >>> noise_constraint=NonTransformedInterval( + >>> 1e-5, 1e-1, initial_value=1e-3 + >>> ) + >>> ) + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `SparseOutlierNoise` + >>> sparse_module = likelihood.noise_covar + >>> backward_relevance_pursuit(sparse_module, mll) + + Args: + base_noise: The base noise model. + dim: The number of training observations, which determines the maximum + number of data-point-specific noise variances of the noise model. + outlier_indices: The indices of the outliers. + rho_prior: Prior for the rho parameter. + rho_constraint: Constraint for the rho parameter. Needs to be a + NonTransformedInterval because exact sparsity cannot be represented + using smooth transforms like a softplus or sigmoid. + batch_shape: The batch shape of the learned noise parameter (default: []). + convex_parameterization: Whether to use the convex parameterization of rho, + which generally improves optimization results and is thus recommended. + loo: Whether to use leave-one-out (LOO) update equations that can compute + the optimal values of each individual rho, keeping all else equal. + """ + super().__init__() + RelevancePursuitMixin.__init__(self, dim=dim, support=outlier_indices) + + if batch_shape is None: + batch_shape = base_noise.noise.shape[:-1] + + self.base_noise = base_noise + device = base_noise.noise.device + if rho_constraint is None: + cvx_upper_bound = 1 - 1e-3 # < 1 to avoid singularities + rho_constraint = NonTransformedInterval( + lower_bound=0.0, + upper_bound=cvx_upper_bound if convex_parameterization else torch.inf, + initial_value=0.0, + ) + else: + if not isinstance(rho_constraint, NonTransformedInterval): + raise ValueError( + "`rho_constraint` must be a `NonTransformedInterval` if it " + "is not None." + ) + + if rho_constraint.lower_bound < 0: + raise ValueError( + "SparseOutlierNoise requires rho_constraint.lower_bound >= 0." + ) + + if convex_parameterization and rho_constraint.upper_bound > 1: + raise ValueError( + "Convex parameterization requires rho_constraint.upper_bound <= 1." + ) + + # NOTE: Prefer to keep the initialization of the sparse_parameter in the + # derived classes of the Mixin, because it might require additional logic + # that we don't want to put into RelevancePursuitMixin. + num_outliers = len(self.support) + self.register_parameter( + "raw_rho", + parameter=Parameter( + torch.zeros( + *batch_shape, + num_outliers, + dtype=base_noise.noise.dtype, + device=device, + ) + ), + ) + + if rho_prior is not None: + + def _rho_param(m): + return m.rho # TODO: coverage + + def _rho_closure(m, v): + return m._set_rho(v) # TODO: coverage + + self.register_prior("rho_prior", rho_prior, _rho_param, _rho_closure) + + self.register_constraint("raw_rho", rho_constraint) + # only publicly exposing getter of convex parameterization + # since post-hoc modification can lead to inconsistencies + # with the rho constraints. + self._convex_parameterization = convex_parameterization + self.loo = loo + self._cached_train_inputs = None + + @property + def sparse_parameter(self) -> Parameter: + return self.raw_rho + + def set_sparse_parameter(self, value: Parameter) -> None: + """Sets the sparse parameter. + + NOTE: We can't use the property setter @sparse_parameter.setter because of + the special way PyTorch treats Parameter types, including custom setters. + """ + self.raw_rho = torch.nn.Parameter(value.to(self.raw_rho)) + + @property + def convex_parameterization(self) -> bool: + return self._convex_parameterization + + @staticmethod + def _from_model(model: Model) -> RelevancePursuitMixin: + sparse_module = model.likelihood.noise_covar + if not isinstance(sparse_module, SparseOutlierNoise): + raise ValueError( + "The model's likelihood does not have a SparseOutlierNoise noise " + f"as its noise_covar module, but instead a {type(sparse_module)}." + ) + return sparse_module + + @property + def _convex_rho(self) -> Tensor: + """Transforms the raw_rho parameter such that `rho ~= 1 / (1 - raw_rho) - 1`, + which is a diffeomorphism from [0, 1] to [0, inf] whose derivative is nowhere + zero. This transforms the marginal log likelihood to be a convex function of + the `self.raw_rho` Parameter, when the covariance matrix is well conditioned. + + NOTE: The convex parameterization also includes a scaling of the rho values by + the diagonal of the covariance matrix, which is carried out in the `marginal` + call in the SparseOutlierGaussianLikelihood. + """ + # pyre-ignore[7]: It is not have an incompatible return type, pyre just doesn't + # recognize that the result gets promoted to a Tensor. + return 1 / (1 - self.raw_rho) - 1 + + @property + def rho(self) -> Tensor: + """Dense representation of the data-point-specific variances, corresponding to + the latent `self.raw_rho` values, which might be represented sparsely or in the + convex parameterization. The last dimension is equal to the number of training + points `self.dim`. + + NOTE: `rho` differs from `self.sparse_parameter` in that the latter returns the + the parameter in its sparse representation when `self.is_sparse` is true, and in + its latent convex paramzeterization when `self.convex_parameterization` is true, + while `rho` always returns the data-point-specific variances, embedded in a + dense tensor. The dense representation is used to propagate gradients to the + sparse rhos in the support. + + Returns: + A `batch_shape x self.dim`-dim Tensor of robustness variances. + """ + # NOTE: don't need to do transform / untransform since we are + # enforcing NonTransformedIntervals. + rho_outlier = self._convex_rho if self.convex_parameterization else self.raw_rho + if not self.is_sparse: # in the dense representation, we're done. + return rho_outlier + + # If rho_outlier is in the sparse representation, we need to pad the + # rho values with zeros at the correct positions. The difference + # between this and calling RelevancePursuit's `to_dense` is that + # the latter will propagate gradients through all rhos, whereas + # the path here only propagates gradients to the sparse set of + # outliers, which is important for the optimization of the support. + rho_inlier = torch.zeros( + 1, dtype=rho_outlier.dtype, device=rho_outlier.device + ).expand(rho_outlier.shape[:-1] + (1,)) + rho = torch.cat( + [rho_outlier, rho_inlier], dim=-1 + ) # batch_shape x (num_outliers + 1) + + return rho[..., self._rho_selection_indices] + + @property + def _rho_selection_indices(self) -> Tensor: + # num_train is cached in the forward pass in training mode + # if an index is not in the outlier indices, we get the zeros from the + # last index of "rho" + # is this related to a sparse to dense mapping used in RP? + rho_selection_indices = torch.full( + self.raw_rho.shape[:-1] + (self.dim,), + -1, + dtype=torch.long, + device=self.raw_rho.device, + ) + for i, j in enumerate(self.support): + rho_selection_indices[j] = i + + return rho_selection_indices + + # pyre-ignore[14]: Inconsistent override because the super class accepts `**kwargs` + def forward( + self, + *params: Any, + diag_K: Tensor | None = None, + shape: torch.Size | None = None, + ) -> LinearOperator | Tensor: + """Computes the covariance matrix of the sparse outlier noise model. + + Args: + params: The parameters of noise model, same as for GPyTorch's noise model. + diag_K: The diagonal of the covariance matrix, which is used to scale the + rho values in the convex parameterization. + shape: The shape of the covariance matrix, which is used to broadcast the + rho values to the correct shape. + + Returns: + A `batch_shape x self.dim`-dim Tensor of robustness variances. + """ + noise_covar = self.base_noise(*params, shape=shape) + # rho should always be applied to the training set, irrespective of whether or + # not we are in training mode. + rho = self.rho + # check if we should apply the rhos, based on the cached training inputs + # NOTE: Even though it is not strictly required for many likelihoods, BoTorch + # and GPyTorch generally pass the training inputs to the likelihood, e.g.: + # 1) in fit_gpytorch_mll: + # ( + # github.com/pytorch/botorch/blob/3ca48d0ac5865a017ac6b2294807b432d6472bcf/ + # botorch/optim/closures/model_closures.py#L185 + # ) + # 2) in the exact prediction strategy: + # ( + # github.com/cornellius-gp/gpytorch/blob/ + # d501c284d05a1186868dc3fb20e0fa6ad32d32ac/ + # gpytorch/models/exact_prediction_strategies.py#L387 + # ) + if len(params) > 0: + if noise_covar.shape[-1] != rho.shape[-1]: + apply_robust_variances = False + warning_reason = ( + "the last dimension of the base noise covariance " + f"({noise_covar.shape[-1]}) " + "is not compatible with the last dimension of rho " + f"({rho.shape[-1]})." + ) + elif self.training or self._cached_train_inputs is None: + apply_robust_variances = True + self._cached_train_inputs = params[0][0] + warning_reason = "" + else: + apply_robust_variances = torch.equal( + params[0][0], self._cached_train_inputs + ) + warning_reason = ( + "the passed train_inputs are not equal to the cached ones." + ) + else: + apply_robust_variances = False + warning_reason = "the training inputs were not passed to the likelihood." + + if apply_robust_variances: + if diag_K is not None: + rho = (diag_K + noise_covar.diagonal()) * rho # convex parameterization + noise_covar = noise_covar + DiagLinearOperator(rho) + else: + warn( + f"SparseOutlierNoise: Robust rho not applied because {warning_reason} " + + "This can happen when the model posterior is evaluated on test data.", + InputDataWarning, + stacklevel=2, + ) + return noise_covar + + # relevance pursuit method expansion and contraction related methods + def expansion_objective(self, mll: ExactMarginalLogLikelihood) -> Tensor: + """Computes an objective value for all the inactive parameters, i.e. + self.sparse_parameter[~self.is_active] since we can't add already active + parameters to the support. This value will be used to select the parameters. + + Args: + mll: The marginal likelihood, containing the model to optimize. + + Returns: + The expansion objective value for all the inactive parameters. + """ + f = self._optimal_rhos if self.loo else self._sparse_parameter_gradient + return f(mll) + + def _optimal_rhos(self, mll: ExactMarginalLogLikelihood) -> Tensor: + """Computes the optimal rho deltas for the given model. + + Args: + mll: The marginal likelihood, containing the model to optimize. + + Returns: + A `batch_shape x self.dim`-dim Tensor of optimal rho deltas. + """ + # train() is important, since we want to evaluate the prior with mll.model(X), + # but in eval(), __call__ gives the posterior. + mll.train() # NOTE: this changes model.train_inputs to be unnormalized. + X, Y = mll.model.train_inputs[0], mll.model.train_targets + F = mll.model(X) + L = mll.likelihood(F, X) + S = L.covariance_matrix # (Kernel Matrix + Noise Matrix) + + # NOTE: The following computation is mathematically equivalent to the formula + # in this comment, but leverages the positive-definiteness of S via its + # Cholesky factorization. + # S_inv = S.inverse() + # diag_S_inv = S_inv.diagonal(dim1=-1, dim2=-2) + # loo_var = 1 / S_inv.diagonal(dim1=-1, dim2=-2) + # loo_mean = Y - (S_inv @ Y) / diag_S_inv + + chol = psd_safe_cholesky(S, upper=True) + eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype) + inv_root = torch.linalg.solve_triangular(chol, eye, upper=True) + + # test: inv_root.square().sum(dim=-1) - S.inverse().diag() + diag_S_inv = inv_root.square().sum(dim=-1) + loo_var = 1 / diag_S_inv + S_inv_Y = torch.cholesky_solve(Y.unsqueeze(-1), chol, upper=True).squeeze(-1) + loo_mean = Y - S_inv_Y / diag_S_inv + + loo_error = loo_mean - Y + optimal_rho_deltas = loo_error.square() - loo_var + return (optimal_rho_deltas - self.rho).clamp(0)[~self.is_active] diff --git a/botorch/models/relevance_pursuit.py b/botorch/models/relevance_pursuit.py new file mode 100644 index 0000000000..07407ac917 --- /dev/null +++ b/botorch/models/relevance_pursuit.py @@ -0,0 +1,948 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Relevance Pursuit model structure and optimization routines for the sparse optimization +of Gaussian process hyper-parameters, see [Ament2024pursuit]_ for details. + +References + +.. [Ament2024pursuit] + S. Ament, E. Santorella, D. Eriksson, B. Letham, M. Balandat, and E. Bakshy. + Robust Gaussian Processes via Relevance Pursuit. Advances in Neural Information + Processing Systems 37, 2024. Arxiv: https://arxiv.org/abs/2410.24222. +""" + +from __future__ import annotations + +import math + +from abc import ABC, abstractmethod +from collections.abc import Callable +from copy import copy, deepcopy +from functools import partial +from typing import Any, cast, Optional + +import torch +from botorch.fit import fit_gpytorch_mll +from botorch.models.model import Model +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from torch import Tensor +from torch.nn.parameter import Parameter + +MLL_ITER = 10_000 # let's take convergence seriously +MLL_TOL = 1e-8 +RESET_PARAMETERS = False + + +class RelevancePursuitMixin(ABC): + """Mixin class to convert between the sparse and dense representations of the + relevance pursuit models' sparse parameters, as well as to compute the generalized + support acquisition and support deletion criteria. + """ + + dim: int # the total number of features + + # IDEA: could generalize this to sets of parameters Dict[str, List[int]] + # Beside looping over the parameters for all the sparse / dense conversions, + # we'd need to introduce a vectorial representation of all the parameters + # for the selection of the acquisition / deletion indices. + # We don't really need to enforce a vectorial parameter storage for this, we + # only need to introduce a helper that computes the (parameter, index) pair + # that maximize the acquisition criterion. + # potentially relevant: get_tensors_as_ndarray_1d + _support: list[int] # indices of the features in the support, subset of range(dim) + + def __init__( + self, + dim: int, + support: list[int] | None, + ) -> None: + """Constructor for the RelevancePursuitMixin class. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + Args: + dim: The total number of features. + support: The indices of the features in the support, subset of range(dim). + """ + + self.dim = dim + self._support = support if support is not None else [] + # Assumption: sparse_parameter is initialized in sparse representation + self._is_sparse = True + self._expansion_modifier = None + self._contraction_modifier = None + + @property + @abstractmethod + def sparse_parameter(self) -> Parameter: + """The sparse parameter.""" + pass # pragma: no cover + + @abstractmethod + def set_sparse_parameter(self, value: Parameter) -> None: + """Sets the sparse parameter. + + NOTE: We can't use the property setter @sparse_parameter.setter because of + the special way PyTorch treats Parameter types, including custom setters that + bypass the @property setters before the latter are called. + """ + pass # pragma: no cover + + @staticmethod + def _from_model(model: Model) -> RelevancePursuitMixin: + """Retrieves a RelevancePursuitMixin from a model.""" + raise NotImplementedError # pragma: no cover + + @property + def is_sparse(self) -> bool: + # Do we need to differentiate between a full support sparse representation and + # a full support dense representation? The order the of the indices could be + # different, unless we keep them sorted. + return self._is_sparse + + @property + def support(self) -> list[int]: + """The indices of the active parameters.""" + return self._support + + @property + def is_active(self) -> Tensor: + """A Boolean Tensor of length dim, indicating which of the d dimensions are in + the support, i.e. "active". + """ + is_active = [(i in self.support) for i in range(self.dim)] # TODO: coverage + return torch.tensor( + is_active, dtype=torch.bool, device=self.sparse_parameter.device + ) + + @property + def active_parameters(self) -> Tensor: # TODO: coverage + if self.is_sparse: + return self.sparse_parameter + else: + return self.sparse_parameter[self.support] + + @property + def inactive_indices(self) -> Tensor: # TODO: coverage + device = self.sparse_parameter.device + return torch.arange(self.dim, device=device)[~self.is_active] + + def to_sparse(self) -> RelevancePursuitMixin: + # should we prohibit this for the case where the support is the full set? + if not self.is_sparse: + self.set_sparse_parameter( + torch.nn.Parameter(self.sparse_parameter[self.support]) + ) + self._is_sparse = True + return self + + def to_dense(self) -> RelevancePursuitMixin: + if not self.is_sparse: + return self # already dense + dtype = self.sparse_parameter.dtype + device = self.sparse_parameter.device + zero = torch.tensor( + 0.0, + dtype=dtype, + device=device, + ) + dense_parameter = [ + ( + self.sparse_parameter[self.support.index(i)] + if i in self.support + else zero + ) + for i in range(self.dim) + ] + dense_parameter = torch.tensor(dense_parameter, dtype=dtype, device=device) + self.set_sparse_parameter(torch.nn.Parameter(dense_parameter)) + self._is_sparse = False + return self + + def expand_support(self, indices: list[int]) -> RelevancePursuitMixin: + for i in indices: + if i in self.support: + raise ValueError(f"Feature {i} already in support.") # TODO: coverage + + self.support.extend(indices) + # we need to add the parameter in the sparse representation + if self.is_sparse: + self.set_sparse_parameter( + torch.nn.Parameter( + torch.cat( + ( + self.sparse_parameter, + torch.zeros(len(indices)).to(self.sparse_parameter), + ) + ) + ) + ) + return self + + def contract_support(self, indices: list[int]) -> RelevancePursuitMixin: + # indices into the sparse representation of features to *keep* + sparse_indices = list(range(len(self.support))) + original_support = copy(self.support) + for i in indices: + if i not in self.support: + raise ValueError(f"Feature {i} is not in support.") + sparse_indices.remove(original_support.index(i)) + self.support.remove(i) + + # we need to add the parameter in the sparse representation + if self.is_sparse: + self.set_sparse_parameter(Parameter(self.sparse_parameter[sparse_indices])) + else: + requires_grad = self.sparse_parameter.requires_grad # TODO: coverage + self.sparse_parameter.requires_grad_(False) + self.sparse_parameter[indices] = 0.0 + self.sparse_parameter.requires_grad_(requires_grad) # restore + return self + + def drop_zeros_from_support(self, threshold: float = 0.0) -> RelevancePursuitMixin: + # drops indices from support whose corresponding values are zero + # TODO: figure out batch_shape if necessary, this seems complicated + # to make batched, unless we force the support to be the same for + # all batches. + is_zero = self.sparse_parameter <= threshold # TODO: coverage + if self.is_sparse: + indices = [self.support[i] for i, b in enumerate(is_zero) if b] + else: + indices = [i for i, b in enumerate(is_zero) if b and i in self.support] + self.contract_support(indices) + return self + + def drop_threshold_from_support( + self, lower: float, upper: float + ) -> RelevancePursuitMixin: + # drops indices from support whose corresponding values are zero + # TODO: figure out batch_shape if necessary, this seems complicated + # to make batched, unless we force the support to be the same for + # all batches. + is_small = self.sparse_parameter <= lower # TODO: coverage + is_large = self.sparse_parameter >= upper + to_drop = is_small | is_large + if self.is_sparse: + indices = [self.support[i] for i, b in enumerate(to_drop) if b] + else: + indices = [i for i, b in enumerate(to_drop) if b and i in self.support] + self.contract_support(indices) + return self + + # support initialization helpers + def full_support(self) -> RelevancePursuitMixin: + self.expand_support([i for i in range(self.dim) if i not in self.support]) + self.to_dense() # no reason to be sparse with full support + return self + + def remove_support(self) -> RelevancePursuitMixin: + self._support = [] + requires_grad = self.sparse_parameter.requires_grad + if self.is_sparse: + self.set_sparse_parameter( + torch.nn.Parameter(torch.tensor([]).to(self.sparse_parameter)) + ) + else: + self.sparse_parameter.requires_grad_(False) + self.sparse_parameter[:] = 0.0 + self.sparse_parameter.requires_grad_(requires_grad) + return self + + def random_support(self, n: int) -> RelevancePursuitMixin: + # randperm could also be interesting as an expansion tactic in cases + # where we want to avoid evaluating other criteria + self.remove_support() + if n == self.dim: + self.full_support() + elif 0 < n and n < self.dim: + # random support initialization + self.expand_support(torch.randperm(self.dim)[:n].tolist()) + else: + raise ValueError(f"Cannot add more than {self.dim} indices to support.") + return self + + # the following two methods are the only ones that are specific to the marginal + # likelihood optimization problem + def support_expansion( + self, + mll: ExactMarginalLogLikelihood, + n: int = 1, + modifier: Callable[[Tensor], Tensor] | None = None, + ) -> bool: + """Computes the indices of the features that maximize the gradient of the sparse + parameter and that are not already in the support, and subsequently expands the + support to include the features if their gradient is positive. + + Args: + mll: The marginal likelihood, containing the model to optimize. + NOTE: Virtually all of the rest of the code is not specific to the + marginal likelihood optimization, so we could generalize this to work + with any objective. + n: The number of features to select. + modifier: A function that modifies the gradient before computing + the support expansion criterion. This is useful, for example, + when we want to select the maximum gradient magnitude for real-valued + (not non-negative) parameters, in which case modifier = torch.abs. + + Returns: + True if the support was expanded, False otherwise. + """ + g = self.expansion_objective(mll) + + modifier = modifier if modifier is not None else self._expansion_modifier + if modifier is not None: + # IDEA: could compute a Newton step here / use the approximation to the + # Hessian that is returned by L-BFGS. + g = modifier(g) + + # support is already removed from consideration + # gradient of the support parameters is not necessarily zero, + # even for a converged solution in the presence of constraints. + # IDEA: could use the vectorized representation of all + # parameters in the optimizer stack to make this selection + # over multiple parameter groups. + # NOTE: these indices are relative to self.inactive_indices. + indices = g.argsort(descending=True)[:n] + indices = indices[g[indices] > 0] + if indices.numel() == 0: # no indices with positive gradient + return False + self.expand_support(self.inactive_indices[indices].tolist()) + + return True + + # NOTE: could also generalize contraction_objective + def expansion_objective(self, mll: ExactMarginalLogLikelihood) -> Tensor: + """Computes an objective value for all the inactive parameters, i.e. + self.sparse_parameter[~self.is_active] since we can't add already active + parameters to the support. This value will be used to select the parameters. + + Args: + mll: The marginal likelihood, containing the model to optimize. + + Returns: + The expansion objective value for all the inactive parameters. + """ + return self._sparse_parameter_gradient(mll) + + def _sparse_parameter_gradient(self, mll: ExactMarginalLogLikelihood) -> Tensor: + """Computes the gradient of the marginal likelihood with respect to the + sparse parameter. + + Args: + mll: The marginal likelihood, containing the model to optimize. + + Returns: + The gradient of the marginal likelihood with respect to the inactive + sparse parameters. + """ + # evaluate gradient of the sparse parameter + is_sparse = self.is_sparse # in order to restore the original representation + self.to_dense() # need the parameter in its dense parameterization + + requires_grad = self.sparse_parameter.requires_grad + self.sparse_parameter.requires_grad_(True) + if self.sparse_parameter.grad is not None: + self.sparse_parameter.grad.zero_() + mll.train() # NOTE: this changes model.train_inputs + X, Y = mll.model.train_inputs[0], mll.model.train_targets + cast(Tensor, mll(mll.model(X), Y)).backward() # evaluation + self.sparse_parameter.requires_grad_(requires_grad) + + g = self.sparse_parameter.grad + if g is None: + raise ValueError("Gradient is not available.") + + if is_sparse: + self.to_sparse() + + return g[~self.is_active] # only need the inactive parameters + + def support_contraction( + self, + mll: ExactMarginalLogLikelihood, + n: int = 1, + modifier: Callable[[Tensor], Tensor] | None = None, + ) -> bool: + """Computes the indices of the features that have the smallest coefficients, + and subsequently contracts the exlude the features. + + Args: + mll: The marginal likelihood, containing the model to optimize. + NOTE: Virtually all of the rest of the code is not specific to the + marginal likelihood optimization, so we could generalize this to work + with any objective. + n: The number of features to select for removal. + modifier: A function that modifies the parameter values before computing + the support contraction criterion. + + Returns: + True if the support was expanded, False otherwise. + """ + if len(self.support) == 0: + return False + + is_sparse = self.is_sparse + self.to_sparse() + x = self.sparse_parameter + + modifier = modifier if modifier is not None else self._contraction_modifier + if modifier is not None: + x = modifier(x) + + # IDEA: for non-negative parameters, could break ties at zero + # depending with derivative + sparse_indices = x.argsort(descending=False)[:n] + indices = [self.support[i] for i in sparse_indices] + self.contract_support(indices) + if not is_sparse: + self.to_dense() + return True + + def optimize_mll( + self, + mll: ExactMarginalLogLikelihood, + model_trace: list[Model] | None = None, + reset_parameters: bool = RESET_PARAMETERS, + reset_dense_parameters: bool = RESET_PARAMETERS, + optimizer_kwargs: dict[str, Any] | None = None, + ): + """Optimizes the marginal likelihood. + + Args: + mll: The marginal likelihood, containing the model to optimize. + model_trace: If not None, a list to which a deepcopy of the model state is + appended. NOTE This operation is *in place*. + reset_parameters: If True, initializes the sparse parameter to the all-zeros + vector before every marginal likelihood optimization step. If False, the + optimization is warm-started with the previous iteration's parameters. + reset_dense_parameters: If True, re-initializes the dense parameters, e.g. + other GP hyper-parameters that are *not* part of the Relevance Pursuit + module, to the initial values provided by their associated constraints. + optimizer_kwargs: A dictionary of keyword arguments for the optimizer. + + Returns: + The marginal likelihood after optimization. + """ + if reset_parameters: + # this might be beneficial because the parameters can + # end up at a constraint boundary, which can anecdotally make + # it more difficult to move the newly added parameters. + # should we only do this after expansion? + # IDEA: should we also reset the dense parameters? + with torch.no_grad(): + self.sparse_parameter.zero_() + + if reset_dense_parameters: + # re-initialize dense parameters + initialize_dense_parameters(mll.model) + + # move to sparse representation for optimization + # NOTE: this function should never force the dense representation, because some + # models might never need it, and it would be inefficient. + self.to_sparse() + fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs) + if model_trace is not None: + # need to record the full model here, rather than just the sparse parameter + # since other hyper-parameters are co-adapted to the sparse parameter. + model_trace.append(deepcopy(mll.model)) + return mll + + +# Optimization Algorithms +def relevance_pursuit( + sparse_module: RelevancePursuitMixin, + mll: ExactMarginalLogLikelihood, + num_iter: int, + num_expand: int = 1, + num_contract: int = 0, + mll_iter: int = MLL_ITER, + mll_tol: float = MLL_TOL, + optimizer_kwargs: dict[str, Any] | None = None, + reset_parameters: bool = RESET_PARAMETERS, + reset_dense_parameters: bool = RESET_PARAMETERS, + record_model_trace: bool = False, +) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]: + """Relevance pursuit algorithm for the sparse marginal likelihood optimization + of Gaussian process parameters. In its most general form, it is a forward-backward + algorithm, but the forward and backward stages can be called independently, and + modulated with the num_expand and num_contract arguments. + + NOTE: For the robust `SparseOutlierNoise` model of [Ament2024pursuit]_, we recommend + using the backward algorithm `backward_relevance_pursuit`, which leads to the most + robust results in the presence of a large number of outliers. + + Ideas: + - Could re-optimize after every single expansion, even if num_expand > 1. + - Could drop exact zeros. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + Example: + >>> base_noise = HomoskedasticNoise( + >>> noise_constraint=NonTransformedInterval( + >>> 1e-5, 1e-1, initial_value=1e-3 + >>> ) + >>> ) + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `RelevancePursuitMixin` + >>> sparse_module = likelihood.noise_covar + >>> sparse_module.remove_support() + >>> relevance_pursuit(sparse_module, mll, num_iter=3, num_expand=2) + + Args: + mll: The marginal likelihood, containing the model to optimize. + num_iter: The number of iterations to run. + num_expand: The number of features to add during each iteration. + num_contract: The number of features to remove during each iteration. + mll_iter: The maximum number of iterations to run the MLL optimizer. Only used + when `optimizer_kwargs` is None. + mll_tol: The convergence tolerance for the MLL optimizer. Only used when + `optimizer_kwargs` is None. + optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. + reset_parameters: If true, initializes the sparse parameter to the all zeros + vector before every marginal likelihood optimization step. If false, the + optimization is warm-started with the parameters of the previous iteration. + reset_dense_parameters: If true, re-initializes the dense parameters, e.g. + other GP hyper-parameters that are *not* part of the Relevance Pursuit + module, to the initial values provided by their associated constraints. + record_model_trace: If true, records the model state after every iteration. + + Returns: + The marginal likelihood after relevance pursuit optimization. + """ + if optimizer_kwargs is None: + optimizer_kwargs = { + "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol} + } + + model_trace = [] if record_model_trace else None + + def optimize_mll(mll): + return sparse_module.optimize_mll( + mll=mll, + model_trace=model_trace, + reset_parameters=reset_parameters, + reset_dense_parameters=reset_dense_parameters, + optimizer_kwargs=optimizer_kwargs, + ) + + optimize_mll(mll) # initial optimization + + for _ in range(num_iter): + expanded = False + if num_expand > 0: + expanded = sparse_module.support_expansion(mll=mll, n=num_expand) + optimize_mll(mll) # re-optimize support + + contracted = False + if num_contract > 0: + contracted = sparse_module.support_contraction(mll=mll, n=num_contract) + optimize_mll(mll) # re-optimize support + + # IDEA: could stop here if the marginal likelihood decreases, assuming that + # the posterior pdf of the support size is uni-modal. + if not expanded and not contracted: # stationary support + break + + return sparse_module, model_trace + + +def forward_relevance_pursuit( + sparse_module: RelevancePursuitMixin, + mll: ExactMarginalLogLikelihood, + sparsity_levels: list[int] | None = None, + mll_iter: int = MLL_ITER, + mll_tol: float = MLL_TOL, + optimizer_kwargs: dict[str, Any] | None = None, + reset_parameters: bool = RESET_PARAMETERS, + reset_dense_parameters: bool = RESET_PARAMETERS, + record_model_trace: bool = True, + initial_support: list[int] | None = None, +) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]: + """Forward Relevance Pursuit. + + NOTE: For the robust `SparseOutlierNoise` model of [Ament2024pursuit]_, the forward + algorithm is generally faster than the backward algorithm, particularly when the + maximum sparsity level is small, but it leads to less robust results when the number + of outliers is large. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + Example: + >>> base_noise = HomoskedasticNoise( + >>> noise_constraint=NonTransformedInterval( + >>> 1e-5, 1e-1, initial_value=1e-3 + >>> ) + >>> ) + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `RelevancePursuitMixin` + >>> sparse_module = likelihood.noise_covar + >>> sparse_module, model_trace = forward_relevance_pursuit(sparse_module, mll) + + Args: + sparse_module: The relevance pursuit module. + mll: The marginal likelihood, containing the model to optimize. + sparsity_levels: The sparsity levels to expand the support to. + mll_iter: The maximum number of iterations to run the MLL optimizer. Only used + when `optimizer_kwargs` is None. + mll_tol: The convergence tolerance for the MLL optimizer. Only used when + `optimizer_kwargs` is None. + optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. + reset_parameters: If true, initializes the sparse parameter to the all zeros + after each iteration. + reset_dense_parameters: If true, re-initializes the dense parameters, e.g. + other GP hyper-parameters that are *not* part of the Relevance Pursuit + module, to the initial values provided by their associated constraints. + record_model_trace: If true, records the model state after every iteration. + initial_support: The support with which to initialize the sparse module. By + default, the support is initialized to the empty set. + + Returns: + The relevance pursuit module after forward relevance pursuit optimization, and + a list of models with different supports that were optimized. + """ + sparse_module.remove_support() + if initial_support is not None: + sparse_module.expand_support(initial_support) + + if sparsity_levels is None: + sparsity_levels = list(range(len(sparse_module.support), sparse_module.dim + 1)) + + # since this is the forward algorithm, potential sparsity levels + # must be in increasing order and unique. + sparsity_levels = list(set(sparsity_levels)) + sparsity_levels.sort(reverse=False) + + if optimizer_kwargs is None: + optimizer_kwargs = { + "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol} + } + + model_trace = [] if record_model_trace else None + + def optimize_mll(mll): + return sparse_module.optimize_mll( + mll=mll, + model_trace=model_trace, + reset_parameters=reset_parameters, + reset_dense_parameters=reset_dense_parameters, + optimizer_kwargs=optimizer_kwargs, + ) + + # if sparsity levels contains the initial support, remove it + if sparsity_levels[0] == len(sparse_module.support): + sparsity_levels.pop(0) + + optimize_mll(mll) # initial optimization + + for sparsity in sparsity_levels: + support_size = len(sparse_module.support) + num_expand = sparsity - support_size + if num_expand <= 0: + raise ValueError( + "sparsity_levels need to be increasing and larger than initial support." + ) + + expanded = sparse_module.support_expansion(mll=mll, n=num_expand) + # IDEA: could stop here if the marginal likelihood decreases, assuming that + # the posterior pdf of the support size is uni-modal. + if not expanded: # stationary support + break + + optimize_mll(mll) # re-optimize support + + return sparse_module, model_trace + + +def backward_relevance_pursuit( + sparse_module: RelevancePursuitMixin, + mll: ExactMarginalLogLikelihood, + sparsity_levels: list[int] | None = None, + mll_iter: int = MLL_ITER, + mll_tol: float = MLL_TOL, + optimizer_kwargs: dict[str, Any] | None = None, + reset_parameters: bool = RESET_PARAMETERS, + reset_dense_parameters: bool = RESET_PARAMETERS, + record_model_trace: bool = True, + initial_support: list[int] | None = None, +) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]: + """Backward Relevance Pursuit. + + NOTE: For the robust `SparseOutlierNoise` model of [Ament2024pursuit]_, the backward + algorithm generally leads to more robust results than the forward algorithm, + especially when the number of outliers is large, but is more expensive unless the + support is contracted by more than one in each iteration. + + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. + + Example: + >>> base_noise = HomoskedasticNoise( + >>> noise_constraint=NonTransformedInterval( + >>> 1e-5, 1e-1, initial_value=1e-3 + >>> ) + >>> ) + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `RelevancePursuitMixin` + >>> sparse_module = likelihood.noise_covar + >>> sparse_module, model_trace = backward_relevance_pursuit(sparse_module, mll) + + Args: + sparse_module: The relevance pursuit module. + mll: The marginal likelihood, containing the model to optimize. + sparsity_levels: The sparsity levels to expand the support to. + mll_iter: The maximum number of iterations to run the MLL optimizer. Only used + when `optimizer_kwargs` is None. + mll_tol: The convergence tolerance for the MLL optimizer. Only used when + `optimizer_kwargs` is None. + optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. + reset_parameters: If true, initializes the sparse parameter to the all zeros + after each iteration. + reset_dense_parameters: If true, re-initializes the dense parameters, e.g. + other GP hyper-parameters that are *not* part of the Relevance Pursuit + module, to the initial values provided by their associated constraints. + record_model_trace: If true, records the model state after every iteration. + initial_support: The support with which to initialize the sparse module. By + default, the support is initialized to the full set. + + Returns: + The relevance pursuit module after forward relevance pursuit optimization, and + a list of models with different supports that were optimized. + """ + if initial_support is not None: + sparse_module.remove_support() + sparse_module.expand_support(initial_support) + else: + sparse_module.full_support() + + if sparsity_levels is None: + sparsity_levels = list(range(len(sparse_module.support) + 1)) + + # since this is the backward algorithm, potential sparsity levels + # must be in decreasing order, unique, and less than the initial support. + sparsity_levels = list(set(sparsity_levels)) + sparsity_levels.sort(reverse=True) + + if optimizer_kwargs is None: + optimizer_kwargs = { + "options": {"maxiter": mll_iter, "ftol": mll_tol, "gtol": mll_tol} + } + + model_trace = [] if record_model_trace else None + + def optimize_mll(mll): + return sparse_module.optimize_mll( + mll=mll, + model_trace=model_trace, + reset_parameters=reset_parameters, + reset_dense_parameters=reset_dense_parameters, + optimizer_kwargs=optimizer_kwargs, + ) + + # if sparsity levels contains the initial support, remove it + if sparsity_levels[0] == len(sparse_module.support): + sparsity_levels.pop(0) + + optimize_mll(mll) # initial optimization + + for sparsity in sparsity_levels: + support_size = len(sparse_module.support) + num_contract = support_size - sparsity + if num_contract <= 0: + raise ValueError( + "sparsity_levels need to be decreasing and less than initial support." + ) + + contracted = sparse_module.support_contraction(mll=mll, n=num_contract) + # IDEA: could stop here if the marginal likelihood decreases, assuming that + # the posterior pdf of the support size is uni-modal. + if not contracted: # stationary support + break + + optimize_mll(mll) # re-optimize support + + return sparse_module, model_trace + + +# Bayesian Model Comparison +def get_posterior_over_support( + rp_class: type[RelevancePursuitMixin], + model_trace: list[Model], + log_support_prior: Callable[[Tensor], Tensor] | None = None, + prior_mean_of_support: float | None = None, +) -> tuple[Tensor, Tensor]: + """Computes the posterior distribution over a list of models. + Assumes we are storing both likelihood and GP model in the model_trace. + + Example: + >>> likelihood = SparseOutlierGaussianLikelihood( + >>> base_noise=base_noise, + >>> dim=X.shape[0], + >>> ) + >>> model = SingleTaskGP(train_X=X, train_Y=Y, likelihood=likelihood) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> # NOTE: `likelihood.noise_covar` is the `RelevancePursuitMixin` + >>> sparse_module = likelihood.noise_covar + >>> sparse_module, model_trace = backward_relevance_pursuit(sparse_module, mll) + >>> # NOTE: SparseOutlierNoise is the type of `sparse_module` + >>> support_size, bmc_probabilities = get_posterior_over_support( + >>> SparseOutlierNoise, model_trace, prior_mean_of_support=2.0 + >>> ) + + Args: + rp_class: The relevance pursuit class to use for computing the support size. + This is used to get the RelevancePursuitMixin from the Model via the static + method `_from_model`. We could generalize this and let the user pass this + getter instead. + model_trace: A list of models with different support sizes, usually generated + with relevance_pursuit. + log_support_prior: Callable that computes the log prior probability of a + support size. If None, uses a default exponential prior with a mean + specified by `prior_mean_of_support`. + prior_mean_of_support: A mean value for the default exponential prior + distribution over the support size. Ignored if `log_support_prior` + is passed. + + Returns: + A tensor of posterior marginal likelihoods, one for each model in the trace. + """ + if log_support_prior is None: + if prior_mean_of_support is None: + raise ValueError( + "Please provide a prior mean of the support size or pass a Callable " + "to evaluate the log prior density as log_support_prior." + ) + log_support_prior = partial(_exp_log_pdf, mean=prior_mean_of_support) + + log_support_prior = cast(Callable[[Tensor], Tensor], log_support_prior) + + def log_prior( + model: Model, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[Tensor, Tensor]: + sparse_module = rp_class._from_model(model) + num_support = torch.tensor( + len(sparse_module.support), dtype=dtype, device=device + ) + return num_support, log_support_prior(num_support) # pyre-ignore[29] + + log_mll_trace = [] + log_prior_trace = [] + support_size_trace = [] + for model in model_trace: + mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model) + mll.train() + X, Y = mll.model.train_inputs[0], mll.model.train_targets + F = mll.model(X) + mll_i = cast(Tensor, mll(F, Y, X)) + log_mll_trace.append(mll_i) + support_size, log_prior_i = log_prior( + model, + dtype=mll_i.dtype, + device=mll_i.device, + ) + support_size_trace.append(support_size) + log_prior_trace.append(log_prior_i) + + log_mll_trace = torch.stack(log_mll_trace) + log_prior_trace = torch.stack(log_prior_trace) + support_size_trace = torch.stack(support_size_trace) + + unnormalized_posterior_trace = log_mll_trace + log_prior_trace + evidence = unnormalized_posterior_trace.logsumexp(dim=-1) + posterior_probabilities = (unnormalized_posterior_trace - evidence).exp() + return support_size_trace, posterior_probabilities + + +def _exp_log_pdf(x: Tensor, mean: Tensor) -> Tensor: + """Compute the exponential log probability density. + + Args: + x: A tensor of values. + mean: A tensor of means. + + Returns: + A tensor of log probabilities. + """ + return -x / mean - math.log(mean) + + +def initialize_dense_parameters(model: Model) -> tuple[Model, dict[str, Any]]: + """Sets the dense parameters of a model to their initial values. Infers initial + values from the constraints their bounds, if no initial values are provided. If + a parameter does not have a constraint, it is initialized to zero. + + Args: + model: The model to initialize. + + Returns: + The re-initialized model, and a dictionary of initial values. + """ + constraints = dict(model.named_constraints()) + parameters = dict(model.named_parameters()) + initial_values = { + n: getattr(constraints.get(n + "_constraint", None), "_initial_value", None) + for n in parameters + } + lower_bounds = { + n: getattr( + constraints.get(n + "_constraint", None), + "lower_bound", + torch.tensor(-torch.inf), + ) + for n in parameters + } + upper_bounds = { + n: getattr( + constraints.get(n + "_constraint", None), + "upper_bound", + torch.tensor(torch.inf), + ) + for n in parameters + } + for n, v in initial_values.items(): + # if no initial value is provided, or the initial value is outside the bounds, + # use a rule-based initialization. + if v is None or not ((lower_bounds[n] <= v) and (v <= upper_bounds[n])): + if upper_bounds[n].isinf(): + if lower_bounds[n].isinf(): + v = 0.0 + else: + v = lower_bounds[n] + 1 + elif lower_bounds[n].isinf(): # implies u[n] is finite + v = upper_bounds[n] - 1 + else: # both are finite + v = lower_bounds[n] + torch.minimum( + torch.ones_like(lower_bounds[n]), + (upper_bounds[n] - lower_bounds[n]) / 2, + ) + initial_values[n] = v + + # the initial values need to be converted to the transformed space + for n, v in initial_values.items(): + c = constraints.get(n + "_constraint", None) + # convert the constraint into the latent space + if c is not None: + initial_values[n] = c.inverse_transform(v) + model.initialize(**initial_values) + parameters = dict(model.named_parameters()) + return model, initial_values diff --git a/botorch/test_functions/base.py b/botorch/test_functions/base.py index 19abe18ebb..a6ed1a5c0d 100644 --- a/botorch/test_functions/base.py +++ b/botorch/test_functions/base.py @@ -11,9 +11,11 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Any, Iterable, Iterator, Protocol import torch from botorch.exceptions.errors import InputDataError +from pyre_extensions import none_throws from torch import Tensor from torch.nn import Module @@ -203,3 +205,127 @@ def max_hv(self) -> float: def gen_pareto_front(self, n: int) -> Tensor: r"""Generate `n` pareto optimal points.""" raise NotImplementedError + + +class SeedingMixin(ABC): + _seeds: Iterator[int] | None + _current_seed: int | None + + @property + def has_seeds(self) -> bool: + return self._seeds is not None + + def increment_seed(self) -> int: + self._current_seed = next(none_throws(self._seeds)) + return none_throws(self._current_seed) + + @property + def seed(self) -> int | None: + return self._current_seed + + +# Outlier problems +class OutlierGenerator(Protocol): + def __call__(self, problem: BaseTestProblem, X: Tensor, bounds: Tensor) -> Tensor: + """Call signature for outlier generators for single-objective problems. + + Args: + problem: The test problem. + X: The input tensor. + bounds: The bounds of the test problem. + + Returns: + A tensor of outliers with shape X.shape[:-1] (1d if unbatched). + """ + pass # pragma: no cover + + +def constant_outlier_generator( + problem: Any, X: Tensor, bounds: Any, constant: float +) -> Tensor: + """ + Generates outliers that are all the same constant. To be used in conjunction with + `partial` to fix the constant value and conform to the `OutlierGenerator` protocol. + + Example: + >>> generator = partial(constant_outlier_generator, constant=1.0) + + Args: + problem: Not used. + X: The `batch_shape x n x d`-dim inputs. Also determines the number, dtype, + and device of the returned tensor. + bounds: Not used. + constant: The constant value of the outliers. + + Returns: + Tensor of shape `batch_shape x n` (1d if unbatched). + """ + return torch.full(X.shape[:-1], constant, dtype=X.dtype, device=X.device) + + +class CorruptedTestProblem(BaseTestProblem, SeedingMixin): + def __init__( + self, + base_test_problem: BaseTestProblem, + outlier_generator: OutlierGenerator, + outlier_fraction: float, + bounds: list[tuple[float, float]] | None = None, + seeds: Iterable[int] | None = None, + ) -> None: + """A problem with outliers. + + NOTE: Both noise_std and negate will be taken from the base test problem. + + Args: + base_test_problem: The base function to be corrupted. + outlier_generator: A function that generates outliers. It will be called + with arguments `f`, `X` and `bounds`, where `f` is the + `base_test_problem`, `X` is the + argument passed to the `forward` method, and `bounds` + are as here, and it returns the values of outliers. + outlier_fraction: The fraction of outliers. + bounds: The bounds of the function. + seeds: The seeds to use for the outlier generator. If seeds are provided, + the problem will iterate through the list of seeds, changing the seed + with a call to `next(seeds)` with every `forward` call. If a list is + provided, it will first be converted to an iterator. + """ + self.dim: int = base_test_problem.dim + self._bounds: list[tuple[float, float]] = ( + bounds if bounds is not None else base_test_problem._bounds + ) + super().__init__( + noise_std=base_test_problem.noise_std, + negate=base_test_problem.negate, + ) + self.base_test_problem = base_test_problem + self.outlier_generator = outlier_generator + self.outlier_fraction = outlier_fraction + self._current_seed: int | None = None + self._seeds: Iterator[int] | None = None if seeds is None else iter(seeds) + + def evaluate_true(self, X: Tensor) -> Tensor: + return self.base_test_problem.evaluate_true(X) + + def forward(self, X: Tensor, noise: bool = True) -> Tensor: + """ + Generate data at X and corrupt it, if noise is True. + + Args: + X: The `batch_shape x n x d`-dim inputs. + noise: Whether to corrupt the data. + + Returns: + A `batch_shape x n`-dim tensor. + """ + Y = super().forward(X, noise=noise) + if noise: + if self.has_seeds: + self.increment_seed() + torch.manual_seed(self.seed) + corrupt = torch.rand(X.shape[:-1]) < self.outlier_fraction + outliers = self.outlier_generator( + problem=self.base_test_problem, X=X, bounds=self.bounds + ) + Y = torch.where(corrupt, outliers, Y) + return Y diff --git a/botorch/test_functions/synthetic.py b/botorch/test_functions/synthetic.py index b4efc0920c..001c5dcae5 100644 --- a/botorch/test_functions/synthetic.py +++ b/botorch/test_functions/synthetic.py @@ -160,7 +160,7 @@ def __init__( def evaluate_true(self, X: Tensor) -> Tensor: a, b, c = self.a, self.b, self.c - part1 = -a * torch.exp(-b / math.sqrt(self.dim) * torch.linalg.norm(X, dim=-1)) + part1 = -a * torch.exp(-torch.linalg.norm(X, dim=-1) * b / math.sqrt(self.dim)) part2 = -(torch.exp(torch.mean(torch.cos(c * X), dim=-1))) return part1 + part2 + a + math.e diff --git a/botorch/utils/constraints.py b/botorch/utils/constraints.py index 1ee46984d6..4af5d408ab 100644 --- a/botorch/utils/constraints.py +++ b/botorch/utils/constraints.py @@ -15,6 +15,8 @@ from functools import partial import torch +from gpytorch.constraints import Interval + from torch import Tensor @@ -96,3 +98,48 @@ def get_monotonicity_constraints( if descending: A = -A return A, b + + +class NonTransformedInterval(Interval): + """Modification of the GPyTorch interval class that does not apply transformations. + + This is generally useful, and it is a requirement for the sparse parameters of the + Relevance Pursuit model [Ament2024pursuit]_, since it is not possible to achieve + exact zeros with the sigmoid transformations that are applied by default in the + GPyTorch Interval class. The variant implemented here does not apply transformations + to the parameters, instead passing the bounds constraint to the scipy L-BFGS + optimizer. This allows for the expression of exact zeros for sparse optimization + algorithms. + + NOTE: On a high level, the cleanest solution for this would be to separate out the + 1) definition and book-keeping of parameter constraints on the one hand, and + 2) the re-parameterization of the variables with some monotonic transformation, + since the two steps are orthogonal, but this would require refactoring GPyTorch. + """ + + def __init__( + self, + lower_bound: float | Tensor, + upper_bound: float | Tensor, + initial_value: float | Tensor | None = None, + ): + """Constructor of the NonTransformedInterval class. + + Args: + lower_bound: The lower bound of the interval. + upper_bound: The upper bound of the interval. + initial_value: The initial value of the parameter. + """ + super().__init__( + lower_bound=lower_bound, + upper_bound=upper_bound, + transform=None, + inv_transform=None, + initial_value=initial_value, + ) + + def transform(self, tensor: Tensor) -> Tensor: + return tensor + + def inverse_transform(self, transformed_tensor: Tensor) -> Tensor: + return transformed_tensor diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index 3592e19991..2e15d5d794 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -25,7 +25,8 @@ from botorch.sampling.base import MCSampler from botorch.sampling.get_sampler import GetSampler from botorch.sampling.stochastic_samplers import StochasticSampler -from botorch.test_functions.base import BaseTestProblem +from botorch.test_functions.base import BaseTestProblem, CorruptedTestProblem +from botorch.test_functions.synthetic import Rosenbrock from botorch.utils.transforms import unnormalize from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from linear_operator.operators import AddedDiagLinearOperator, DiagLinearOperator @@ -232,6 +233,25 @@ def test_evaluate_slack(self): self.assertTrue(is_equal.all().item()) +class TestCorruptedProblemsMixin(BotorchTestCase): + def setUp(self, suppress_input_warnings: bool = True) -> None: + super().setUp(suppress_input_warnings=suppress_input_warnings) + + def outlier_generator( + problem: torch.Tensor | BaseTestProblem, X: Any, bounds: Any + ) -> torch.Tensor: + return torch.ones(X.shape[0]) + + self.outlier_generator = outlier_generator + + self.rosenbrock_problem = CorruptedTestProblem( + base_test_problem=Rosenbrock(), + outlier_fraction=1.0, + outlier_generator=outlier_generator, + seeds=[1, 2], + ) + + class MockPosterior(Posterior): r"""Mock object that implements dummy methods and feeds through specified outputs""" diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index ec6995565f..4880d22931 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -99,6 +99,10 @@ Fully Bayesian Multitask GP Models .. automodule:: botorch.models.fully_bayesian_multitask :members: +Relevance Pursuit Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.relevance_pursuit + :members: Model Components ------------------------------------------- @@ -134,6 +138,9 @@ Likelihoods .. automodule:: botorch.models.likelihoods.pairwise :members: +.. automodule:: botorch.models.likelihoods.sparse_outlier_noise + :members: + Transforms ------------------------------------------- diff --git a/test/models/test_relevance_pursuit.py b/test/models/test_relevance_pursuit.py new file mode 100644 index 0000000000..0cb7a4b3d8 --- /dev/null +++ b/test/models/test_relevance_pursuit.py @@ -0,0 +1,444 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import itertools + +from functools import partial +from typing import Callable + +import gpytorch +import torch +from botorch.exceptions.warnings import InputDataWarning +from botorch.models import SingleTaskGP +from botorch.models.likelihoods.sparse_outlier_noise import ( + SparseOutlierGaussianLikelihood, + SparseOutlierNoise, +) +from botorch.models.relevance_pursuit import ( + backward_relevance_pursuit, + forward_relevance_pursuit, + get_posterior_over_support, + relevance_pursuit, +) +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.test_functions.base import constant_outlier_generator, CorruptedTestProblem + +from botorch.test_functions.synthetic import Ackley +from botorch.utils.constraints import NonTransformedInterval +from botorch.utils.testing import BotorchTestCase +from gpytorch.constraints import Interval + +from gpytorch.kernels import RBFKernel, ScaleKernel +from gpytorch.likelihoods.noise_models import HomoskedasticNoise +from gpytorch.means import ZeroMean +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from pyre_extensions import none_throws +from torch import Tensor + + +class TestRobustGP(BotorchTestCase): + def _make_dataset( + self, + n: int, + num_outliers: int, + dtype: torch.dtype, + seed: int = 1, + ) -> tuple[Tensor, Tensor, list[int]]: + torch.manual_seed(seed) + + X = torch.rand(n, 1, dtype=dtype, device=self.device) + F = torch.sin(2 * torch.pi * (2 * X)).sum(dim=-1, keepdim=True) + sigma = 1e-2 + Y = F + torch.randn_like(F) * sigma + outlier_indices = list(range(n - num_outliers, n)) + Y[outlier_indices] = -Y[outlier_indices] + return X, Y, outlier_indices + + def _get_robust_model( + self, + X: Tensor, + Y: Tensor, + likelihood: SparseOutlierGaussianLikelihood, + ) -> SingleTaskGP: + min_lengthscale = 0.1 + lengthscale_constraint = NonTransformedInterval( + min_lengthscale, torch.inf, initial_value=0.2 + ) + d = X.shape[-1] + + kernel = ScaleKernel( + RBFKernel(ard_num_dims=d, lengthscale_constraint=lengthscale_constraint), + outputscale_constraint=NonTransformedInterval( + 0.01, 10.0, initial_value=0.1 + ), + ).to(dtype=X.dtype, device=self.device) + + model = SingleTaskGP( + train_X=X, + train_Y=Y, + mean_module=ZeroMean(), + covar_module=kernel, + input_transform=Normalize(d=X.shape[-1]), + outcome_transform=Standardize(m=Y.shape[-1]), + likelihood=likelihood, + ) + model.to(dtype=X.dtype, device=self.device) + return model + + def test_robust_gp_end_to_end(self) -> None: + self._test_robust_gp_end_to_end(convex_parameterization=False, mll_tol=1e-8) + + def test_robust_convex_gp_end_to_end(self) -> None: + self._test_robust_gp_end_to_end(convex_parameterization=True, mll_tol=1e-7) + + def _test_robust_gp_end_to_end( + self, + convex_parameterization: bool, + mll_tol: float, + ) -> None: + """End-to-end robust GP test.""" + n = 32 + dtype = torch.double + num_outliers = 6 + X, Y, outlier_indices = self._make_dataset( + n=n, num_outliers=num_outliers, dtype=dtype, seed=1 + ) + min_noise = 1e-6 # minimum noise variance constraint + max_noise = 1e-2 + base_noise = HomoskedasticNoise( + noise_constraint=NonTransformedInterval( + min_noise, max_noise, initial_value=1e-3 + ) + ).to(dtype=dtype, device=self.device) + + rp_likelihood = SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + ) + + model = self._get_robust_model(X=X, Y=Y, likelihood=rp_likelihood) + + X_test = torch.rand(3, 1, dtype=dtype, device=self.device) + with self.assertWarnsRegex(InputDataWarning, "SparseOutlierNoise"): + model.posterior(X_test, observation_noise=True) + + # optimization via backward relevance pursuit (num_contract=1) + sparse_module = model.likelihood.noise_covar + sparse_module.full_support() + mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model) + # IDEA: could speed up this test further, by initially reducing the support set + # significantly, since the first few steps should be the easiest. + # In the canonical parameterization, this actually leads this test to fail, + # since outliers in this example are very challenging. Similarly, the canonical + # parameterization is sensitive to increasing optimizer tolerances, which is + # remedied by the convex parameterization. + sparse_module, model_trace = relevance_pursuit( + sparse_module=sparse_module, + mll=mll, + num_iter=n - 1, + num_expand=0, + num_contract=1, + record_model_trace=True, + reset_parameters=False, + # NOTE: When `mll_iter` > 100, runtime seems to be primarily controlled + # by the convergence tolerance + mll_iter=1024, + mll_tol=mll_tol, + ) + model_trace = none_throws(model_trace) + + # Bayesian model comparison + prior_mean_of_support = 2.0 + support_size, bmc_probabilities = get_posterior_over_support( + SparseOutlierNoise, model_trace, prior_mean_of_support=prior_mean_of_support + ) + self.assertEqual(len(support_size), n) + self.assertEqual(len(support_size), len(bmc_probabilities)) + self.assertAlmostEqual(bmc_probabilities.sum().item(), 1.0) + map_index = torch.argmax(bmc_probabilities) + + with self.assertRaisesRegex( + ValueError, "Please provide a prior mean of the support size" + ): + get_posterior_over_support(SparseOutlierNoise, model_trace) + + # The MAP model on this data with this specific seed is not detecting one of the + # "outliers" since it's actually consistent with the data. + map_model = model_trace[map_index] + sparse_module = map_model.likelihood.noise_covar + undetected_outliers = set(outlier_indices) - set(sparse_module.support) + self.assertLessEqual(len(undetected_outliers), 1) + if len(undetected_outliers) == 1: + self.assertEqual(undetected_outliers.pop(), 25) + + def test_robust_relevance_pursuit(self) -> None: + for optimizer, convex_parameterization, dtype in itertools.product( + [forward_relevance_pursuit, backward_relevance_pursuit], + [True, False], + [torch.float32, torch.float64], + ): + with self.subTest( + optimizer=optimizer, + convex_parameterization=convex_parameterization, + dtype=dtype, + ): + # testing the loo functionality only with the forward algorithm + # and the convex parameterization, to save test runtime. + loo = ( + optimizer is forward_relevance_pursuit + ) and convex_parameterization + self._test_robust_relevance_pursuit( + optimizer=optimizer, + convex_parameterization=convex_parameterization, + dtype=dtype, + loo=loo, + ) + + def _test_robust_relevance_pursuit( + self, + optimizer: Callable, + convex_parameterization: bool, + dtype: torch.dtype, + loo: bool, + ) -> None: + """ + Test executing with different combinations of arguments, without checking the + model fit end-to-end. + """ + n = 32 + dtype = torch.double + X, Y, _ = self._make_dataset(n=n, num_outliers=6, dtype=dtype, seed=1) + min_noise = 1e-6 # minimum noise variance constraint + max_noise = 1e-2 + base_noise = HomoskedasticNoise( + noise_constraint=NonTransformedInterval( + min_noise, max_noise, initial_value=1e-3 + ) + ).to(dtype=dtype, device=self.device) + + with self.assertRaisesRegex( + ValueError, + "`rho_constraint` must be a `NonTransformedInterval` if it is not None.", + ): + SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + rho_constraint=Interval(0.0, 1.0), # pyre-ignore[6] + ) + + with self.assertRaisesRegex(ValueError, "rho_constraint.lower_bound >= 0"): + SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + rho_constraint=NonTransformedInterval(-1.0, 1.0), + ) + + if convex_parameterization: + with self.assertRaisesRegex(ValueError, "rho_constraint.upper_bound <= 1"): + SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + rho_constraint=NonTransformedInterval(0.0, 2.0), + ) + else: # with the canonical parameterization, any upper bound on rho is valid. + likelihood_with_other_bounds = SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + rho_constraint=NonTransformedInterval(0.0, 2.0), + ) + noise_w_other_bounds = likelihood_with_other_bounds.noise_covar + self.assertEqual(noise_w_other_bounds.raw_rho_constraint.lower_bound, 0.0) + self.assertEqual(noise_w_other_bounds.raw_rho_constraint.upper_bound, 2.0) + + rp_likelihood = SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + ) + + sparse_noise = rp_likelihood.noise_covar + sparse_noise.to_dense() + self.assertFalse(sparse_noise.is_sparse) + dense_rho = sparse_noise.rho + sparse_noise.to_sparse() + self.assertTrue(sparse_noise.is_sparse) + sparse_rho = sparse_noise.rho + self.assertAllClose(dense_rho, sparse_rho) + + with self.assertRaisesRegex(NotImplementedError, "variational inference"): + rp_likelihood.expected_log_prob(target=None, input=None) # pyre-ignore[6] + + # testing prior initialization + likelihood_with_prior = SparseOutlierGaussianLikelihood( + base_noise=base_noise, + dim=X.shape[0], + convex_parameterization=convex_parameterization, + rho_prior=gpytorch.priors.NormalPrior(loc=1 / 2, scale=0.1), + loo=loo, + ) + self.assertIsInstance( + likelihood_with_prior.noise_covar.rho_prior, gpytorch.priors.NormalPrior + ) + + # combining likelihood with full GP model + model = self._get_robust_model(X=X, Y=Y, likelihood=rp_likelihood) + + # testing the _from_model method + with self.assertRaisesRegex( + ValueError, + "The model's likelihood does not have a SparseOutlierNoise noise", + ): + SparseOutlierNoise._from_model(SingleTaskGP(train_X=X, train_Y=Y)) + + self.assertEqual( + SparseOutlierNoise._from_model(model), rp_likelihood.noise_covar + ) + + # Test that there is a warning because + # model.likelihood.noise_covar._cached_train_inputs is None + # and the shape of the test inputs are not compatible with the noise module. + X_test = torch.rand(3, 1, dtype=dtype, device=self.device) + with self.assertWarnsRegex( + InputDataWarning, + "Robust rho not applied because the last dimension of the base noise " + "covariance", + ): + model.likelihood.noise_covar.forward(X_test) + + # executing once successfully so that _cached_train_inputs is populated + model.posterior(X, observation_noise=True) + + X_test = torch.rand_like(X) # same size as training inputs but not the same + with self.assertWarnsRegex( + InputDataWarning, + "Robust rho not applied because the passed train_inputs are not equal to", + ): + model.posterior(X_test, observation_noise=True) + + # optimization via backward relevance pursuit + sparse_module = model.likelihood.noise_covar + mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model) + + with self.assertWarnsRegex( + InputDataWarning, + "not applied because the training inputs were not passed to the likelihood", + ): + mll(mll.model(*mll.model.train_inputs), mll.model.train_targets) + + extra_kwargs = {} + if convex_parameterization: + # initial_support doesn't have any special effect in the convex + # parameterization, adding this here so we don't have to test every + # combination of arguments. + extra_kwargs["initial_support"] = torch.randperm( + X.shape[0], device=self.device + )[: n // 2] # initializing with half the support + + # testing the reset_parameters functionality in conjunction with the convex + # parameterization to limit test runtime, but they are orthogonal. + reset_parameters = convex_parameterization + sparse_module, model_trace = optimizer( + sparse_module=sparse_module, + mll=mll, + record_model_trace=True, + reset_parameters=reset_parameters, + reset_dense_parameters=reset_parameters, + mll_iter=2, + **extra_kwargs, + ) + model_trace = none_throws(model_trace) + # Bayesian model comparison + prior_mean_of_support = 2.0 + support_size, bmc_probabilities = get_posterior_over_support( + SparseOutlierNoise, model_trace, prior_mean_of_support=prior_mean_of_support + ) + if optimizer is backward_relevance_pursuit: + expected_length = len(X) + 1 + if "initial_support" in extra_kwargs: + expected_length = len(X) // 2 + 1 + self.assertEqual(len(model_trace), expected_length) + self.assertEqual(support_size.max(), expected_length - 1) + self.assertEqual(support_size[-1].item(), 0) # includes zero + # results of forward are sorted in decreasing order of support size + self.assertAllClose(support_size, support_size.sort(descending=True).values) + elif optimizer is forward_relevance_pursuit: + # the forward algorithm will only add until no additional rho has a + # non-negative gradient, which can happen before the full support is added. + min_expected_length = 10 + self.assertGreaterEqual(len(model_trace), min_expected_length) + lower_bound = len(X) // 2 if "initial_support" in extra_kwargs else 0 + self.assertGreaterEqual(support_size.min().item(), lower_bound) + self.assertEqual( + support_size[0].item(), + n // 2 if "initial_support" in extra_kwargs else 0, + ) + + # results of forward are sorted in increasing order of support size + self.assertAllClose( + support_size, support_size.sort(descending=False).values + ) + + def test_experimental_utils(self) -> None: + base_f = Ackley(dim=3) + outlier_value = 100.0 + outlier_generator = partial(constant_outlier_generator, constant=outlier_value) + + # no outliers + f = CorruptedTestProblem( + base_test_problem=base_f, + outlier_generator=outlier_generator, + outlier_fraction=0.0, + ) + n, d = 16, base_f.dim + X = torch.randn(n, d) + Y = f(X) + self.assertAllClose(Y, base_f(X)) + + # all outliers + f = CorruptedTestProblem( + base_test_problem=base_f, + outlier_generator=outlier_generator, + outlier_fraction=1.0, + ) + n, d = 16, base_f.dim + X = torch.randn(n, d) + Y = f(X) + self.assertTrue(((Y - base_f(X)).abs() > 1).all()) + + # testing seeds + num_seeds = 3 + f = CorruptedTestProblem( + base_test_problem=base_f, + outlier_generator=outlier_generator, + outlier_fraction=1 / 2, + seeds=range(num_seeds), + ) + n, d = 8, base_f.dim + X = torch.randn(n, d) + Y_last = base_f(X) + for _ in range(num_seeds): + Y = f(X) + # with these seeds we should have at least 1 outlier and less than n, + # which shows that the masking works correctly. + num_outliers = (Y == outlier_value).sum() + self.assertGreater(num_outliers, 1) + self.assertLess(num_outliers, n) + # testing that the outliers are not the same, even if we evaluate on the + # same input. + self.assertTrue((Y_last - Y).norm() > 1.0) + Y_last = Y + + # after num_seeds has been exhausted, the evaluation will error. + with self.assertRaises(StopIteration): + f(X) diff --git a/test/test_functions/test_base.py b/test/test_functions/test_base.py index b39fa7a049..3ee76558c6 100644 --- a/test/test_functions/test_base.py +++ b/test/test_functions/test_base.py @@ -6,13 +6,13 @@ import torch from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem -from botorch.utils.testing import BotorchTestCase +from botorch.utils.testing import BotorchTestCase, TestCorruptedProblemsMixin from torch import Tensor class DummyTestProblem(BaseTestProblem): dim = 2 - _bounds = [(0, 1), (2, 3)] + _bounds = [(0.0, 1.0), (2.0, 3.0)] def evaluate_true(self, X: Tensor) -> Tensor: return -X.pow(2).sum(dim=-1) @@ -56,3 +56,26 @@ def test_constrained_base_test_problem(self): feas = problem.is_feasible(X=X) self.assertFalse(feas[0].item()) self.assertTrue(feas[1].item()) + + +class TestSeedingMixin(TestCorruptedProblemsMixin): + def test_seed_iteration(self) -> None: + problem = self.rosenbrock_problem + + self.assertTrue(problem.has_seeds) + self.assertIsNone(problem.seed) # increment_seed needs to be called first + problem.increment_seed() + self.assertEqual(problem.seed, 1) + problem.increment_seed() + self.assertEqual(problem.seed, 2) + with self.assertRaises(StopIteration): + problem.increment_seed() + + +class TestCorruptedTestProblem(TestCorruptedProblemsMixin): + def test_basic_rosenbrock(self) -> None: + problem = self.rosenbrock_problem + x = torch.rand(5, 1) + result = problem(x) + # the outlier_generator sets corruptions to 1 + self.assertTrue((result == 1).all()) diff --git a/test/utils/multi_objective/test_scalarization.py b/test/utils/multi_objective/test_scalarization.py index 723135eef8..4ae745b5aa 100644 --- a/test/utils/multi_objective/test_scalarization.py +++ b/test/utils/multi_objective/test_scalarization.py @@ -15,6 +15,7 @@ class TestGetChebyshevScalarization(BotorchTestCase): def test_get_chebyshev_scalarization(self): + torch.manual_seed(1234) tkwargs = {"device": self.device} Y_train = torch.rand(4, 2, **tkwargs) neg_Y_train = -Y_train