From 9fac54f3fe0fabedfddd3fcfe321090b7f859b91 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 6 Nov 2024 17:19:37 +0100 Subject: [PATCH] fix ambiguity + oneAPI --- lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl | 1 + lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl | 10 ++++++++-- lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl | 9 +++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 9355b8171..90a5fe733 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -54,6 +54,7 @@ end # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) + function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device dev = MLDataDevices.get_device(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 0ca189ed6..0346745be 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,19 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,OneAPIDevice, ReactantDevice using ChainRules: OneElement -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index efe5f332e..fe0467a89 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,15 +1,20 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end