Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix oneelement
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 27, 2024
1 parent ece7ba2 commit 53206ad
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Expand All @@ -30,6 +31,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesChainRulesExt = "ChainRules"
MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
Expand All @@ -48,6 +50,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
AMDGPU = "0.9.6, 1"
Adapt = "4.1"
CUDA = "5.2"
ChainRules = "1.51.0"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Expand Down
13 changes: 13 additions & 0 deletions ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using ChainRules: OneElement
using MLDataDevices: GPU_DEVICES, CPUDevice

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x
for Dev in GPU_DEVICES
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

end
10 changes: 7 additions & 3 deletions ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: AbstractDevice, CPUDevice
using MLDataDevices: CPUDevice, GPU_DEVICES
using Zygote: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))
Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for Dev in GPU_DEVICES
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray), ...
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

end
12 changes: 12 additions & 0 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,15 @@ end

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

@testset "Zygote and ChainRules OneElement" begin
# Issue #91
using Zygote
cpu = cpu_device()
gpu = gpu_device()

g = Zygote.gradient(x -> cpu(2 .* gpu(x))[1], Float32[1,2,3])[1]
@test g isa Vector{Float32}
g = Zygote.gradient(x -> cpu(gpu(x) * gpu(x))[1,2], Float32[1 2 3; 4 5 6; 7 8 9])[1]
@test g isa Matrix{Float32}
end

0 comments on commit 53206ad

Please sign in to comment.