diff --git a/test/gpu/mixed_gpu_cpu_adjoint.jl b/test/gpu/mixed_gpu_cpu_adjoint.jl index ed061bc62..42eab336d 100644 --- a/test/gpu/mixed_gpu_cpu_adjoint.jl +++ b/test/gpu/mixed_gpu_cpu_adjoint.jl @@ -1,6 +1,8 @@ using SciMLSensitivity, OrdinaryDiffEq using Lux, LuxCUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays +const gdev = gpu_device() + CUDA.allowscalar(false) H = CuArray(rand(Float32, 2, 2)) @@ -42,10 +44,10 @@ grad = Zygote.gradient(cost, p)[1] rng = MersenneTwister(1234) m = 32 n = 16 -Z = randn(rng, Float32, (n, m)) |> gpu +Z = randn(rng, Float32, (n, m)) |> gdev 𝒯 = 2.0f0 Δτ = 1.0f-1 -ca_init = [zeros(1); ones(m)] |> gpu +ca_init = [zeros(1); ones(m)] |> gdev function f(ca, Z, t) a = ca[2:end] @@ -54,7 +56,7 @@ function f(ca, Z, t) Ka_unit = Z' * w_unit z_unit = dot(abs.(Ka_unit), a_unit) aKa_over_z = a .* Ka_unit / z_unit - [sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu + [sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gdev end function c(Z)