-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrnn.vanilla.outputless.go
45 lines (37 loc) · 1.15 KB
/
rnn.vanilla.outputless.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package gortex
import "fmt"
type OutputlessRNN struct {
Wxh *Matrix
Whh *Matrix
Bias *Matrix
}
func MakeOutputlessRNN(x_size, h_size int) *OutputlessRNN {
net := new(OutputlessRNN)
net.Wxh = RandXavierMat(h_size, x_size)
net.Whh = RandXavierMat(h_size, h_size)
net.Bias = RandXavierMat(h_size, 1)
return net
}
func (rnn *OutputlessRNN) GetParameters(namespace string) map[string]*Matrix {
return map[string]*Matrix{
namespace + "_Wxh": rnn.Wxh,
namespace + "_Whh": rnn.Whh, namespace + "_Bias": rnn.Bias}
}
func (rnn *OutputlessRNN) SetParameters(namespace string, parameters map[string]*Matrix) error {
for k, v := range rnn.GetParameters(namespace) {
fmt.Printf("Look for %s parameters\n", k)
if m, ok := parameters[k]; ok {
fmt.Printf("Got %s parameters\n", k)
v.W = m.W
} else {
return fmt.Errorf("Model geometry is not compatible, parameter %s is unknown", k)
}
}
return nil
}
func (rnn *OutputlessRNN) Step(g *Graph, x, h_prev *Matrix) (h *Matrix) {
// make RNN computation graph at one time-step
// h = tanh( Wxh * x+Whh * h + bias )
h = g.Tanh(g.Add(g.Add(g.Mul(rnn.Wxh, x), g.Mul(rnn.Whh, h_prev)), rnn.Bias))
return
}