-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
46 lines (35 loc) · 1.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gpytorch
import numpy as np
from rich.console import Console
import torch
from gpytorch.models import ExactGP
# extend the gptorch.models.ExactGP class
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(
self, train_x, train_y, likelihood, kernel=gpytorch.kernels.RQKernel()
):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RQKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
# add some noise to the covraiance matrix for numerical stability
covar_x = covar_x.add_jitter(1e-3)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
@classmethod
def load_from_file(cls, file_path):
# Load the dictionary from the file
checkpoint = torch.load(file_path)
# Retrieve the necessary components from the state dictionary
# Assuming that the state dictionary includes the kernel type and parameters
kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RQKernel())
likelihood = gpytorch.likelihoods.GaussianLikelihood()
# Restore the training inputs
train_x = torch.tensor(checkpoint["train_x"],dtype=torch.float32, device='cuda')
train_y = torch.tensor(checkpoint["train_y"],dtype=torch.float32, device='cuda')
# Create an instance of the model with placeholder data
model = cls(train_x, train_y, likelihood, kernel)
# Load the state dictionary into the model
model.load_state_dict(checkpoint["model"])
return model