-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathSeq2Seq.hs
123 lines (109 loc) · 4.48 KB
/
Seq2Seq.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnicodeSyntax #-}
module Main where
import TypedFlow
import TypedFlow.Python
mkLSTM :: ∀ n x w.
KnownNat x => KnownNat n => KnownBits w
=> String
-> Gen (RnnCell w '[ '[n], '[n]] (Tensor '[x] (Flt w))
(Tensor '[n] (Flt w)))
mkLSTM pName = do
params <- parameterDefault pName
drp1 <- mkDropout (DropProb 0.05)
rdrp1 <- mkDropouts (DropProb 0.05)
return (timeDistribute drp1 .-. onStates rdrp1 (lstm params))
encoder :: forall (lstmSize :: Nat) (vocSize :: Nat) (n :: Nat) w.
KnownNat lstmSize => KnownNat vocSize
=> (KnownNat n) => KnownBits w
=> String
-> Gen
(
T '[] Int32 -- length
-> Tensor '[n] Int32 ->
((HTV (Flt w) '[ '[lstmSize], '[lstmSize] ], Tensor '[n, lstmSize] (Flt w))))
encoder prefix = do
embs <- parameterDefault (prefix++"embs")
lstm1 <- mkLSTM (prefix++"lstm1")
return $ \len input ->
runRnn
(iterateWithCull len (timeDistribute (embedding @vocSize @vocSize embs) .-. lstm1))
(repeatT zeros, input)
decoder :: forall (lstmSize :: Nat) (n :: Nat) (outVocabSize :: Nat) (d::Nat) w.
KnownNat lstmSize => KnownNat d => (KnownNat outVocabSize, KnownNat n) => KnownBits w =>
String
-> Gen (
T '[] Int32 -- ^ length
-> T '[n, d] (Flt w) -- todo: consider a larger size for the output string
-> HTV (Flt w) '[ '[lstmSize], '[lstmSize] ]
-> Tensor '[n] Int32
-> Tensor '[n, outVocabSize] (Flt w))
decoder prefix = do
-- note: for an intra-language translation the embeddings can be shared easily.
projs <- parameterDefault (prefix++"proj")
lstm1 <- mkLSTM (prefix++"lstm1")
embs <- parameterDefault "embs"
w1 <- parameter' (prefix++"att1") =<< glorotUniform
return $ \ lens hs thoughtVectors targetInput ->
let attn = uniformAttn (multiplicativeScoring w1) lens hs -- NOTE: attention on the left-part of the input.
(_sFinal,outFinal) = simpleRnn
((timeDistribute (embedding @outVocabSize @outVocabSize embs)
.-.
attentiveWithFeedback attn lstm1
.-.
timeDistribute (dense projs)))
((F zeros :* thoughtVectors), targetInput)
in outFinal
seq2seq :: forall (vocSize :: Nat) (n :: Nat).
KnownNat vocSize => (KnownNat n)
=> Gen (Placeholders
'[ '("tgt_weights", '[n], Float32),
'("src_in", '[n], Int32),
'("src_len", '[], Int32),
'("tgt_in", '[n], Int32),
'("tgt_out", '[n], Int32)] ->
ModelOutput Float32 '[n, vocSize] '[])
seq2seq = do
enc <- encoder @256 @vocSize "enc"
dec <- decoder "dec"
return $ \(PHT masks :* PHT input :* PHT inputLen :* PHT tgtIn :* PHT tgtOut :* Unit) ->
let (VecPair t1 t2,h) = enc inputLen input
y_ = dec inputLen h (VecPair t1 t2) tgtIn
in timedCategorical masks y_ tgtOut
main :: IO ()
main = generateFile "s2s.py" (compileGen @256
defaultOptions {maxGradientNorm = Just 5}
(stateless <$> seq2seq @15 @22))
-- >>> main
-- Parameters (total 889041):
-- decatt1: T [256,256] tf.float32
-- embs: T [15,15] tf.float32
-- declstm1_o_bias: T [256] tf.float32
-- declstm1_o_w: T [527,256] tf.float32
-- declstm1_c_bias: T [256] tf.float32
-- declstm1_c_w: T [527,256] tf.float32
-- declstm1_i_bias: T [256] tf.float32
-- declstm1_i_w: T [527,256] tf.float32
-- declstm1_f_bias: T [256] tf.float32
-- declstm1_f_w: T [527,256] tf.float32
-- decproj_bias: T [15] tf.float32
-- decproj_w: T [256,15] tf.float32
-- enclstm1_o_bias: T [256] tf.float32
-- enclstm1_o_w: T [271,256] tf.float32
-- enclstm1_c_bias: T [256] tf.float32
-- enclstm1_c_w: T [271,256] tf.float32
-- enclstm1_i_bias: T [256] tf.float32
-- enclstm1_i_w: T [271,256] tf.float32
-- enclstm1_f_bias: T [256] tf.float32
-- enclstm1_f_w: T [271,256] tf.float32
-- encembs: T [15,15] tf.float32
-- Local Variables:
-- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl")
-- End: