forked from tmbdev/clstm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-lstm.py
34 lines (34 loc) · 952 Bytes
/
test-lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/python
from numpy import *
from pylab import *
from numpy.random import rand
import clstm
net = clstm.make_net_init("lstm1", "ninput=1:nhidden=4:noutput=2")
net.setLearningRate(1e-4, 0.9)
N = 20
ntrain = 30000
ntest = 1000
print "training 1:4:2 network to learn delay"
for i in range(ntrain):
xs = array(rand(N) < 0.3, 'f')
ys = roll(xs, 1)
ys[0] = 0
ys = array([1 - ys, ys], 'f').T.copy()
net.inputs.aset(xs.reshape(N, 1, 1))
net.forward()
net.outputs.dset(ys.reshape(N, 2, 1) - net.outputs.array())
net.backward()
clstm.sgd_update(net)
print "testing", ntest, "random instances"
maxerr = 0.0
for i in range(ntest):
xs = array(rand(N) < 0.3, 'f')
ys = roll(xs, 1)
ys[0] = 0
net.inputs.aset(xs.reshape(N, 1, 1))
net.forward()
preds = net.outputs.array()[:, 1, 0]
err = abs(amax(abs(ys - preds)))
assert (err < 0.1), err
maxerr = maximum(err, maxerr)
print "OK", maxerr