diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5bb1dbfe3c..59538d1772 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -10,7 +10,6 @@ steps: queue: "juliagpu" cuda: "*" env: - JULIA_CUDA_USE_BINARYBUILDER: "true" FLUX_TEST_CUDA: "true" FLUX_TEST_CPU: "false" timeout_in_minutes: 60 diff --git a/Project.toml b/Project.toml index f40a8cdab7..9c332a8fc4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.8" +version = "0.14.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl index e8a89c5553..dc8649fff0 100644 --- a/ext/FluxCUDAExt/functor.jl +++ b/ext/FluxCUDAExt/functor.jl @@ -29,6 +29,10 @@ adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = # TODO: figure out the correct design for OneElement adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) +# Patch for GPU support until we can make OneElement smarter +if isdefined(Zygote.ChainRules, :OneElement) + adapt_storage(to::FluxCUDAAdaptor, x::Zygote.ChainRules.OneElement) = CUDA.cu(collect(x)) +end adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()