-
Notifications
You must be signed in to change notification settings - Fork 9
/
policy.py
82 lines (63 loc) · 2.53 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import sde
import util
from ipdb import set_trace as debug
def build(opt, dyn, direction):
print(util.magenta("build {} policy...".format(direction)))
net_name = getattr(opt, direction+'_net')
net = _build_net(opt, net_name)
use_t_idx = (net_name in ['toy', 'Unet', 'DGLSB']) # t_idx is handled internally in ncsnpp
scale_by_g = (net_name in ['ncsnpp'])
policy = SchrodingerBridgePolicy(
opt, direction, dyn, net, use_t_idx=use_t_idx, scale_by_g=scale_by_g
)
print(util.red('number of parameters is {}'.format(util.count_parameters(policy))))
policy.to(opt.device)
return policy
def _build_net(opt, net_name):
compute_sigma = lambda t: sde.compute_sigmas(t, opt.sigma_min, opt.sigma_max)
zero_out_last_layer = opt.DSM_warmup
if net_name == 'toy':
assert util.is_toy_dataset(opt)
from models.toy_model.Toy import build_toy
net = build_toy(zero_out_last_layer)
elif net_name == 'Unet':
from models.Unet.Unet import build_unet
net = build_unet(opt.model_configs[net_name], zero_out_last_layer)
elif net_name == 'ncsnpp':
from models.ncsnpp.ncsnpp import build_ncsnpp
net = build_ncsnpp(opt.model_configs[net_name], compute_sigma, zero_out_last_layer)
elif net_name == 'DGLSB':
from models.DGLSB.dglsb import build_dglsb
net = build_dglsb(zero_out_last_layer)
else:
raise RuntimeError()
return net
class SchrodingerBridgePolicy(torch.nn.Module):
# note: scale_by_g matters only for pre-trained model
def __init__(self, opt, direction, dyn, net, use_t_idx=False, scale_by_g=True):
super(SchrodingerBridgePolicy,self).__init__()
self.opt = opt
self.direction = direction
self.dyn = dyn
self.net = net
self.use_t_idx = use_t_idx
self.scale_by_g = scale_by_g
@ property
def zero_out_last_layer(self):
return self.net.zero_out_last_layer
def forward(self, x, t):
# make sure t.shape = [batch]
t = t.squeeze()
if t.dim()==0: t = t.repeat(x.shape[0])
assert t.dim()==1 and t.shape[0] == x.shape[0]
if self.use_t_idx:
t = t / self.opt.T * self.opt.interval
out = self.net(x, t)
# if the SB policy behaves as "Z" in FBSDE system,
# the output should be scaled by the diffusion coefficient "g".
if self.scale_by_g:
g = self.dyn.g(t)
g = g.reshape(x.shape[0], *([1,]*(x.dim()-1)))
out = out * g
return out