diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 49b955621..9566c340d 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.5.1" +version = "1.5.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -31,6 +32,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesExt = "ChainRules" MLDataDevicesChainRulesCoreExt = "ChainRulesCore" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" @@ -50,6 +52,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.9.6, 1" Adapt = "4.1" CUDA = "5.2" +ChainRules = "1.51" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 9355b8171..90a5fe733 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -54,6 +54,7 @@ end # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) + function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device dev = MLDataDevices.get_device(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl new file mode 100644 index 000000000..25b05c01d --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -0,0 +1,16 @@ +module MLDataDevicesChainRulesExt + +using Adapt: Adapt +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice +using ChainRules: OneElement + +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x + +for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 1b705c582..66a363d55 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,10 +1,16 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: AbstractDevice, CPUDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using Zygote: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) + +for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) +end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 5ece810bf..55265ada3 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -222,6 +222,21 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end +@testset "Zygote and ChainRules OneElement #1016" begin + using Zygote + + cdev = cpu_device() + gdev = gpu_device() + + g = only(Zygote.gradient(x -> cdev(2 .* gdev(x))[1], Float32[1, 2, 3])) + @test g isa Vector{Float32} + + g = only(Zygote.gradient( + x -> cdev(gdev(x) * gdev(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9] + )) + @test g isa Matrix{Float32} +end + @testset "OneHotArrays" begin using OneHotArrays