diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 230d04add3..ff1183e99f 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -46,19 +46,19 @@ function ctc_alpha(ŷ::AbstractArray, y) α[1,1] = ŷ[blank, 1] α[2,1] = ŷ[z′[2], 1] for t=2:T - bound = max(1, U′ - 2(T - t) - 1) + bound = max(1, U′ - 2(T - t) - 1) for u=bound:U′ - if u == 1 - α[u,t] = α[u, t-1] - else - α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) - - # array bounds check and f(u) function from Eq. 7.9 - if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) - α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) - end - end - α[u,t] += ŷ[z′[u], t] + if u == 1 + α[u,t] = α[u, t-1] + else + α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) + + # array bounds check and f(u) function from Eq. 7.9 + if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) + α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) + end + end + α[u,t] += ŷ[z′[u], t] end end return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) @@ -80,18 +80,18 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out) # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 - bound = min(U′, 2t) + bound = min(U′, 2t) for u=bound:-1:1 - if u == U′ - β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] - else - β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) + if u == U′ + β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] + else + β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) - # array bounds check and g(u) function from Eq. 7.16 - if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] - β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) - end - end + # array bounds check and g(u) function from Eq. 7.16 + if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] + β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) + end + end end end @@ -132,7 +132,7 @@ for mathematical details. ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss @adjoint function ctc_loss(ŷ, y) - out = ctc_alpha(ŷ, y) - ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) - return out.loss, ctc_loss_pullback + out = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) + return out.loss, ctc_loss_pullback end