-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrnn.gru.go
90 lines (71 loc) · 2.13 KB
/
rnn.gru.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package gortex
import "fmt"
import "github.com/vseledkin/gortex/assembler"
// Gated recurrent unit
type GRU struct {
Wz *Matrix
Uz *Matrix
Bz *Matrix
Wr *Matrix
Ur *Matrix
Br *Matrix
Wh *Matrix
Uh *Matrix
Bh *Matrix
Who *Matrix
}
func (gru *GRU) ForgetGateTrick(v float32) {
if gru.Bz != nil {
assembler.Sset(v, gru.Bz.W)
}
}
func MakeGRU(x_size, h_size, out_size int) *GRU {
rnn := new(GRU)
rnn.Wz = RandXavierMat(h_size, x_size)
rnn.Uz = RandXavierMat(h_size, h_size)
rnn.Bz = RandXavierMat(h_size, 1)
rnn.Wr = RandXavierMat(h_size, x_size)
rnn.Ur = RandXavierMat(h_size, h_size)
rnn.Br = RandXavierMat(h_size, 1)
rnn.Wh = RandXavierMat(h_size, x_size)
rnn.Uh = RandXavierMat(h_size, h_size)
rnn.Bh = RandXavierMat(h_size, 1)
rnn.Who = RandXavierMat(out_size, h_size)
return rnn
}
func (rnn *GRU) GetParameters(namespace string) map[string]*Matrix {
return map[string]*Matrix{
namespace + "_Wz": rnn.Wz,
namespace + "_Uz": rnn.Uz,
namespace + "_Bz": rnn.Bz,
namespace + "_Wr": rnn.Wr,
namespace + "_Ur": rnn.Ur,
namespace + "_Br": rnn.Br,
namespace + "_Wh": rnn.Wh,
namespace + "_Uh": rnn.Uh,
namespace + "_Bh": rnn.Bh,
namespace + "_Who": rnn.Who,
}
}
func (rnn *GRU) SetParameters(namespace string, parameters map[string]*Matrix) error {
for k, v := range rnn.GetParameters(namespace) {
fmt.Printf("Look for %s parameters\n", k)
if m, ok := parameters[k]; ok {
fmt.Printf("Got %s parameters\n", k)
copy(v.W, m.W)
} else {
return fmt.Errorf("Model geometry is not compatible, parameter %s is unknown", k)
}
}
return nil
}
func (rnn *GRU) Step(g *Graph, x, h_prev *Matrix) (h, y *Matrix) {
// make GRU computation graph at one time-step
zt := g.Sigmoid(g.Add(g.Add(g.Mul(rnn.Wz, x), g.Mul(rnn.Uz, h_prev)), rnn.Bz))
rt := g.Sigmoid(g.Add(g.Add(g.Mul(rnn.Wr, x), g.Mul(rnn.Ur, h_prev)), rnn.Br))
ht := g.Tanh(g.Add(g.Add(g.Mul(rnn.Wh, x), g.Mul(rnn.Uh, g.EMul(rt, h_prev))), rnn.Bh))
//h = g.InstanceNormalization(g.Add(g.EMul(zt, h_prev), g.EMul(g.Sub(zt.OnesAs(), zt), ht)))
h = g.Add(g.EMul(zt, h_prev), g.EMul(g.Sub(zt.OnesAs(), zt), ht))
y = g.Mul(rnn.Who, g.Tanh(h))
return
}