From ea3246cd25a213dc4f8bfcdac0031f9cb5572381 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sun, 7 Jan 2024 20:15:49 -0800 Subject: [PATCH] Hotfix for new OneElement on GPU --- ext/FluxCUDAExt/functor.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index e8a89c5553..dc8649fff0 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -29,6 +29,10 @@ adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = # TODO: figure out the correct design for OneElement adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) +# Patch for GPU support until we can make OneElement smarter +if isdefined(Zygote.ChainRules, :OneElement) + adapt_storage(to::FluxCUDAAdaptor, x::Zygote.ChainRules.OneElement) = CUDA.cu(collect(x)) +end adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()