From 9ff2b501b93aa63f79825bfe7fd41fa94c1db951 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 16 Oct 2024 07:47:06 +0200 Subject: [PATCH] isleaf --- src/MLDataDevices.jl | 3 +++ src/public.jl | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index edf3b67..e9ef6c0 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -25,4 +25,7 @@ export get_device, get_device_type export DeviceIterator +### uncomment below when min supported julia version is >=1.11 +# public isleaf + end diff --git a/src/public.jl b/src/public.jl index 178c6f9..1911415 100644 --- a/src/public.jl +++ b/src/public.jl @@ -337,7 +337,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) + isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x) end end @@ -370,3 +370,6 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end + + +isleaf(x) = Functors.isleaf(x)