Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixedNoiseGaussianLikelihood results in negative variance #864

Open
mgarort opened this issue Sep 11, 2019 · 17 comments
Open

FixedNoiseGaussianLikelihood results in negative variance #864

mgarort opened this issue Sep 11, 2019 · 17 comments
Assignees
Labels
bug stability When models return NaNs and stuff

Comments

@mgarort
Copy link

mgarort commented Sep 11, 2019

Hi,

I am trying to train a most likely heteroscedastic GP (from "Most likely heteroscedastic GP regression", Kersting et al. 2007). To this end I am using the likelihood FixedNoiseGaussianLikelihood. I am setting the noise to positive values r.

(Pdb) print(r)
tensor([0.0086, 0.0071, 0.0071, 0.0069, 0.0067, 0.0067, 0.0065, 0.0065, 0.0065,
        0.0065, 0.0065, 0.0065, 0.0065, 0.0066, 0.0066, 0.0066, 0.0070, 0.0076,
        0.0076, 0.0090, 0.0107, 0.0110, 0.0117, 0.0122, 0.0130, 0.0135, 0.0140,
        0.0154, 0.0202, 0.0208, 0.0218, 0.0226, 0.0229, 0.0265, 0.0270, 0.0280,
        0.0282, 0.0285, 0.0287, 0.0289, 0.0290, 0.0290, 0.0289, 0.0288, 0.0287,
        0.0276, 0.0269, 0.0241, 0.0219, 0.0218, 0.0191, 0.0177, 0.0175, 0.0166,
        0.0164, 0.0160, 0.0146, 0.0145, 0.0131, 0.0120, 0.0112, 0.0105, 0.0100,
        0.0094, 0.0082, 0.0081, 0.0079, 0.0072, 0.0072, 0.0070, 0.0070, 0.0059,
        0.0058, 0.0057, 0.0057, 0.0056, 0.0052, 0.0051, 0.0049, 0.0049, 0.0048,
        0.0048, 0.0045, 0.0045, 0.0044, 0.0043, 0.0043, 0.0042, 0.0042, 0.0042,
        0.0042, 0.0042, 0.0042, 0.0042, 0.0042, 0.0042, 0.0045, 0.0046, 0.0046,
        0.0047, 0.0050, 0.0051, 0.0059, 0.0067, 0.0086, 0.0104, 0.0115, 0.0161,
        0.0213, 0.0226, 0.0278, 0.0399, 0.0413, 0.0418, 0.0463, 0.0567, 0.0592,
        0.2299, 0.2421, 0.2920, 0.3486, 0.5690, 0.7409, 0.8167, 1.3840, 1.4557,
        1.3335, 1.2206, 1.1017, 0.8272])
lik_3 = FixedNoiseGaussianLikelihood(noise=r, learn_additional_noise=False) 
GP3 = ExactGPModel(self.train_x,self.train_y,lik_3)
GP3, lik_3 = train_a_GP(GP3,self.train_x,self.train_y,lik_3,self.training_iter)

where train_a_GP is simply the following training function (copied from a GPytorch regression tutorial):

def train_a_GP(model, train_x, train_y, likelihood, training_iter):
    # train GP_model for training_iter iterations
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam([
        {'params': model.parameters()},  # Includes GaussianLikelihood parameters
    ], lr=0.1)

    # "Loss" for GPs - the marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
        ))
        optimizer.step()

        model.eval()
        likelihood.eval()
    return model, likelihood

However when I try to obtain predictions, the variance of the MultivariateNormal returned seems to be negative.

