Skip to content

Commit

Permalink
fix ambiguity + oneAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 6, 2024
1 parent 71123bd commit 9fac54f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)

function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
dev = MLDataDevices.get_device(x)
Expand Down
10 changes: 8 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,OneAPIDevice, ReactantDevice
using ChainRules: OneElement
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x
for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice)

for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

for Dev in (CUDADevice, AMDGPUDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

end
9 changes: 7 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice
using Zygote: OneElement

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice)
for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

for Dev in (CUDADevice, AMDGPUDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x))
end

end

0 comments on commit 9fac54f

Please sign in to comment.