diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 039058cff..eef457df1 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,14 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using ChainRules: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 9bec6a82f..53544a520 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,16 +1,16 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end - diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 42d27cf00..2a22df370 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -222,15 +222,15 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end -@testset "Zygote and ChainRules OneElement" begin - # Issue #1016 +@testset "Zygote and ChainRules OneElement #1016" begin using Zygote + cpu = cpu_device() gpu = gpu_device() - g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1, 2, 3])[1] @test g isa Vector{Float32} - g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9])[1] @test g isa Matrix{Float32} end