-
Notifications
You must be signed in to change notification settings - Fork 0
/
ANPRNN_test.py
76 lines (62 loc) · 3.01 KB
/
ANPRNN_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from data.GP_data_sampler import GPCurvesReader
from module.ANP_RNN import RecurrentAttentiveNeuralProcess as ANPRNN
from module.utils import compute_loss, comput_kl_loss, to_numpy, load_plot_data
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import os
def testing(data_test, model, test_batch = 64):
total_ll = 0
model.eval()
for i in tqdm(range(test_batch)):
data = data_test.generate_curves(include_context=False, sort=True)
(x_context, y_context), x_target = data.query
(mean, var), _, _ = model(x_context.to(device), y_context.to(device), x_target.to(device))
loss = compute_loss(mean, var, data.y_target.to(device))
total_ll += -loss.item()
return total_ll/(i+1)
def plot_sample(dataset, model):
ax, fig = plt.subplots()
# load test data set
data = dataset.generate_curves(include_context=False, sort=True)
(x_context, y_context), x_target = data.query
x_grid = torch.arange(-2, 2, 0.01)[None, :, None].repeat([x_context.shape[0], 1, 1]).to(device)
(mean, var), _, _ = model(x_context.to(device), y_context.to(device), x_grid.to(device))
# plot scatter:
plt.scatter(to_numpy(x_context[0]), to_numpy(y_context[0]), label = 'context points', c = 'red', s = 15)
# plot sampled function:
plt.scatter(to_numpy(x_target[0]), to_numpy(data.y_target[0]), label = 'target points', marker='x', color = 'k')
# plot predicted function:
plt.plot(to_numpy(x_grid[0]), to_numpy(mean[0]), label = MODELNAME + ' predicted mean', c = 'blue')
# mu +/- 1.97* sigma: 97.5% confidence
plt.fill_between(to_numpy(x_grid[0,:,0]), to_numpy(mean[0,:,0] - 1.97*var[0,:,0]), to_numpy(mean[0, :, 0] + 1.97*var[0, :, 0]), color ='blue', alpha = 0.15)
plt.legend()
# plt.savefig(MODELNAME+".png")
plt.show()
return fig
if __name__ == '__main__':
# define hyper parameters
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
TESTING_ITERATIONS = int(1024)
MAX_CONTEXT_POINT = 50
MODELNAME = 'ANP_RNN' # 'NP' or 'ANP'
kernel = 'EQ' # EQ or period
# set up tensorboard
# load data set
dataset = GPCurvesReader(kernel=kernel, batch_size=64, max_num_context= MAX_CONTEXT_POINT, device=device)
# load module parameters
anprnn = ANPRNN(input_dim=1, latent_dim = 32, output_dim=1, use_attention=True, use_rnn=True).to(device)
anprnn.load_state_dict(torch.load('saved_model/'+kernel+'_' + MODELNAME+'.pt'))
print("successfully load %s module!" %MODELNAME)
# total_loss = []
# for _ in range(10):
# test_loss = testing(dataset, np, TESTING_ITERATIONS)
# total_loss.append(test_loss)
# print("for 10 runs, mean: %.4f, std:%.4f" % (numpy.mean(total_loss), numpy.std(total_loss)))
# test_loss = testing(dataset, anprnn, TESTING_ITERATIONS)
# print ("loglikelihood on 1024 samples: %.4f"%(test_loss))
fig = plot_sample(dataset, anprnn)
print("save plots!")