-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrbmapplygrads.py
76 lines (63 loc) · 2.51 KB
/
rbmapplygrads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
# RBMAPPLYGRADS applies momentum and learningrate and updates rbm weights
# Internal function used by rbmtrain
#
# INPUT
# rbm : rbm struct
# opts : opts struct
# grads.dw : w weights chainge normalized by minibatch size
# grads.db : bias of visible layer weight change norm by minibatch size
# grads.dc : bias of hidden layer weight change norm by minibatch size
# grads.du : class label layer weight change norm by minibatch size
# grads.dd : class label hidden bias weight change norm by minibatch size
# x : current minibatch
# ey : if classRBM one hot encoded class labels otherwise empty
# epoch : current epoch number
#
# OUTPUT
# rbm : rbm struct with updated weights, LR and momentum
def rbmapplygrads(rbm, grads, x, ey, epoch):
dw = grads['dw']
db = grads['db']
dc = grads['dc']
dd = grads['dd']
du = grads['du']
# L2 regularization
if rbm.L2 > 0:
dw = dw - rbm.L2 * rbm.W
if rbm.classRBM == 1:
du = du - rbm.L2 * rbm.U
# L1 regularization
if rbm.L1 > 0:
dw = dw - rbm.L2 * np.sign(rbm.W)
if rbm.classRBM == 1:
du = du - rbm.L2 * np.sign(rbm.U)
if rbm.sparsity > 0:
db = db - rbm.sparsity
dc = dc - rbm.sparsity
if rbm.classRBM == 1:
dd = dd - rbm.sparsity
# update weights and momentum of regular weights
rbm.vW = rbm.curMomentum * rbm.vW + rbm.curLR * dw
if dc.shape[0] == 1:
rbm.vc = rbm.curMomentum * rbm.vc + rbm.curLR * np.transpose(dc)
else:
rbm.vc = rbm.curMomentum * rbm.vc + np.reshape(rbm.curLR * dc,(dc.shape[0],1))
if not(not db.any()):
rbm.vb = rbm.curMomentum * rbm.vb + np.reshape(rbm.curLR * db,(db.shape[0],1))
# bug fixed by: rbm.vb = rbm.curMomentum * rbm.vb + np.reshape(rbm.curLR * db,(db.shape[0],1))
rbm.W = rbm.W + rbm.vW
rbm.b = rbm.b + rbm.vb
rbm.c = rbm.c + rbm.vc
# if classRBM update weights and momentum of U and d
if rbm.classRBM == 1:
rbm.vU = rbm.curMomentum * rbm.vU + rbm.curLR * du
rbm.vd = rbm.curMomentum * rbm.vd + rbm.curLR * np.reshape(dd,(dd.shape[0],1))
# :o rbm.vd = rbm.curMomentum * rbm.vd + rbm.curLR * np.transpose(dd) - shape bug with np.transpose(dd)
rbm.U = rbm.U + rbm.vU
rbm.d = rbm.d + rbm.vd
# l2 norm constraint
if rbm.L2 > 0:
#rbm.W =
pass # todo: implement l2 norm
return rbm