-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpcc_model.py
executable file
·143 lines (116 loc) · 4.82 KB
/
pcc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from networks import load_config
from torch import nn
torch.set_default_dtype(torch.float64)
# torch.manual_seed(0)
class PCC(nn.Module):
def __init__(self, armotized, x_dim, z_dim, u_dim, env="planar"):
super(PCC, self).__init__()
enc, dec, dyn, back_dyn = load_config(env)
self.x_dim = x_dim
self.z_dim = z_dim
self.u_dim = u_dim
self.armotized = armotized
self.encoder = enc(x_dim, z_dim)
self.decoder = dec(z_dim, x_dim)
self.dynamics = dyn(armotized, z_dim, u_dim)
self.backward_dynamics = back_dyn(z_dim, u_dim, x_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
def transition(self, z, u):
return self.dynamics(z, u)
def back_dynamics(self, z, u, x):
return self.backward_dynamics(z, u, x)
def reparam(self, mean, std):
# sigma = (logvar / 2).exp()
epsilon = torch.randn_like(std)
return mean + torch.mul(epsilon, std)
def forward(self, x, u, x_next):
# prediction and consistency loss
# 1st term and 3rd
q_z_next = self.encode(x_next) # Q(z^_t+1 | x_t+1)
z_next = self.reparam(q_z_next.mean, q_z_next.stddev) # sample z^_t+1
p_x_next = self.decode(z_next) # P(x_t+1 | z^t_t+1)
# 2nd term
q_z_backward = self.back_dynamics(z_next, u, x) # Q(z_t | z^_t+1, u_t, x_t)
p_z = self.encode(x) # P(z_t | x_t)
# 4th term
z_q = self.reparam(q_z_backward.mean, q_z_backward.stddev) # samples from Q(z_t | z^_t+1, u_t, x_t)
p_z_next, _, _ = self.transition(z_q, u) # P(z^_t+1 | z_t, u _t)
# additional VAE loss
z_p = self.reparam(p_z_next.mean, p_z_next.stddev) # samples from P(z_t | x_t)
p_x = self.decode(z_p) # for additional vae loss
# additional deterministic loss
mu_z_next_determ = self.transition(p_z.mean, u)[0].mean
p_x_next_determ = self.decode(mu_z_next_determ)
return p_x_next, q_z_backward, p_z, q_z_next, z_next, p_z_next, z_p, u, p_x, p_x_next_determ
def predict(self, x, u):
mu, logvar = self.encoder(x)
z = self.reparam(mu, logvar)
x_recon = self.decode(z)
mu_next, logvar_next, A, B = self.transition(z, u)
z_next = self.reparam(mu_next, logvar_next)
x_next_pred = self.decode(z_next)
return x_recon, x_next_pred
# def reparam(mean, logvar):
# sigma = (logvar / 2).exp()
# epsilon = torch.randn_like(sigma)
# return mean + torch.mul(epsilon, sigma)
# def jacobian_1(dynamics, z, u):
# """
# compute the jacobian of F(z,u) w.r.t z, u
# """
# z_dim, u_dim = z.size(1), u.size(1)
# z, u = z.squeeze().repeat(z_dim, 1), u.squeeze().repeat(z_dim, 1)
# z = z.detach().requires_grad_(True)
# u = u.detach().requires_grad_(True)
# z_next, _, _, _ = dynamics(z, u)
# grad_inp = torch.eye(z_dim)
# A = torch.autograd.grad(z_next, z, grad_inp, retain_graph=True)[0]
# B = torch.autograd.grad(z_next, u, grad_inp, retain_graph=True)[0]
# return A, B
# def jacobian_2(dynamics, z, u):
# """
# compute the jacobian of F(z,u) w.r.t z, u
# """
# z_dim, u_dim = z.size(1), u.size(1)
# z = z.detach().requires_grad_(True)
# u = u.detach().requires_grad_(True)
# z_next, _, _, _ = dynamics(z, u)
# A = torch.empty(size=(z_dim, z_dim))
# B = torch.empty(size=(z_dim, u_dim))
# for i in range(A.size(0)): # for each row
# grad_inp = torch.zeros(size=(1, A.size(0)))
# grad_inp[0][i] = 1
# A[i] = torch.autograd.grad(z_next, z, grad_inp, retain_graph=True)[0]
# for i in range(B.size(0)): # for each row
# grad_inp = torch.zeros(size=(1, B.size(0)))
# grad_inp[0][i] = 1
# B[i] = torch.autograd.grad(z_next, u, grad_inp, retain_graph=True)[0]
# return A, B
# enc, dec, dyn, back_dyn = load_config('planar')
# dynamics = dyn(armotized=False, z_dim=2, u_dim=2)
# dynamics.eval()
# import torch.optim as optim
# optimizer = optim.Adam(dynamics.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=0.001)
# z = torch.randn(size=(1, 2))
# z.requires_grad = True
# u = torch.randn(size=(1, 2))
# u.requires_grad = True
# eps_z = torch.normal(0.0, 0.1, size=z.size())
# eps_u = torch.normal(0.0, 0.1, size=u.size())
# mean, logvar, _, _ = dynamics(z, u)
# grad_z = torch.autograd.grad(mean, z, grad_outputs=eps_z, retain_graph=True, create_graph=True)
# grad_u = torch.autograd.grad(mean, u, grad_outputs=eps_u, retain_graph=True, create_graph=True)
# print ('AAA')
# print (grad_z, grad_u)
# A, B = jacobian_1(dynamics, z, u)
# grad_z, grad_u = eps_z.mm(A), eps_u.mm(B)
# print ('BBBB')
# print (grad_z, grad_u)
# A, B = jacobian_1(dynamics, z, u)
# grad_z, grad_u = eps_z.mm(A), eps_u.mm(B)
# print ('BBBB')
# print (grad_z, grad_u)