From 3ece4adc94f4ef4dc8f570108a5e923855cbb5d1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 12 Dec 2024 19:12:18 +0100 Subject: [PATCH] fixes --- test/layers/recurrent.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 9382091a44..c66c4aa34f 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -190,7 +190,7 @@ end @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + @test r(x[1])[1] ≈ r(x[1], zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) @@ -247,22 +247,22 @@ end h = randn(Float32, 5) out, state = r(x, h) @test out === state - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x) ≈ r(x, zeros(Float32, 5)) + @test r(x)[1] ≈ r(x, zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # The input sequence has no batch dimension. x = rand(Float32, 3) h = randn(Float32, 5) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) end @testset "GRUv3" begin