diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl index 9355b8171..90a5fe733 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -54,6 +54,7 @@ end # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) + function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device dev = MLDataDevices.get_device(x) diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 0ca189ed6..0346745be 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,13 +1,19 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice,OneAPIDevice, ReactantDevice using ChainRules: OneElement -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) + +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index efe5f332e..fe0467a89 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,15 +1,20 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice +using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x -for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, ReactantDevice) +for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, OneAPIDevice, ReactantDevice) # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end +for Dev in (CUDADevice, AMDGPUDevice) + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray) + @eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x)) +end + end