diff --git a/src/public.jl b/src/public.jl index 4111dba..379aea9 100644 --- a/src/public.jl +++ b/src/public.jl @@ -348,7 +348,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x) + return Functors.fmap(D, x, exclude=isleaf) end end end