Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 12, 2024
1 parent e5e0896 commit 3ece4ad
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ end
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) r(x[1], zeros(Float32, 5))
@test r(x[1])[1] r(x[1], zeros(Float32, 5))[1]

# Now initial state has a batch dimension.
h = randn(Float32, 5, 4)
Expand Down Expand Up @@ -247,22 +247,22 @@ end
h = randn(Float32, 5)
out, state = r(x, h)
@test out === state
test_gradients(r, x, h)
test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1]))

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x) r(x, zeros(Float32, 5))
@test r(x)[1] r(x, zeros(Float32, 5))[1]

# Now initial state has a batch dimension.
h = randn(Float32, 5, 4)
test_gradients(r, x, h)
test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1]))

# The input sequence has no batch dimension.
x = rand(Float32, 3)
h = randn(Float32, 5)
test_gradients(r, x, h)
test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1]))
end

@testset "GRUv3" begin
Expand Down

0 comments on commit 3ece4ad

Please sign in to comment.