diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index edf3b67..e9ef6c0 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -25,4 +25,7 @@ export get_device, get_device_type export DeviceIterator +### uncomment below when min supported julia version is >=1.11 +# public isleaf + end diff --git a/src/public.jl b/src/public.jl index 178c6f9..1911415 100644 --- a/src/public.jl +++ b/src/public.jl @@ -337,7 +337,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) + isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x) end end @@ -370,3 +370,6 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end + + +isleaf(x) = Functors.isleaf(x)