Skip to content

Commit

Permalink
Merge #1473
Browse files Browse the repository at this point in the history
1473: Fix RNN tests on GPU r=DhairyaLGandhi a=jeremiedb

Fix for RNN on CUDA, as discussed in #1367 .

Co-authored-by: jeremie.db <[email protected]>
  • Loading branch information
bors[bot] and jeremiedb authored Jan 22, 2021
2 parents 5483a12 + 0b147d8 commit 2c0dcb2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Flux, CUDA, Test
using Flux: pullback

@testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu
Expand All @@ -9,7 +8,7 @@ using Flux: pullback
θ = gradient(() -> sum(m(x)), params(m))
@test x isa CuArray
@test θ[m.cell.Wi] isa CuArray
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
@test collect(m̄[].cell.Wi) == collect(θ[m.cell.Wi])
end

@testset "RNN" begin
Expand All @@ -34,9 +33,9 @@ end
cum̄, cux̄ = cuback(gpu(ȳ))

@test collect(cux̄)
@test_broken m̄[].cell[].Wi collect(cum̄[].cell[].Wi)
@test_broken m̄[].cell[].Wh collect(cum̄[].cell[].Wh)
@test_broken m̄[].cell[].b collect(cum̄[].cell[].b)
@test m̄[].cell.Wi collect(cum̄[].cell.Wi)
@test m̄[].cell.Wh collect(cum̄[].cell.Wh)
@test m̄[].cell.b collect(cum̄[].cell.b)
if m̄[].state isa Tuple
for (x, cx) in zip(m̄[].state, cum̄[].state)
@test x collect(cx)
Expand Down

0 comments on commit 2c0dcb2

Please sign in to comment.