Skip to content

Commit

Permalink
cl/oneelement
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 5, 2024
1 parent 8b87c2b commit 62033dc
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
2 changes: 2 additions & 0 deletions lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Expand Down Expand Up @@ -48,6 +49,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
AMDGPU = "0.9.6, 1"
Adapt = "4.1"
CUDA = "5.2"
ChainRules = "1.51"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Expand Down
13 changes: 13 additions & 0 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using ChainRules: OneElement
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x
for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

end
11 changes: 8 additions & 3 deletions lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: AbstractDevice, CPUDevice
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice
using Zygote: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))
Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

end

12 changes: 12 additions & 0 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,15 @@ end

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

@testset "Zygote and ChainRules OneElement" begin
# Issue #91
using Zygote
cpu = cpu_device()
gpu = gpu_device()

g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1]
@test g isa Vector{Float32}
g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1]
@test g isa Matrix{Float32}
end

0 comments on commit 62033dc

Please sign in to comment.