Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
isleaf
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 18, 2024
1 parent 17bc9aa commit 44a1dea
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ export get_device, get_device_type

export DeviceIterator

### uncomment below when min supported julia version is >=1.11
# public isleaf

end
5 changes: 4 additions & 1 deletion src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
Functors.isleaf(x) && return Adapt.adapt(D, x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x)
end
end
Expand Down Expand Up @@ -380,3 +380,6 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end


isleaf(x) = Functors.isleaf(x)

0 comments on commit 44a1dea

Please sign in to comment.