GP3.eval()
lik_3.eval()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    train_pred = lik_3(GP3(self.train_x),noise=r)
(Pdb) train_pred.variance
tensor([-0.1621, -0.1650, -0.1650, -0.1652, -0.1653, -0.1653, -0.1651, -0.1651,
        -0.1650, -0.1650, -0.1648, -0.1648, -0.1647, -0.1644, -0.1642, -0.1642,
        -0.1631, -0.1619, -0.1618, -0.1586, -0.1548, -0.1542, -0.1528, -0.1515,
        -0.1497, -0.1487, -0.1475, -0.1444, -0.1334, -0.1320, -0.1295, -0.1279,
        -0.1271, -0.1192, -0.1182, -0.1162, -0.1159, -0.1155, -0.1152, -0.1150,
        -0.1152, -0.1159, -0.1160, -0.1167, -0.1168, -0.1206, -0.1226, -0.1303,
        -0.1358, -0.1360, -0.1424, -0.1454, -0.1460, -0.1479, -0.1483, -0.1493,
        -0.1522, -0.1524, -0.1552, -0.1574, -0.1590, -0.1602, -0.1611, -0.1622,
        -0.1642, -0.1645, -0.1648, -0.1660, -0.1660, -0.1662, -0.1663, -0.1680,
        -0.1682, -0.1683, -0.1684, -0.1684, -0.1691, -0.1693, -0.1696, -0.1696,
        -0.1697, -0.1698, -0.1701, -0.1702, -0.1703, -0.1704, -0.1705, -0.1706,
        -0.1707, -0.1707, -0.1707, -0.1707, -0.1707, -0.1708, -0.1707, -0.1707,
        -0.1705, -0.1704, -0.1703, -0.1702, -0.1700, -0.1699, -0.1690, -0.1681,
        -0.1660, -0.1638, -0.1626, -0.1569, -0.1504, -0.1486, -0.1416, -0.1242,
        -0.1221, -0.1214, -0.1145, -0.0983, -0.0943,  0.2122,  0.2350,  0.3277,
         0.4312,  0.8048,  1.0528,  1.1486,  1.5613,  1.5205,  1.2962,  1.1957,
         1.1100,  0.8646])

What am I doing wrong? Any help would be greatly appreciated.

Thanks a lot!

Miguel

@jacobrgardner
Copy link
Member

Hi @mgarort --

Nothing immediately jumps out at me as wrong. Would you possibly be able to provide example data for which this happens, as well as the definition ExactGPModel if it differs from our basic examples?

@mgarort
Copy link
Author

mgarort commented Sep 11, 2019

Hi @jacobrgardner

Thanks for the quick reply. Here's a simple script that reproduces the error, together with the 3 files needed (x, y and noise r).

reproduce_negative_variance.zip

@Balandat
Copy link
Collaborator

If it helps, you can also take a look at our convenience wrapper model for observed variances in BoTorch: https://github.com/pytorch/botorch/blob/master/botorch/models/gp_regression.py#L118-L127

There is also a PR open for the most likely heteroskedastic GP fitting (which I'll probably get to in the near future): pytorch/botorch#250. Note that this uses a full heteroskedastic GP, where the noise model is itself a GP (rather than fixed variances).

@mgarort
Copy link
Author

mgarort commented Sep 11, 2019

Hi @Balandat

Thanks! I'll take a look at the wrapper model in BoTorch

I think I have modelled the noise correctly according to the most likely heteroscedastic GP (although of course there's always room for error!). The vector r with variances that is passed to GP3 has been produced by a previous GP that models the noise called GP2 (following the notation convention on section 4 of the paper, "Optimization"). If anything, I think I am most likely to have made an incorrect assumption about the way GPytorch incorporates the heteroscedastic noise r into the model...

Best,

Miguel

@mgarort
Copy link
Author

mgarort commented Sep 11, 2019

@Balandat @jacobrgardner

Here's my full attempt at an implementation of the most likely heteroscedastic GP.

import torch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.constraints import Positive

def train_a_GP(model, train_x, train_y, likelihood, training_iter):
    """
    Simple utility function to train a Gaussian process (GP) model with Adam (following the examples on the docs).

    :param model: GP model
    :param train_x: tensor with training features X
    :param train_y: tensor with training targets Y
    :param likelihood: likelihood function
    :param training_iter: number of iterations to train
    :return: trained GP model, trained likelihood
    """
    # train GP_model for training_iter iterations
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam([
        {'params': model.parameters()},  # Includes GaussianLikelihood parameters
    ], lr=0.1)

    # "Loss" for GPs - the marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
        ))
        optimizer.step()

        model.eval()
        likelihood.eval()
    return model, likelihood


class ExactGPModel(gpytorch.models.ExactGP):
    """
    Exact Gaussian process model (following the examples in the docs).
    """
    def __init__(self, train_x, train_y, likelihood):
        """
        Initializer function. Specifies the mean and the covariance functions.

        :param train_x: tensor with training features X
        :param train_y: tensor with training targets Y
        :param likelihood: likelihood function
        :return: None
        """
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel())

    def forward(self, x):
        """
        Forward method to evaluate GP.

        :param x: tensor with features X on which to evaluate the GP.
        :return: MultivariateNormal
        """
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


