diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index 8dba7c0e57..63a5f93ada 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -1,5 +1,4 @@ using Flux, CUDA, Test -using Flux: pullback @testset for R in [RNN, GRU, LSTM] m = R(10, 5) |> gpu @@ -9,7 +8,7 @@ using Flux: pullback θ = gradient(() -> sum(m(x)), params(m)) @test x isa CuArray @test θ[m.cell.Wi] isa CuArray - @test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi]) + @test collect(m̄[].cell.Wi) == collect(θ[m.cell.Wi]) end @testset "RNN" begin @@ -34,9 +33,9 @@ end cum̄, cux̄ = cuback(gpu(ȳ)) @test x̄ ≈ collect(cux̄) - @test_broken m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi) - @test_broken m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh) - @test_broken m̄[].cell[].b ≈ collect(cum̄[].cell[].b) + @test m̄[].cell.Wi ≈ collect(cum̄[].cell.Wi) + @test m̄[].cell.Wh ≈ collect(cum̄[].cell.Wh) + @test m̄[].cell.b ≈ collect(cum̄[].cell.b) if m̄[].state isa Tuple for (x, cx) in zip(m̄[].state, cum̄[].state) @test x ≈ collect(cx)