-
Notifications
You must be signed in to change notification settings - Fork 2
/
example_autograd.py
75 lines (63 loc) · 2.32 KB
/
example_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import matplotlib.pyplot as plt
import numpy as np
import wbml.out
import wbml.plot
import stheno.autograd
import lab.autograd as B
from varz.autograd import Vars, minimise_l_bfgs_b
from olmm import model, objective, predict, project
from data import load
# Load the data, which are Pandas data frames.
locs, data = load()
# Convert to NumPy.
locs = locs.to_numpy()
x_data = data.index.to_numpy()[:, None]
y_data = data.to_numpy()
# Inputs for two-months ahead predictions.
x_pred = np.arange(1, x_data.max() + 60, dtype=np.float64)[:, None]
# Normalise data.
data_mean = np.mean(y_data, axis=0, keepdims=True)
data_scale = np.std(y_data, axis=0, keepdims=True)
y_data_norm = (y_data - data_scale) / data_mean
# Model parameters:
n = data.shape[0] # Number of data points
p = data.shape[1] # Number of outputs
m = 10 # Number of latent processes
# Learn.
vs = Vars(np.float64)
minimise_l_bfgs_b(lambda vs_: objective(vs_, m, x_data, y_data_norm, locs),
vs=vs,
trace=True,
iters=200)
wbml.out.kv('Learned spatial scales', vs['scales'])
# Predict.
lat_preds, obs_preds = predict(vs, m, x_data, y_data_norm, locs, x_pred)
# Undo normalisation.
obs_preds = [tuple(x * data_mean[0, i] + data_scale[0, i] for x in tup)
for i, tup in enumerate(obs_preds)]
# Plot first four latent processes.
plt.figure(figsize=(15, 5))
y_proj, _, S, _ = B.to_numpy(project(vs, m, y_data_norm, locs))
xs, _, _ = model(vs, m)
for i in range(4):
plt.subplot(2, 2, i + 1)
mean, lower, upper = lat_preds[i]
plt.title(f'Latent Process {i + 1} (${100 * S[i] / np.sum(S):.1f}\\%$) \n'
f'{xs[i].display(wbml.out.format)}')
plt.plot(x_data, y_proj[i], c='tab:blue')
plt.plot(x_pred, mean, c='tab:green')
plt.plot(x_pred, lower, c='tab:green', ls='--')
plt.plot(x_pred, upper, c='tab:green', ls='--')
wbml.plot.tweak(legend=False)
# Plot four random outputs.
plt.figure(figsize=(10, 5))
for i, j in enumerate(sorted(np.random.permutation(p)[:4])):
plt.subplot(2, 2, i + 1)
mean, lower, upper = obs_preds[j]
plt.title(data.columns[j])
plt.plot(x_data, y_data[:, j], c='tab:blue')
plt.plot(x_pred, mean, c='tab:green')
plt.plot(x_pred, lower, c='tab:green', ls='--')
plt.plot(x_pred, upper, c='tab:green', ls='--')
wbml.plot.tweak(legend=False)
plt.show()