-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathctc-gpu.jl
291 lines (213 loc) · 6.49 KB
/
ctc-gpu.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# GPU impelmentation
# a port of the GPU kernels from Baidu's C++ warp-ctc package
# GitHub: https://github.com/baidu-research/warp-ctc/
# paper: https://arxiv.org/pdf/1512.02595.pdf
using Flux
using Statistics
using CUDA
const MAX_THREADS = 256
function log_plus_f(p1, p2)
isinf(p1) && return p2
isinf(p2) && return p1
if p1 < p2
p1, p2 = p2, p1
end
return p1 + CUDA.log(1+CUDA.exp(p2 - p1))
end
function countRepeats(A)
repeats = 0
for (i,elem) in enumerate(A)
if i > 1 && A[i] == A[i-1]
repeats += 1
end
end
return repeats
end
function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel)
tid = threadIdx().x
L = labelSize
T = uttLength
S = length(labelsWithBlanks)
if L + repeats > T
return nothing
end
labels = labelsWithBlanks
# Corner-case checking
start = (L + repeats <= T) ? 0 : 1
last = S > 1 ? 2 : 1
# Fill in first column (time step)
i = tid
while i <= last - start
alpha[start+i, 1] = probs[labels[start+i], 1]
i += blockDim().x
end
sync_threads()
# Fill in coefficients for each time step
for t=2:T
# Corner-case checking
if tid == 1 && !(1 < S - 2*(T-t) - 1)
if start == 0
alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1]
elseif start == 1
alpha[1, t] = alpha[1, t-1]
end
end
sync_threads()
# Fill in coefficients for each label class in the target output sequence;
# each thread will process the calculations for one class
idx = tid+1
while idx <= S
prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1])
if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2]
prevSum = log_plus_f(prevSum, alpha[idx-2, t-1])
end
if idx < S - 2*(T-t) - 1
alpha[idx, t] = -Inf32
else
alpha[idx, t] = prevSum + probs[labels[idx], t]
end
idx += blockDim().x
end
sync_threads()
end
return nothing
end
function computeBetasAndGradKernel(probs, labelSize, uttLength,
repeatsInLabel, labelsWithBlanks,
alphas, beta, output, accum,
grad, blankLabel)
tid = threadIdx().x
L = labelSize
T = uttLength
S = 2*L + 1
repeats = repeatsInLabel
labels = labelsWithBlanks
if (L+repeats) > T
return nothing
end
# Corner-case checking
start = S > 1 ? S-2 : 0
last = L + repeats < T ? S : S-1
sync_threads()
i = tid
# Calculate coefficients for last column (time step)
# then determine alpha and beta product
while i <= last - start + 1
beta[i+start, T] = 0
output[i+start, T] = beta[i+start, T] + alphas[i+start, T]
i += blockDim().x
end
sync_threads()
# Fill in `accum` for last column (time step)
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T])
end
end
sync_threads()
# Fill in `grad` for last column (time step)
idx = tid
while idx <= size(grad, 1)
s = -Inf32
for i=1:S
s = log_plus_f(s, output[i, T])
end
# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s)
idx += blockDim().x
end
sync_threads()
# Fill in the rest of the coefficients
t = T-1
while t >= 1
if t < T
idx = tid
# while idx <= S-1
while idx <= S
nextSum = beta[idx, t+1] + probs[labels[idx], t+1]
if idx < S
nextSum = log_plus_f(nextSum,
beta[idx+1, t+1] + probs[labels[idx+1], t+1])
end
if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2]
nextSum = log_plus_f(nextSum,
beta[idx + 2, t+1] + probs[labels[idx+2], t+1])
end
if idx > 2*t
beta[idx, t] = -Inf32
else
beta[idx, t] = nextSum
end
idx += blockDim().x
end
sync_threads()
if tid == 1 && last == S
beta[S, t] = beta[S, t] + probs[blankLabel, t+1]
end
sync_threads()
idx = tid
while idx <= S
output[idx, t] = alphas[idx, t] + beta[idx, t]
idx += blockDim().x
end
sync_threads()
end
sync_threads()
# Calculate accumulated alpha-beta products for each label class for
# each time step; used in calculating gradients
if tid == 1
for i=1:S
labelIdx = labels[i]
accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t])
end
end
sync_threads()
idx = tid
# Calculate gradients
while idx <= size(grad, 1)
s = -Inf32
for i=1:S
s = log_plus_f(s, output[i, t])
end
# ∂L/∂a (where a is activation before logsoftmax)
grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s)
idx += blockDim().x
end
sync_threads()
t -= 1
sync_threads()
end
return nothing
end
# methods for `ctc_` helper function
ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean
ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean
ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean
ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))
function ctc_(ŷ::CuArray, y)
ŷ = logsoftmax(ŷ)
blank = size(ŷ, 1)
labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)]
z = F(labels, blank)
z′ = [blank]
for label in z
push!(z′, label)
push!(z′, blank)
end
T = size(ŷ, 2)
U′ = 2*length(z) + 1
alphas = CUDA.fill(log(zero(ŷ[1])), U′, T)
betas = CUDA.fill(log(zero(ŷ[1])), U′, T)
output = CUDA.fill(log(zero(ŷ[1])), U′, T)
nRepeats = countRepeats(labels)
nThreads = min(U′, MAX_THREADS)
@cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank)
grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ))
accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ))
@cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank)
ls = collect(output)
ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)])
ŷ = alphas = betas = output = accum = nothing
return ls, grads
end