From 53206ad0396ff8560dd6ca91fff7e18a35c3e990 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 27 Oct 2024 10:31:12 +0100 Subject: [PATCH] fix oneelement --- Project.toml | 3 +++ ext/MLDataDevicesChainRulesExt.jl | 13 +++++++++++++ ext/MLDataDevicesZygoteExt.jl | 10 +++++++--- test/misc_tests.jl | 12 ++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 ext/MLDataDevicesChainRulesExt.jl diff --git a/Project.toml b/Project.toml index 68d4325..cb02fb6 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -30,6 +31,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesExt = "ChainRules" MLDataDevicesChainRulesCoreExt = "ChainRulesCore" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" @@ -48,6 +50,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.9.6, 1" Adapt = "4.1" CUDA = "5.2" +ChainRules = "1.51.0" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" diff --git a/ext/MLDataDevicesChainRulesExt.jl b/ext/MLDataDevicesChainRulesExt.jl new file mode 100644 index 0000000..8cd469d --- /dev/null +++ b/ext/MLDataDevicesChainRulesExt.jl @@ -0,0 +1,13 @@ +module MLDataDevicesChainRulesExt + +using Adapt: Adapt +using ChainRules: OneElement +using MLDataDevices: GPU_DEVICES, CPUDevice + +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +for Dev in GPU_DEVICES + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + +end diff --git a/ext/MLDataDevicesZygoteExt.jl b/ext/MLDataDevicesZygoteExt.jl index 1b705c5..3678dbb 100644 --- a/ext/MLDataDevicesZygoteExt.jl +++ b/ext/MLDataDevicesZygoteExt.jl @@ -1,10 +1,14 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: AbstractDevice, CPUDevice +using MLDataDevices: CPUDevice, GPU_DEVICES using Zygote: OneElement -Adapt.adapt_structure(::CPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x + +for Dev in GPU_DEVICES + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray), ... + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end end diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 28275d3..415a610 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -219,3 +219,15 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end + +@testset "Zygote and ChainRules OneElement" begin + # Issue #91 + using Zygote + cpu = cpu_device() + gpu = gpu_device() + + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] + @test g isa Vector{Float32} + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + @test g isa Matrix{Float32} +end