Skip to content

Commit

Permalink
Fix indentation in ctc.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
maetshju committed Jan 19, 2021
1 parent 6e5fb17 commit bc94a16
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ function ctc_alpha(ŷ::AbstractArray, y)
α[1,1] = ŷ[blank, 1]
α[2,1] = ŷ[z′[2], 1]
for t=2:T
bound = max(1, U′ - 2(T - t) - 1)
bound = max(1, U′ - 2(T - t) - 1)
for u=bound:U′
if u == 1
α[u,t] = α[u, t-1]
else
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
# array bounds check and f(u) function from Eq. 7.9
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
end
end
α[u,t] += ŷ[z′[u], t]
if u == 1
α[u,t] = α[u, t-1]
else
α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1])
# array bounds check and f(u) function from Eq. 7.9
if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u])
α[u,t] = logaddexp(α[u,t], α[u-2,t-1])
end
end
α[u,t] += ŷ[z′[u], t]
end
end
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
Expand All @@ -80,18 +80,18 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out)

# start at T-1 so that β(T, u) = log(0) for all u < U′ - 1
for t=(T-1):-1:1
bound = min(U′, 2t)
bound = min(U′, 2t)
for u=bound:-1:1
if u == U′
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
else
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])
if u == U′
β[u,t] = ŷ[z′[u], t+1] + β[u, t+1]
else
β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1])

# array bounds check and g(u) function from Eq. 7.16
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
end
end
# array bounds check and g(u) function from Eq. 7.16
if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2]
β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1])
end
end
end
end

Expand Down Expand Up @@ -132,7 +132,7 @@ for mathematical details.
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

@adjoint function ctc_loss(ŷ, y)
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
end

0 comments on commit bc94a16

Please sign in to comment.