From 92d940a010d0b43d24cd95dd06d0c7e1adc05ebc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 12 Dec 2024 19:12:18 +0100 Subject: [PATCH] fixes --- test/layers/recurrent.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 9d7d3e73dc..73dae4ac65 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -190,7 +190,7 @@ end @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + @test r(x[1])[1] ≈ r(x[1], zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) @@ -247,22 +247,22 @@ end h = randn(Float32, 5) out, state = r(x, h) @test out === state - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x) ≈ r(x, zeros(Float32, 5)) + @test r(x)[1] ≈ r(x, zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # The input sequence has no batch dimension. x = rand(Float32, 3) h = randn(Float32, 5) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) end @testset "GRUv3" begin