-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FixedNoiseGaussianLikelihood results in negative variance #864
Comments
Hi @mgarort -- Nothing immediately jumps out at me as wrong. Would you possibly be able to provide example data for which this happens, as well as the definition |
Thanks for the quick reply. Here's a simple script that reproduces the error, together with the 3 files needed (x, y and noise r). |
If it helps, you can also take a look at our convenience wrapper model for observed variances in BoTorch: https://github.com/pytorch/botorch/blob/master/botorch/models/gp_regression.py#L118-L127 There is also a PR open for the most likely heteroskedastic GP fitting (which I'll probably get to in the near future): pytorch/botorch#250. Note that this uses a full heteroskedastic GP, where the noise model is itself a GP (rather than fixed variances). |
Hi @Balandat Thanks! I'll take a look at the wrapper model in BoTorch I think I have modelled the noise correctly according to the most likely heteroscedastic GP (although of course there's always room for error!). The vector Best, Miguel |
Here's my full attempt at an implementation of the most likely heteroscedastic GP. import torch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.constraints import Positive
def train_a_GP(model, train_x, train_y, likelihood, training_iter):
"""
Simple utility function to train a Gaussian process (GP) model with Adam (following the examples on the docs).
:param model: GP model
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param likelihood: likelihood function
:param training_iter: number of iterations to train
:return: trained GP model, trained likelihood
"""
# train GP_model for training_iter iterations
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam([
{'params': model.parameters()}, # Includes GaussianLikelihood parameters
], lr=0.1)
# "Loss" for GPs - the marginal log likelihood
mll = ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f lengthscale: %.3f' % (
i + 1, training_iter, loss.item(),
model.covar_module.base_kernel.lengthscale.item(),
))
optimizer.step()
model.eval()
likelihood.eval()
return model, likelihood
class ExactGPModel(gpytorch.models.ExactGP):
"""
Exact Gaussian process model (following the examples in the docs).
"""
def __init__(self, train_x, train_y, likelihood):
"""
Initializer function. Specifies the mean and the covariance functions.
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param likelihood: likelihood function
:return: None
"""
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(RBFKernel())
def forward(self, x):
"""
Forward method to evaluate GP.
:param x: tensor with features X on which to evaluate the GP.
:return: MultivariateNormal
"""
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class hetGPModel():
"""
Most likely heteroscedastic GP model.
"""
def __init__(self,train_x,train_y,training_iter=100,het_fitting_iter=10,var_estimator_n=50):
"""
Initializer function.
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param training_iter: number of iterations to train GP1, GP2 and GP3
:param het_fitting_iter: number of iterations to run the pseudo expectation maximization (EM) algorithm while refining GP3
:param var_estimator_n: number of samples to estimate the variance at each training point
:return: None
"""
self.train_x = train_x
self.train_y = train_y
self.training_iter = training_iter
self.var_estimator_n = var_estimator_n
self.het_fitting_iter = het_fitting_iter
self.final_GP = None
self.final_lik = None
self.final_r = None
def predict(self,x):
"""
Predict method to evaluate GP.
:param x: tensor with features X on which to evaluate the GP.
:return: MultivariateNormal
"""
if self.final_GP is None:
raise RuntimeError('hetGPModel needs to be trained before using it')
return self.final_GP(x)
def train_model(self):
"""
Train most likely heteroscedastic GP, in which one GP predicts the mean and another GP predicts the variance. This function
corresponds to section '4. Optimization' in the original most likely heteroscedastic GP paper (Kersting et al. 2007).
:return: None
"""
# train self.GP1 if self.is_GP1_trained == False, and then set it to True. Otherwise ignore
lik_1 = GaussianLikelihood()
GP1 = ExactGPModel(self.train_x,self.train_y,lik_1)
GP1, lik_1 = train_a_GP(GP1,self.train_x,self.train_y,lik_1,self.training_iter)
for i in range(self.het_fitting_iter):
# estimate the noise levels z
z = torch.log(self.get_r_hat(GP1,lik_1))
# fit the noise z at train_x
lik_2 = GaussianLikelihood()
GP2 = ExactGPModel(self.train_x,z,lik_2)
GP2, lik_2 = train_a_GP(GP2,self.train_x,z,lik_2,self.training_iter)
# create a heteroscedastic GP
with torch.no_grad(), gpytorch.settings.fast_pred_var():
r_pred = lik_2(GP2(self.train_x))
r = torch.exp(r_pred.mean)
lik_3 = FixedNoiseGaussianLikelihood(noise=r, learn_additional_noise=False)
GP3 = ExactGPModel(self.train_x,self.train_y,lik_3)
GP3, lik_3 = train_a_GP(GP3,self.train_x,self.train_y,lik_3,self.training_iter)
GP1 = GP3
lik_1 = lik_3
self.final_GP = GP3
self.final_lik = lik_3
self.final_r = r
def get_r_hat(self,GP,likelihood):
"""
Estimate variance at each training point.
:param GP: GP model that predicts the mean in the heteroscedastic GP model.
:return: tensor r_hat with the estimated variances
"""
with torch.no_grad(), gpytorch.settings.fast_pred_var():
train_pred = likelihood(GP(self.train_x))
r_hat = torch.sum(0.5*(self.train_y.reshape(1,-1) - train_pred.sample_n(self.var_estimator_n))**2,dim=0)/self.var_estimator_n
return r_hat |
Hi @Balandat and @jacobrgardner I have tried to rewrite my heteroscedastic GP using the wrapper My code for the heteroscedastic GP with Thanks a lot in advance. import torch
import gpytorch
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.constraints import Positive
from botorch.models.gp_regression import FixedNoiseGP
def train_a_GP(model, train_x, train_y, likelihood, training_iter):
"""
Simple utility function to train a Gaussian process (GP) model with Adam (following the examples on the docs).
:param model: GP model
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param likelihood: likelihood function
:param training_iter: number of iterations to train
:return: trained GP model, trained likelihood
"""
# train GP_model for training_iter iterations
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam([
{'params': model.parameters()}, # Includes GaussianLikelihood parameters
], lr=0.1)
# "Loss" for GPs - the marginal log likelihood
mll = ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f lengthscale: %.3f' % (
i + 1, training_iter, loss.item(),
model.covar_module.base_kernel.lengthscale.item(),
))
optimizer.step()
model.eval()
likelihood.eval()
return model, likelihood
class ExactGPModel(gpytorch.models.ExactGP):
"""
Exact Gaussian process model (following the examples in the docs).
"""
def __init__(self, train_x, train_y, likelihood):
"""
Initializer function. Specifies the mean and the covariance functions.
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param likelihood: likelihood function
:return: None
"""
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(RBFKernel())
def forward(self, x):
"""
Forward method to evaluate GP.
:param x: tensor with features X on which to evaluate the GP.
:return: MultivariateNormal
"""
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class hetGPModel():
"""
Most likely heteroscedastic GP model.
"""
def __init__(self,train_x,train_y,training_iter=100,het_fitting_iter=10,var_estimator_n=50):
"""
Initializer function.
:param train_x: tensor with training features X
:param train_y: tensor with training targets Y
:param training_iter: number of iterations to train GP1, GP2 and GP3
:param het_fitting_iter: number of iterations to run the pseudo expectation maximization (EM) algorithm while refining GP3
:param var_estimator_n: number of samples to estimate the variance at each training point
:return: None
"""
self.train_x = train_x
self.train_y = train_y
self.training_iter = training_iter
self.var_estimator_n = var_estimator_n
self.het_fitting_iter = het_fitting_iter
self.final_GP = None
self.final_lik = None
self.final_r_func = None
def predict(self,x):
"""
Predict method to evaluate GP.
:param x: tensor with features X on which to evaluate the GP.
:return: MultivariateNormal
"""
if self.final_GP is None:
raise RuntimeError('hetGPModel needs to be trained before using it')
return self.final_GP(x)
def train_model(self):
"""
Train most likely heteroscedastic GP, in which one GP predicts the mean and another GP predicts the variance. This function
corresponds to section '4. Optimization' in the original most likely heteroscedastic GP paper (Kersting et al. 2007).
:return: None
"""
# train self.GP1 if self.is_GP1_trained == False, and then set it to True. Otherwise ignore
lik_1 = GaussianLikelihood()
GP1 = ExactGPModel(self.train_x,self.train_y,lik_1)
GP1, lik_1 = train_a_GP(GP1,self.train_x,self.train_y,lik_1,self.training_iter)
for i in range(self.het_fitting_iter):
# estimate the noise levels z
z = torch.log(self.get_r_hat(GP1,lik_1))
# fit the noise z at train_x
lik_2 = GaussianLikelihood()
GP2 = ExactGPModel(self.train_x,z,lik_2)
GP2, lik_2 = train_a_GP(GP2,self.train_x,z,lik_2,self.training_iter)
# create a heteroscedastic GP
with torch.no_grad(), gpytorch.settings.fast_pred_var():
r_func_pred = lik_2(GP2(self.train_x))
r_func = torch.exp(r_func_pred.mean)
lik_3 = GaussianLikelihood()
#import pdb; pdb.set_trace()
GP3 = FixedNoiseGP(self.train_x.reshape(-1,1),self.train_y.reshape(-1,1),r_func.reshape(-1,1))
GP3, lik_3 = train_a_GP(GP3,self.train_x,self.train_y,lik_3,self.training_iter)
GP1 = GP3
lik_1 = lik_3
self.final_GP = GP3
self.final_lik = lik_3
self.final_r_func = r_func
def get_r_hat(self,GP,likelihood):
"""
Estimate variance at each training point.
:param GP: GP model that predicts the mean in the heteroscedastic GP model.
:return: tensor r_hat with the estimated variances
"""
with torch.no_grad(), gpytorch.settings.fast_pred_var():
train_pred = likelihood(GP(self.train_x))
r_hat = torch.sum(0.5*(self.train_y.reshape(1,-1) - train_pred.sample_n(self.var_estimator_n))**2,dim=0)/self.var_estimator_n
return r_hat |
Hey there - I had this problem a few weeks ago and solved it by adding a constraint to the the likelihood: or for fixednoise: |
Thanks for the tip. Unfortunately I've tried |
@mgarort I need sample data that produces this issue. When I run your code with toy data: train_x = torch.linspace(0, 6, 100)
train_y = torch.sin(train_x) + 0.01 * torch.randn(100) I get positive variances. |
I attached sample data and a script that results in negative variance in a zip a few posts back, but I guess it has ended up hidden under the subsequent posts! Here's the zip again. You just need to run the script in the same folder as Thanks, Miguel |
@mgarort Okay, I'm pretty sure I know what's going on here. It's actually pretty technical. Basically, for fast predictive variances we decompose To work around this, we could instead decompose |
cc/ @gpleiss on this one actually, since we've discussed decomposing |
So sorry for the very delayed reply, a long holiday and a paper submission got on the way. Turning off Thanks a lot again, Miguel |
This unexpected behavior with using
I gather from this thread that the problem is challenging to fix so I would encourage the maintainers to modify the code under the gpytorch tutorials and examples to not use For instance I went through the tutorials/example code for exactGP regression, Multi Task Kernel, and independent model lists all 3 of which use |
Do you mind sharing a reproducible example of this behavior? I wasn't immediately able to produce significant differences when using the model definitions above. |
Would a colab notebook work? Also interestingly, the negative variances under
|
Another related Issue #1840 |
Hi,
I am trying to train a most likely heteroscedastic GP (from "Most likely heteroscedastic GP regression", Kersting et al. 2007). To this end I am using the likelihood FixedNoiseGaussianLikelihood. I am setting the noise to positive values
r
.where
train_a_GP
is simply the following training function (copied from a GPytorch regression tutorial):However when I try to obtain predictions, the variance of the
MultivariateNormal
returned seems to be negative.What am I doing wrong? Any help would be greatly appreciated.
Thanks a lot!
Miguel
The text was updated successfully, but these errors were encountered: