Skip to content

Commit

Permalink
Merge pull request #935 from SciML/ChrisRackauckas-patch-2-1
Browse files Browse the repository at this point in the history
Fix mixed GPU CPU adjoint test by switching to Lux
  • Loading branch information
ChrisRackauckas authored Nov 12, 2023
2 parents d14e346 + 7b39938 commit 3a5c9e7
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/gpu/mixed_gpu_cpu_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
using SciMLSensitivity, OrdinaryDiffEq
using Flux, CUDA, Test, Zygote, Random, LinearAlgebra
using Lux, CUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays

CUDA.allowscalar(false)

H = CuArray(rand(Float32, 2, 2))
ann = Chain(Dense(1, 4, tanh))
p, re = Flux.destructure(ann)
ann = Lux.Chain(Lux.Dense(1, 4, tanh))
rng = Random.default_rng()
p, st = Lux.setup(rng, ann)
p = ComponentArray(p)
const _st = st

function func(x, p, t)
(re(p)([t])[1] * H) * x
CuArray(reshape(first(ann([t],p,_st)),2,2)) * H * x
end

x0 = CuArray(rand(Float32, 2))
Expand All @@ -30,10 +33,7 @@ function cost(p)
end

grad = Zygote.gradient(cost, p)[1]
@test !iszero(grad[1])
@test iszero(grad[2:4])
@test !iszero(grad[5])
@test iszero(grad[6:end])
@test !iszero(grad)

###
# https://github.com/SciML/SciMLSensitivity.jl/issues/632
Expand Down

0 comments on commit 3a5c9e7

Please sign in to comment.