From aabb6f1472c90ddbe0c90adbdb8a6d76e6645cff Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 5 Nov 2024 06:40:49 +0100 Subject: [PATCH 1/8] cl/oneelement --- lib/MLDataDevices/Project.toml | 2 ++ lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 13 +++++++++++++ lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 11 ++++++++--- lib/MLDataDevices/test/misc_tests.jl | 12 ++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 49b955621..4d4f67433 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -50,6 +51,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.9.6, 1" Adapt = "4.1" CUDA = "5.2" +ChainRules = "1.51" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl new file mode 100644 index 000000000..05976ad5f --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -0,0 +1,13 @@ +module MLDataDevicesChainRulesExt + +using Adapt: Adapt +using ChainRules: OneElement +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice + +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + +end \ No newline at end of file diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 1b705c582..efe5f332e 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,10 +1,15 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: AbstractDevice, CPUDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice using Zygote: OneElement -Adapt.adapt_structure(::CPUDevice, x::OneElement) = x -Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) +Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end + +end + diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 5ece810bf..3e4122dd8 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -241,3 +241,15 @@ end @test x_rd isa Reactant.ConcreteRArray{Bool, 2} end end + +@testset "Zygote and ChainRules OneElement" begin + # Issue #91 + using Zygote + cpu = cpu_device() + gpu = gpu_device() + + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] + @test g isa Vector{Float32} + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + @test g isa Matrix{Float32} +end From 1f53dccb33756868fd1be4ffc8bb89a998e1a68d Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 5 Nov 2024 06:45:22 +0100 Subject: [PATCH 2/8] cleanup --- lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 2 +- lib/MLDataDevices/test/misc_tests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 05976ad5f..0ca189ed6 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -10,4 +10,4 @@ for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end -end \ No newline at end of file +end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 3e4122dd8..05e98b6a2 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -243,7 +243,7 @@ end end @testset "Zygote and ChainRules OneElement" begin - # Issue #91 + # Issue #1016 using Zygote cpu = cpu_device() gpu = gpu_device() From a78aa29aed839c023f6dbd4522408c27ce09eb2d Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 6 Nov 2024 17:19:37 +0100 Subject: [PATCH 3/8] fix ambiguity + oneAPI --- lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 1 + lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 10 ++++++++-- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 9 +++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 9355b8171..90a5fe733 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -54,6 +54,7 @@ end # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) + function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device dev = MLDataDevices.get_device(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 0ca189ed6..0346745be 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,19 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,OneAPIDevice, ReactantDevice using ChainRules: OneElement -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index efe5f332e..fe0467a89 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,15 +1,20 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end From 6a955656903b46f71b4081883e216427ba569fef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 6 Nov 2024 15:10:41 -0500 Subject: [PATCH 4/8] Apply suggestions from code review --- lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 6 +++--- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 0346745be..6396df51a 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,19 +1,19 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,OneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice using ChainRules: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end for Dev in (CUDADevice, AMDGPUDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) - @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) + @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index fe0467a89..3060562ee 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,12 +1,12 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end From bcf9f2e703badb13afbe4380d29c89be2fa279c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 6 Nov 2024 15:12:33 -0500 Subject: [PATCH 5/8] Apply suggestions from code review --- lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 10 +++------- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 10 +++------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 6396df51a..039058cff 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -6,14 +6,10 @@ using ChainRules: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) +for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) -end - -for Dev in (CUDADevice, AMDGPUDevice) - # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) - @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) + @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 3060562ee..9bec6a82f 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -6,14 +6,10 @@ using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) +for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) -end - -for Dev in (CUDADevice, AMDGPUDevice) - # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) - @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) + @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end From e35d643357a11d1eb951f58a8fa8da699f2323a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Nov 2024 12:16:52 -0500 Subject: [PATCH 6/8] chore: run the formatter --- .../ext/MLDataDevicesChainRulesExt.jl | 5 ++-- .../ext/MLDataDevicesZygoteExt.jl | 6 ++--- lib/MLDataDevices/test/misc_tests.jl | 24 +++++++++---------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 039058cff..eef457df1 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,14 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using ChainRules: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 9bec6a82f..53544a520 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,16 +1,16 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, + ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) + CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end - diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 05e98b6a2..2a22df370 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -222,6 +222,18 @@ end @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} end +@testset "Zygote and ChainRules OneElement #1016" begin + using Zygote + + cpu = cpu_device() + gpu = gpu_device() + + g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1, 2, 3])[1] + @test g isa Vector{Float32} + g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + @test g isa Matrix{Float32} +end + @testset "OneHotArrays" begin using OneHotArrays @@ -241,15 +253,3 @@ end @test x_rd isa Reactant.ConcreteRArray{Bool, 2} end end - -@testset "Zygote and ChainRules OneElement" begin - # Issue #1016 - using Zygote - cpu = cpu_device() - gpu = gpu_device() - - g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1] - @test g isa Vector{Float32} - g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1] - @test g isa Matrix{Float32} -end From 2015481d8d6f9e0941b00568e0f3e79193bac2e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Nov 2024 12:36:19 -0500 Subject: [PATCH 7/8] fix: scalar indexing issue --- lib/MLDataDevices/Project.toml | 1 + lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 4 ++-- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 4 ++-- lib/MLDataDevices/test/misc_tests.jl | 11 +++++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 4d4f67433..2eeec495a 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -32,6 +32,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesExt = "ChainRules" MLDataDevicesChainRulesCoreExt = "ChainRulesCore" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index eef457df1..25b05c01d 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -5,12 +5,12 @@ using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDev ReactantDevice using ChainRules: OneElement -Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) + @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 53544a520..66a363d55 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -5,12 +5,12 @@ using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDev ReactantDevice using Zygote: OneElement -Adapt.adapt_storage(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) + @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 2a22df370..55265ada3 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -225,12 +225,15 @@ end @testset "Zygote and ChainRules OneElement #1016" begin using Zygote - cpu = cpu_device() - gpu = gpu_device() + cdev = cpu_device() + gdev = gpu_device() - g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1, 2, 3])[1] + g = only(Zygote.gradient(x -> cdev(2 .* gdev(x))[1], Float32[1, 2, 3])) @test g isa Vector{Float32} - g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9])[1] + + g = only(Zygote.gradient( + x -> cdev(gdev(x) * gdev(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9] + )) @test g isa Matrix{Float32} end From 53e05385bae79da2db1ad7a182fb114c77fb7847 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Nov 2024 12:36:37 -0500 Subject: [PATCH 8/8] chore: bump version --- lib/MLDataDevices/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 2eeec495a..9566c340d 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.5.1" +version = "1.5.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"