diff --git a/Project.toml b/Project.toml index 41f3134..fa2489c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1" Adapt = "4" CUDA = "5.2" ChainRulesCore = "1.23" +Compat = "4.16.0" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index e9ef6c0..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -4,6 +4,7 @@ using Adapt: Adapt using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using Compat: @compat abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end @@ -25,7 +26,6 @@ export get_device, get_device_type export DeviceIterator -### uncomment below when min supported julia version is >=1.11 -# public isleaf +@compat(public, (isleaf,)) end diff --git a/src/public.jl b/src/public.jl index 379aea9..7a4b5d5 100644 --- a/src/public.jl +++ b/src/public.jl @@ -381,5 +381,19 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end +""" + isleaf(x) -> Bool + +Returns `true` if `x` is a leaf node in the data structure. + +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior +when an object with nested structure containing the type is transferred to a device. +`Adapt.adapt_structure(::AbstractDevice, x::T)` or +`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during +data movement if `isleaf(x::T) == true`. + +If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. +""" isleaf(x) = Functors.isleaf(x) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index f6ea454..a84b8be 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -157,3 +157,22 @@ end @test get_device(x) isa MLDataDevices.UnknownDevice @test get_device_type(x) <: MLDataDevices.UnknownDevice end + +@testset "isleaf" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x + end + Functors.@functor Tleaf + + MLDataDevices.isleaf(::Tleaf) = true + + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + @test cpu(t).x == 2 .* ones(2) +end