class hetGPModel():
    """
    Most likely heteroscedastic GP model.
    """
    def __init__(self,train_x,train_y,training_iter=100,het_fitting_iter=10,var_estimator_n=50):
        """
        Initializer function.

        :param train_x: tensor with training features X
        :param train_y: tensor with training targets Y
        :param training_iter: number of iterations to train GP1, GP2 and GP3
        :param het_fitting_iter: number of iterations to run the pseudo expectation maximization (EM) algorithm while refining GP3
        :param var_estimator_n: number of samples to estimate the variance at each training point
        :return: None
        """
        self.train_x = train_x
        self.train_y = train_y
        self.training_iter = training_iter
        self.var_estimator_n = var_estimator_n
        self.het_fitting_iter = het_fitting_iter
        self.final_GP = None
        self.final_lik = None
        self.final_r = None

    def predict(self,x):
        """
        Predict method to evaluate GP.

        :param x: tensor with features X on which to evaluate the GP.
        :return: MultivariateNormal
        """
        if self.final_GP is None:
            raise RuntimeError('hetGPModel needs to be trained before using it')
        return self.final_GP(x)

    def train_model(self):
        """
        Train most likely heteroscedastic GP, in which one GP predicts the mean and another GP predicts the variance. This function
        corresponds to section '4. Optimization' in the original most likely heteroscedastic GP paper (Kersting et al. 2007).

        :return: None
        """
        # train self.GP1 if self.is_GP1_trained == False, and then set it to True. Otherwise ignore
        lik_1 = GaussianLikelihood()
        GP1 = ExactGPModel(self.train_x,self.train_y,lik_1)
        GP1, lik_1 = train_a_GP(GP1,self.train_x,self.train_y,lik_1,self.training_iter)
        for i in range(self.het_fitting_iter):
            # estimate the noise levels z
            z = torch.log(self.get_r_hat(GP1,lik_1))
            # fit the noise z at train_x
            lik_2 = GaussianLikelihood()
            GP2 = ExactGPModel(self.train_x,z,lik_2)
            GP2, lik_2 = train_a_GP(GP2,self.train_x,z,lik_2,self.training_iter)
            # create a heteroscedastic GP
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                r_pred = lik_2(GP2(self.train_x))
            r = torch.exp(r_pred.mean)
            lik_3 = FixedNoiseGaussianLikelihood(noise=r, learn_additional_noise=False) 
            GP3 = ExactGPModel(self.train_x,self.train_y,lik_3)
            GP3, lik_3 = train_a_GP(GP3,self.train_x,self.train_y,lik_3,self.training_iter)
            GP1 = GP3
            lik_1 = lik_3
            
        self.final_GP = GP3
        self.final_lik = lik_3
        self.final_r = r


    def get_r_hat(self,GP,likelihood):
        """
        Estimate variance at each training point.

        :param GP: GP model that predicts the mean in the heteroscedastic GP model.
        :return: tensor r_hat with the estimated variances
        """
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            train_pred = likelihood(GP(self.train_x))
        r_hat = torch.sum(0.5*(self.train_y.reshape(1,-1) - train_pred.sample_n(self.var_estimator_n))**2,dim=0)/self.var_estimator_n
        return r_hat

@mgarort
Copy link
Author

mgarort commented Sep 12, 2019

Hi @Balandat and @jacobrgardner

I have tried to rewrite my heteroscedastic GP using the wrapper FixedNoiseGP from Botorch instead of FixedNoiseGaussianLikelihood and I obtain the same result (negative variance).

My code for the heteroscedastic GP with FixedNoiseGP is the following.

Thanks a lot in advance.

import torch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.constraints import Positive

from botorch.models.gp_regression import FixedNoiseGP

def train_a_GP(model, train_x, train_y, likelihood, training_iter):
    """
    Simple utility function to train a Gaussian process (GP) model with Adam (following the examples on the docs).

    :param model: GP model
    :param train_x: tensor with training features X
    :param train_y: tensor with training targets Y
    :param likelihood: likelihood function
    :param training_iter: number of iterations to train
    :return: trained GP model, trained likelihood
    """
    # train GP_model for training_iter iterations
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam([
        {'params': model.parameters()},  # Includes GaussianLikelihood parameters
    ], lr=0.1)

    # "Loss" for GPs - the marginal log likelihood
    mll = ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
        ))
        optimizer.step()

        model.eval()
        likelihood.eval()
    return model, likelihood


