diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 3d7d53a486..3ad7428601 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -307,10 +307,10 @@ end end @testset "Recurrence" begin + x = rand(Float32, 2, 3, 4) for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)] cell = rnn.cell rec = Recurrence(cell) - x = rand(Float32, 2, 3, 4) @test rec(x) ≈ rnn(x) end end