diff --git a/gpytorch/kernels/index_kernel.py b/gpytorch/kernels/index_kernel.py index 7fa5e01f3..7fde6497e 100644 --- a/gpytorch/kernels/index_kernel.py +++ b/gpytorch/kernels/index_kernel.py @@ -65,9 +65,11 @@ def __init__( if var_constraint is None: var_constraint = Positive() - self.register_parameter( - name="covar_factor", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks, rank)) - ) + self.rank = rank + if self.rank > 0: + self.register_parameter( + name="covar_factor", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks, self.rank)) + ) self.register_parameter(name="raw_var", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks))) if prior is not None: if not isinstance(prior, Prior): @@ -88,13 +90,19 @@ def _set_var(self, value): self.initialize(raw_var=self.raw_var_constraint.inverse_transform(value)) def _eval_covar_matrix(self): - cf = self.covar_factor - return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var) + if self.rank > 0: + cf = self.covar_factor + return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var) + else: + return torch.diag_embed(self.var) @property def covar_matrix(self): var = self.var - res = PsdSumLinearOperator(RootLinearOperator(self.covar_factor), DiagLinearOperator(var)) + if self.rank > 0: + res = PsdSumLinearOperator(RootLinearOperator(self.covar_factor), DiagLinearOperator(var)) + else: + res = DiagLinearOperator(var) return res def forward(self, i1, i2, **params):