forked from khanrc/pt.darts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchitect.py
106 lines (88 loc) · 3.58 KB
/
architect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
""" Architect controls architecture of cell by computing gradients of alphas """
import copy
import torch
class Architect():
""" Compute gradients of alphas """
def __init__(self, net, w_momentum, w_weight_decay):
"""
Args:
net
w_momentum: weights momentum
"""
self.net = net
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay
def virtual_step(self, trn_X, trn_y, xi, w_optim):
"""
Compute unrolled weight w' (virtual step)
Step process:
1) forward
2) calc loss
3) compute gradient (by backprop)
4) update gradient
Args:
xi: learning rate for virtual gradient step (same as weights lr)
w_optim: weights optimizer
"""
# make virtual model
v_net = copy.deepcopy(self.net)
# forward & calc loss
loss = v_net.loss(trn_X, trn_y) # L_trn(w)
# compute gradient
gradients = torch.autograd.grad(loss, v_net.weights())
# do virtual step (update gradient)
# below operations do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for rw, w, g in zip(self.net.weights(), v_net.weights(), gradients):
m = w_optim.state[rw].get('momentum_buffer', 0.) * self.w_momentum
w -= xi * (m + g + self.w_weight_decay*w)
return v_net
def unrolled_backward(self, trn_X, trn_y, val_X, val_y, xi, w_optim):
""" Compute unrolled loss and backward its gradients
Args:
xi: learning rate for virtual gradient step (same as net lr)
w_optim: weights optimizer - for virtual step
"""
# do virtual step (calc w`)
unrolled_net = self.virtual_step(trn_X, trn_y, xi, w_optim)
# calc unrolled loss
loss = unrolled_net.loss(val_X, val_y) # L_val(w`)
# compute gradient
loss.backward()
dalpha = [v.grad for v in unrolled_net.alphas()] # dalpha { L_val(w`, alpha) }
dw = [v.grad for v in unrolled_net.weights()] # dw` { L_val(w`, alpha) }
hessian = self.compute_hessian(dw, trn_X, trn_y)
# update final gradient = dalpha - xi*hessian
with torch.no_grad():
for alpha, da, h in zip(self.net.alphas(), dalpha, hessian):
alpha.grad = da - xi*h
def compute_hessian(self, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
# w+ = w + eps*dw`
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p += eps * d
loss = self.net.loss(trn_X, trn_y)
dalpha_pos = torch.autograd.grad(loss, self.net.alphas()) # dalpha { L_trn(w+) }
# w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p -= 2. * eps * d
loss = self.net.loss(trn_X, trn_y)
dalpha_neg = torch.autograd.grad(loss, self.net.alphas()) # dalpha { L_trn(w-) }
# recover w
with torch.no_grad():
for p, d in zip(self.net.weights(), dw):
p += eps * d
hessian = [(p-n) / 2.*eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian