diff --git a/Project.toml b/Project.toml index 41f3134..7f34fa4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.4.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1" Adapt = "4" CUDA = "5.2" ChainRulesCore = "1.23" +Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index edf3b67..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -4,6 +4,7 @@ using Adapt: Adapt using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using Compat: @compat abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end @@ -25,4 +26,6 @@ export get_device, get_device_type export DeviceIterator +@compat(public, (isleaf,)) + end diff --git a/src/public.jl b/src/public.jl index 1dc1646..281980e 100644 --- a/src/public.jl +++ b/src/public.jl @@ -347,8 +347,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x) + isleaf(x) && return Adapt.adapt(D, x) + return Functors.fmap(D, x; exclude=isleaf) end end end @@ -380,3 +380,20 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end + +""" + isleaf(x) -> Bool + +Returns `true` if `x` is a leaf node in the data structure. + +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior +when an object with nested structure containing the type is transferred to a device. + +`Adapt.adapt_structure(::AbstractDevice, x::T)` or +`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during +data movement if `isleaf(x::T) == true`. + +If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. +""" +isleaf(x) = Functors.isleaf(x) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index f6ea454..942c2ff 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -3,6 +3,7 @@ using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using Functors: Functors @testset "Issues Patches" begin @testset "#10 patch" begin @@ -157,3 +158,23 @@ end @test get_device(x) isa MLDataDevices.UnknownDevice @test get_device_type(x) <: MLDataDevices.UnknownDevice end + +@testset "isleaf" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x::Any + end + Functors.@functor Tleaf + MLDataDevices.isleaf(::Tleaf) = true + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + y = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 7fecc81..f3f2596 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,7 @@ end all_files = ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] file_names = BACKEND_GROUP == "all" ? all_files : - (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) + BACKEND_GROUP ∈ ("cpu", "none") ? [] : [BACKEND_GROUP * "_tests.jl"] @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`)