diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 3cc272fd3..96bc0fd0c 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -48,6 +49,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.9.6, 1" Adapt = "4.1" CUDA = "5.2" +ChainRules = "1.51" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl new file mode 100644 index 000000000..05976ad5f --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -0,0 +1,13 @@ +module MLDataDevicesChainRulesExt + +using Adapt: Adapt +using ChainRules: OneElement +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice + +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + +end \ No newline at end of file diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 1b705c582..efe5f332e 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,10 +1,15 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: AbstractDevice, CPUDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice using Zygote: OneElement -Adapt.adapt_structure(::CPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end + +end + diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 28275d3b7..f07084f76 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -219,3 +219,15 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end + +@testset "Zygote and ChainRules OneElement" begin + # Issue #91 + using Zygote + cpu = cpu_device() + gpu = gpu_device() + + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] + @test g isa Vector{Float32} + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + @test g isa Matrix{Float32} +end \ No newline at end of file