class ExactGPModel(gpytorch.models.ExactGP):
    """
    Exact Gaussian process model (following the examples in the docs).
    """
    def __init__(self, train_x, train_y, likelihood):
        """
        Initializer function. Specifies the mean and the covariance functions.

        :param train_x: tensor with training features X
        :param train_y: tensor with training targets Y
        :param likelihood: likelihood function
        :return: None
        """
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel())

    def forward(self, x):
        """
        Forward method to evaluate GP.

        :param x: tensor with features X on which to evaluate the GP.
        :return: MultivariateNormal
        """
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


class hetGPModel():
    """
    Most likely heteroscedastic GP model.
    """
    def __init__(self,train_x,train_y,training_iter=100,het_fitting_iter=10,var_estimator_n=50):
        """
        Initializer function.

        :param train_x: tensor with training features X
        :param train_y: tensor with training targets Y
        :param training_iter: number of iterations to train GP1, GP2 and GP3
        :param het_fitting_iter: number of iterations to run the pseudo expectation maximization (EM) algorithm while refining GP3
        :param var_estimator_n: number of samples to estimate the variance at each training point
        :return: None
        """
        self.train_x = train_x
        self.train_y = train_y
        self.training_iter = training_iter
        self.var_estimator_n = var_estimator_n
        self.het_fitting_iter = het_fitting_iter
        self.final_GP = None
        self.final_lik = None
        self.final_r_func = None

    def predict(self,x):
        """
        Predict method to evaluate GP.

        :param x: tensor with features X on which to evaluate the GP.
        :return: MultivariateNormal
        """
        if self.final_GP is None:
            raise RuntimeError('hetGPModel needs to be trained before using it')
        return self.final_GP(x)

    def train_model(self):
        """
        Train most likely heteroscedastic GP, in which one GP predicts the mean and another GP predicts the variance. This function
        corresponds to section '4. Optimization' in the original most likely heteroscedastic GP paper (Kersting et al. 2007).

        :return: None
        """
        # train self.GP1 if self.is_GP1_trained == False, and then set it to True. Otherwise ignore
        lik_1 = GaussianLikelihood()
        GP1 = ExactGPModel(self.train_x,self.train_y,lik_1)
        GP1, lik_1 = train_a_GP(GP1,self.train_x,self.train_y,lik_1,self.training_iter)
        for i in range(self.het_fitting_iter):
            # estimate the noise levels z
            z = torch.log(self.get_r_hat(GP1,lik_1))
            # fit the noise z at train_x
            lik_2 = GaussianLikelihood()
            GP2 = ExactGPModel(self.train_x,z,lik_2)
            GP2, lik_2 = train_a_GP(GP2,self.train_x,z,lik_2,self.training_iter)
            # create a heteroscedastic GP
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                r_func_pred = lik_2(GP2(self.train_x))
            r_func = torch.exp(r_func_pred.mean)
            lik_3 = GaussianLikelihood()
            #import pdb; pdb.set_trace()
            GP3 = FixedNoiseGP(self.train_x.reshape(-1,1),self.train_y.reshape(-1,1),r_func.reshape(-1,1))
            GP3, lik_3 = train_a_GP(GP3,self.train_x,self.train_y,lik_3,self.training_iter)
            GP1 = GP3
            lik_1 = lik_3
            
        self.final_GP = GP3
        self.final_lik = lik_3
        self.final_r_func = r_func


    def get_r_hat(self,GP,likelihood):
        """
        Estimate variance at each training point.

        :param GP: GP model that predicts the mean in the heteroscedastic GP model.
        :return: tensor r_hat with the estimated variances
        """
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            train_pred = likelihood(GP(self.train_x))
        r_hat = torch.sum(0.5*(self.train_y.reshape(1,-1) - train_pred.sample_n(self.var_estimator_n))**2,dim=0)/self.var_estimator_n
        return r_hat

@stanbiryukov
Copy link

Hey there - I had this problem a few weeks ago and solved it by adding a constraint to the the likelihood:
likelihood=gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1e-3)).to(device)

or for fixednoise:
gpytorch.likelihoods.FixedNoiseGaussianLikelihood(noise=noise, noise_constraint=gpytorch.constraints.GreaterThan(1e-3))

@mgarort
Copy link
Author

mgarort commented Sep 12, 2019

Hi @stanbiryukov

Thanks for the tip. Unfortunately I've tried FixedNoiseGaussianLikelihood(noise=noise, noise_constraint=gpytorch.constraints.GreaterThan(1e-3)) and variances are still negative :/

@jacobrgardner
Copy link
Member

