Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
isleaf
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 16, 2024
1 parent 71ed455 commit 9ff2b50
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ export get_device, get_device_type

export DeviceIterator

### uncomment below when min supported julia version is >=1.11
# public isleaf

end
5 changes: 4 additions & 1 deletion src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
Functors.isleaf(x) && return Adapt.adapt(D, x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x)
end
end
Expand Down Expand Up @@ -370,3 +370,6 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end


isleaf(x) = Functors.isleaf(x)

0 comments on commit 9ff2b50

Please sign in to comment.