Skip to content

Commit

Permalink
Hotfix for new OneElement on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir authored Jan 8, 2024
1 parent 1af3f4d commit ea3246c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =

# TODO: figure out the correct design for OneElement
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
# Patch for GPU support until we can make OneElement smarter
if isdefined(Zygote.ChainRules, :OneElement)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.ChainRules.OneElement) = CUDA.cu(collect(x))
end

adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
Expand Down

0 comments on commit ea3246c

Please sign in to comment.