diff --git a/Project.toml b/Project.toml index cb02fb6..48632d0 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] AMDGPU = "0.9.6, 1" Adapt = "4.1" CUDA = "5.2" -ChainRules = "1.51.0" +ChainRules = "1.51" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" diff --git a/ext/MLDataDevicesZygoteExt.jl b/ext/MLDataDevicesZygoteExt.jl index 3678dbb..4796ef2 100644 --- a/ext/MLDataDevicesZygoteExt.jl +++ b/ext/MLDataDevicesZygoteExt.jl @@ -7,7 +7,7 @@ using Zygote: OneElement Adapt.adapt_storage(::CPUDevice, x::OneElement) = x for Dev in GPU_DEVICES - # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray), ... + # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) @eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x)) end