@mgarort I need sample data that produces this issue. When I run your code with toy data:

train_x = torch.linspace(0, 6, 100)
train_y = torch.sin(train_x) + 0.01 * torch.randn(100)

I get positive variances.

@mgarort
Copy link
Author

mgarort commented Sep 14, 2019

Hi @jacobrgardner

I attached sample data and a script that results in negative variance in a zip a few posts back, but I guess it has ended up hidden under the subsequent posts!

Here's the zip again. You just need to run the script in the same folder as x.txt, y.txt and r.txt
reproduce_negative_variance.zip

Thanks,

Miguel

@jacobrgardner
Copy link
Member

@mgarort Okay, I'm pretty sure I know what's going on here. It's actually pretty technical.

Basically, for fast predictive variances we decompose (K+\sigma^{2} I)^{-1} in a way that is fine because the added noise doesn't change the eigenvalue clustering, it only shifts the whole spectrum. In the heteroscedastic noise setting, this is violated in the sense that adding an arbitrary diagonal component does change the eigenvalue clustering. Turning off fast predictive variances gives positive variances.

To work around this, we could instead decompose K, and then use a QR decomposition to effectively get a root for K^{-1}. This will take a bit to implement. For now, is turning off fast_pred_var a feasible work around, or do you anticipate having too much data?

@jacobrgardner
Copy link
Member

cc/ @gpleiss on this one actually, since we've discussed decomposing K instead of K+\sigma^{2} I before for LOVE, but this is the first time we've had any motivation for it at all.

@gpleiss gpleiss self-assigned this Oct 10, 2019
@gpleiss gpleiss added the stability When models return NaNs and stuff label Oct 10, 2019
@mgarort
Copy link
Author

mgarort commented Oct 13, 2019

Hi @jacobrgardner

So sorry for the very delayed reply, a long holiday and a paper submission got on the way.

Turning off fast_pred_var would work great.

Thanks a lot again,

Miguel

@ishank-juneja
Copy link

ishank-juneja commented Nov 24, 2021

This unexpected behavior with using gpytorch.settings.fast_pred_var() still seems to persists (also the issue is open so assume it wasn't fixed) when I use it at test time with the below model, likelihood, and marginal log likelihood definitions for instance, I observe negative variances at test time.

    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=4))
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood1 = gpytorch.likelihoods.GaussianLikelihood()
model1 = ExactGPModel(train_x, train_y[:, 0], likelihood1)

likelihood2 = gpytorch.likelihoods.GaussianLikelihood()
model2 = ExactGPModel(train_x, train_y[:, 1], likelihood2)
# Collect the submodels in an IndependentMultiOutputGP, and the respective likelihoods in a MultiOutputLikelihood
model = gpytorch.models.IndependentModelList(model1, model2)
likelihood = gpytorch.likelihoods.LikelihoodList(model1.likelihood, model2.likelihood)

mll = gpytorch.mlls.SumMarginalLogLikelihood(likelihood, model)

I gather from this thread that the problem is challenging to fix so I would encourage the maintainers to modify the code under the gpytorch tutorials and examples to not use fast_pred_var() without discussing pitfalls (or maybe not use it at all in the examples).

For instance I went through the tutorials/example code for exactGP regression, Multi Task Kernel, and independent model lists all 3 of which use gpytorch.settings.fast_pred_var() in their testing/inference code snippet so I gathered that the statement is just something that we always use in the gpytorch test-time idiom.

@wjmaddox
Copy link
Collaborator

wjmaddox commented Nov 24, 2021

Do you mind sharing a reproducible example of this behavior?

I wasn't immediately able to produce significant differences when using the model definitions above.

@ishank-juneja
Copy link

ishank-juneja commented Nov 24, 2021

Would a colab notebook work?

Also interestingly, the negative variances under gpytorch.settings.fast_pred_var() change to reasonable (seemingly accurate) positive values if I download this very notebook as a .py file and run on my local venv (both colab and my venv have identical gpytorch and torch versions). I am not sure what the reason for this is, but it could have to do with-

  1. Randomness in my data generation process: Unlikely because I have fixed random seeds and the training loss values on local and on colab were identical.
  2. Hardware optimizations used by gpytorch interacting with my local hardware vs. the colab machines: This is what is likely happening but I don't know enough to be more specific.

@ishank-juneja
Copy link

Another related Issue #1840

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug stability When models return NaNs and stuff
Projects
None yet
Development

No branches or pull requests

7 participants