Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Define isleaf #84

Merged
merged 9 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.0"
version = "1.4.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1"
Adapt = "4"
CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10, 11"
Expand Down
3 changes: 3 additions & 0 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Adapt: Adapt
using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using Compat: @compat

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
Expand All @@ -25,4 +26,6 @@ export get_device, get_device_type

export DeviceIterator

@compat(public, (isleaf,))

end
21 changes: 19 additions & 2 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
Functors.isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x; exclude=isleaf)
end
end
end
Expand Down Expand Up @@ -380,3 +380,20 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool

Returns `true` if `x` is a leaf node in the data structure.

Defining `MLDataDevices.isleaf(x::T) = true` for custom types
can be used to customize the behavior the data movement behavior
when an object with nested structure containing the type is transferred to a device.

`Adapt.adapt_structure(::AbstractDevice, x::T)` or
`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during
data movement if `isleaf(x::T) == true`.

If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`.
"""
isleaf(x) = Functors.isleaf(x)
21 changes: 21 additions & 0 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ArrayInterface: parameterless_type
using ChainRulesTestUtils: test_rrule
using ReverseDiff, Tracker, ForwardDiff
using SparseArrays, FillArrays, Zygote, RecursiveArrayTools
using Functors: Functors

@testset "Issues Patches" begin
@testset "#10 patch" begin
Expand Down Expand Up @@ -157,3 +158,23 @@ end
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end

@testset "isleaf" begin
# Functors.isleaf fallback
@test MLDataDevices.isleaf(rand(2))
@test !MLDataDevices.isleaf((rand(2),))

struct Tleaf
x::Any
end
Functors.@functor Tleaf
MLDataDevices.isleaf(::Tleaf) = true
Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x))

cpu = cpu_device()
t = Tleaf(ones(2))
y = cpu(t)
@test y.x == 2 .* ones(2)
y = cpu([(t,)])
@test y[1][1].x == 2 .* ones(2)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end
all_files = ["cuda_tests.jl", "amdgpu_tests.jl",
"metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"]
file_names = BACKEND_GROUP == "all" ? all_files :
(BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"])
BACKEND_GROUP ∈ ("cpu", "none") ? [] : [BACKEND_GROUP * "_tests.jl"]
@testset "$(file_name)" for file_name in file_names
run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path))
--startup-file=no --code-coverage=user $(@__DIR__)/$file_name`)
Expand Down
Loading