Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] The kernel ScaleKernel is not equipped to handle and diag. #2411

Closed
cookbook-ms opened this issue Sep 20, 2023 · 2 comments
Closed

[Bug] The kernel ScaleKernel is not equipped to handle and diag. #2411

cookbook-ms opened this issue Sep 20, 2023 · 2 comments
Labels

Comments

@cookbook-ms
Copy link

🐛 Bug

I have a self defined kernel: basically a exponential kernel but operated in the eigenspace.

To reproduce

** Code snippet to reproduce **

class EdgeDiffusionKernelOceanFlow(Kernel):
    """
    Edge diffusion kernel for simplicial complexes
    
    Parameters
    ----------
    laplacians : torch.sparse_coo_tensor
        Laplacians of the simplicial complex
    kappa : tuple of float
        Diffusion parameters for the Laplacians
    s : tuple of float
        Scaling parameters for the Kernel
    """
    def __init__(self, eigenpairs, kappa_bounds=(1e-5,1e5)): 
        super().__init__()
        # self.eig_h, self.eig_g, self.eig_c, self.eigvec_h, self.eigvec_g, self.eigvec_c = eigenpairs
        self.eigvecs, self.eigvals = eigenpairs
        # register the raw parameters
        self.register_parameter(
            name='raw_kappa_down', parameter=torch.nn.Parameter(torch.zeros(1,1))
        )
        self.register_parameter(
            name='raw_kappa_up', parameter=torch.nn.Parameter(torch.zeros(1,1))
        )
        # set the kappa constraints
        self.register_constraint(
            'raw_kappa_down', Interval(*kappa_bounds)
        )
        self.register_constraint(
            'raw_kappa_up', Interval(*kappa_bounds)
        )
        # we do not set the prior on the parameters 

    # set up the actual parameters 
    @property
    def kappa_down(self):
        return self.raw_kappa_down_constraint.transform(self.raw_kappa_down)

    @kappa_down.setter
    def kappa_down(self, value):
        self._set_kappa_down(value)

    def _set_kappa_down(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_kappa_down)
        self.initialize(raw_kappa_down=self.raw_kappa_down_constraint.inverse_transform(value))

    @property
    def kappa_up(self):
        return self.raw_kappa_up_constraint.transform(self.raw_kappa_up)
    
    @kappa_up.setter
    def kappa_up(self, value):
        self._set_kappa_up(value)

    def _set_kappa_up(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_kappa_up)
        self.initialize(raw_kappa_up=self.raw_kappa_up_constraint.inverse_transform(value))
 
    def _eval_covar_matrix(self):
        """Define the full covariance matrix -- full kernel matrix as a property to avoid repeative computation of the kernel matrix"""
        # K1 = torch.linalg.matrix_exp(- (self.kappa_down*self.L1_down + self.kappa_up*self.L1_up))
        k = (self.kappa_down*self.eigvals).squeeze()
        K1 = self.eigvecs @ DiagLinearOperator(k) @ self.eigvecs.T
        # K2 = torch.linalg.matrix_exp(-self.kappa_up*self.L1_up)
        # This is equivalent to K1+K2-h_0 * I (remove the repeated identity part)
        return K1
    
    @property
    def covar_matrix(self):
        return self._eval_covar_matrix()
        
    # define the kernel function 
    def forward(self, x1, x2=None, **params):
        x1, x2 = x1.long(), x2.long()
        x1 = x1.squeeze(-1)
        x2 = x2.squeeze(-1)
        # compute the kernel matrix
        if x2 is None: 
            x2 = x1
            
        return self.covar_matrix[x1,:][:,x2]

My covariance module:

self.covar_module = gpytorch.kernels.ScaleKernel(kernel, outputscale_constraint=Interval(1e-5, 1e5))

When computing the variance and MSLL, it has the error

pred_mean, pred_var = observed_pred.mean, observed_pred.variance

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([9512]). Got size torch.Size([9512, 9512])",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/mmaosheng/kernel_methods/Edge-GPRegression_OceanFlow.ipynb Cell 35 line 1
----> <a href='vscode-notebook-cell://ssh-remote%2Bdesignare1/home/mmaosheng/kernel_methods/Edge-GPRegression_OceanFlow.ipynb#X43sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> pred_mean, pred_var = observed_pred.mean, observed_pred.variance

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309, in MultivariateNormal.variance(self)
    305 @property
    306 def variance(self) -> Tensor:
    307     if self.islazy:
    308         # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
--> 309         diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
    310         diag = diag.view(diag.shape[:-1] + self._event_shape)
    311         variance = diag.expand(self._batch_shape + self._event_shape)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411, in LinearOperator.diagonal(self, offset, dim1, dim2)
   1409 elif not self.is_square:
   1410     raise RuntimeError(\"LinearOperator#diagonal is only implemented for square operators.\")
-> 1411 return self._diagonal()

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
     22 @functools.wraps(method)
     23 def wrapped(self, *args, **kwargs):
     24     with torch.set_grad_enabled(self._is_grad_enabled):
---> 25         output = method(self, *args, **kwargs)
     26     return output

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126, in LazyEvaluatedKernelTensor._diagonal(self)
    124     expected_shape = self.shape[:-1]
    125     if res.shape != expected_shape:
--> 126         raise RuntimeError(
    127             \"The kernel {} is not equipped to handle and diag. Expected size {}. \"
    128             \"Got size {}\".format(self.kernel.__class__.__name__, expected_shape, res.shape)
    129         )
    131 if isinstance(res, LinearOperator):
    132     res = res.to_dense()

RuntimeError: The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([9512]). Got size torch.Size([9512, 9512])"
}

Expected Behavior

I expect to get the variance or MLSS from the trained model and likelihood.

System information

Please complete the following information:

  • GPyTorch Version 1.11
  • PyTorch Version 2.2.0
  • Computer OS NAME="Ubuntu"
    VERSION="20.04.5 LTS (Focal Fossa)"

Additional context

Add any other context about the problem here.
I noticed similar type of issues in other issues too. While I tried to figure out the reason, I believe it is because I used
K1 = self.eigvecs @ self.eigvecs.T regardless of if there is a diagonal matrix inbetween

@gpleiss
Copy link
Member

gpleiss commented Sep 21, 2023

@cookbook-ms the issue is that your forward function doesn't accept a diag=True keyword argument, which is necessary to obtain variance estimates. See the RBF or the LinearKernel implmentations for an example.

I realize that the custom tutorial documentation also does not include this option. This option may become obsolete with #2342 , so I'm not going to suggest fixing the tutorial at this moment.

@gpleiss gpleiss closed this as completed Sep 21, 2023
@cookbook-ms
Copy link
Author

Please refer to how LinearKernel implements this
https://docs.gpytorch.ai/en/stable/_modules/gpytorch/kernels/linear_kernel.html#LinearKernel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants