-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path02-gpu-model.jl
187 lines (151 loc) · 3.74 KB
/
02-gpu-model.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
using Flux
using BSON
using LinearAlgebra
using Statistics
using Random
using Zygote: @adjoint
using Zygote: @nograd
using Zygote
using ProgressBars
using Distributions
using CUDA
include("ctc.jl")
include("ctc-gpu.jl")
Random.seed!(1)
const TRAINDIR = "train"
const EPOCHS = 100
const BATCH_SIZE = 1
losses = []
forward = LSTM(26, 100)
backward = LSTM(26, 100)
output = Dense(200, 62)
const NOISE = Normal(0, 0.6)
function m(x)
h0f = collect(map(forward, x))
h0b = reverse(map(backward, reverse(x)))
h0 = map(i -> vcat(h0f[i], h0b[i]), 1:length(h0f)) |> collect
o = collect(map(output, h0))
return o
end
function loss(x, y)
x = addNoise(x)
Flux.reset!((forward, backward))
yhat = m(x)
yhat = reduce(hcat, yhat)
l = ctc(CuArray(yhat), y)
addToGlobalLoss(l)
return l
end
@nograd function addNoise(x)
x = deepcopy(x)
n = [rand(NOISE, 26) for i in 1:length(x)]
for i in 1:length(x)
x[i] .+= n[i]
end
return x
end
function readData(dataDir)
# fnames = open(readlines, "shuffled_names.txt")
fnames = readdir(dataDir)
shuffle!(MersenneTwister(4), fnames)
Xs = []
Ys = []
for fname in fnames
BSON.@load joinpath(dataDir, fname) x y
x = [Float32.(x[i,:]) for i in 1:size(x,1)]
push!(Xs, x)
push!(Ys, Array(y'))
end
m = mean(reduce(vcat, Xs))
st = std(reduce(vcat, Xs))
for (i, x) in enumerate(Xs)
Xs[i] = [(xI .- m) ./ st for xI in x]
end
return (Xs, Ys)
end
function lev(s, t)
m = length(s)
n = length(t)
d = Array{Int}(zeros(m+1, n+1))
for i=2:(m+1)
@inbounds d[i, 1] = i-1
end
for j=2:(n+1)
@inbounds d[1, j] = j-1
end
for j=2:(n+1)
for i=2:(m+1)
@inbounds if s[i-1] == t[j-1]
substitutionCost = 0
else
substitutionCost = 1
end
@inbounds d[i, j] = min(d[i-1, j] + 1, # Deletion
d[i, j-1] + 1, # Insertion
d[i-1, j-1] + substitutionCost) # Substitution
end
end
@inbounds return d[m+1, n+1]
end
"""
collapse(seq)
Collapses `seq` into a version that has symbols collapsed into one repetition and removes all instances
of the blank symbol.
"""
function collapse(seq)
s = [x for x in seq if x != 62]
if isempty(s) return s end
s = [seq[1]]
for ch in seq[2:end]
if ch != s[end] && ch != 62
push!(s, ch)
end
end
return s
end
"""
per(x, y)
Compute the phoneme error rate of the model for input `x` and target `y`. The phoneme error rate
is defined as the Levenshtein distance between the labeling produced by running `x` through
the model and the target labeling in `y`, all divided by the length of the target labeling
in `y`
"""
function per(x, y)
Flux.reset!((forward, backward))
yhat = m(x)
yhat = reduce(hcat, yhat)
yhat = mapslices(argmax, yhat, dims=1) |> vec |> collapse
y = mapslices(argmax, y, dims=1) |> vec |> collapse
return lev(yhat, y) / length(y)
end
function addToGlobalLoss(x)
global losses
push!(losses, x)
end
@adjoint function addToGlobalLoss(x)
addToGlobalLoss(x)
return nothing, () -> nothing
end
function main()
println("Loading files")
Xs, Ys = readData(TRAINDIR)
data = collect(zip(Xs, Ys))
Xs = [d[1] for d in data]
Ys = [d[2] for d in data]
println("Beginning training")
data = zip(Xs, Ys) |> collect
valData = data[1:186]
data = data[187:end]
opt = Momentum(1e-4)
for i in 1:EPOCHS
global losses
losses = []
println("Beginning epoch $i/$EPOCHS")
Flux.train!(loss, Flux.params((forward, backward, output)), ProgressBar(data), opt)
println("Calculating PER...")
p = mean(map(x -> per(x...), valData))
println("PER: $(p*100)")
println("Mean loss: ", mean(losses))
end
end
main()