From 44a1deadf2f422ed63e72e62cc18566cc783edfd Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 16 Oct 2024 07:47:06 +0200 Subject: [PATCH 1/9] isleaf --- src/MLDataDevices.jl | 3 +++ src/public.jl | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index edf3b67..e9ef6c0 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -25,4 +25,7 @@ export get_device, get_device_type export DeviceIterator +### uncomment below when min supported julia version is >=1.11 +# public isleaf + end diff --git a/src/public.jl b/src/public.jl index 1dc1646..4111dba 100644 --- a/src/public.jl +++ b/src/public.jl @@ -347,7 +347,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) - Functors.isleaf(x) && return Adapt.adapt(D, x) + isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x) end end @@ -380,3 +380,6 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end + + +isleaf(x) = Functors.isleaf(x) From 7202dbcc187dc276e8193731c965bf6d23733385 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 18 Oct 2024 20:22:15 +0200 Subject: [PATCH 2/9] exclude --- src/public.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/public.jl b/src/public.jl index 4111dba..379aea9 100644 --- a/src/public.jl +++ b/src/public.jl @@ -348,7 +348,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x) + return Functors.fmap(D, x, exclude=isleaf) end end end From 39a4dea304ffaffc602766f05e7442ec35dac90b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 18 Oct 2024 20:46:52 +0200 Subject: [PATCH 3/9] add tests and docs --- Project.toml | 2 ++ src/MLDataDevices.jl | 4 ++-- src/public.jl | 14 ++++++++++++++ test/misc_tests.jl | 19 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 41f3134..fa2489c 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,6 +49,7 @@ AMDGPU = "0.9.6, 1" Adapt = "4" CUDA = "5.2" ChainRulesCore = "1.23" +Compat = "4.16.0" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index e9ef6c0..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -4,6 +4,7 @@ using Adapt: Adapt using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using Compat: @compat abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end @@ -25,7 +26,6 @@ export get_device, get_device_type export DeviceIterator -### uncomment below when min supported julia version is >=1.11 -# public isleaf +@compat(public, (isleaf,)) end diff --git a/src/public.jl b/src/public.jl index 379aea9..7a4b5d5 100644 --- a/src/public.jl +++ b/src/public.jl @@ -381,5 +381,19 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end +""" + isleaf(x) -> Bool + +Returns `true` if `x` is a leaf node in the data structure. + +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior +when an object with nested structure containing the type is transferred to a device. +`Adapt.adapt_structure(::AbstractDevice, x::T)` or +`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during +data movement if `isleaf(x::T) == true`. + +If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. +""" isleaf(x) = Functors.isleaf(x) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index f6ea454..a84b8be 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -157,3 +157,22 @@ end @test get_device(x) isa MLDataDevices.UnknownDevice @test get_device_type(x) <: MLDataDevices.UnknownDevice end + +@testset "isleaf" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x + end + Functors.@functor Tleaf + + MLDataDevices.isleaf(::Tleaf) = true + + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + @test cpu(t).x == 2 .* ones(2) +end From 24a6c97b6cc1904832fc564b82e9efcf6c7580d1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 18 Oct 2024 20:51:37 +0200 Subject: [PATCH 4/9] more tests --- test/misc_tests.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index a84b8be..614cc20 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -167,12 +167,13 @@ end x end Functors.@functor Tleaf - MLDataDevices.isleaf(::Tleaf) = true - Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) cpu = cpu_device() t = Tleaf(ones(2)) - @test cpu(t).x == 2 .* ones(2) + t = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) end From 5c5976731db5a542e8dc4005f54b54d65fac9c1e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 19 Oct 2024 02:14:21 +0200 Subject: [PATCH 5/9] import functors --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 7fecc81..f098f6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ import Pkg using SafeTestsets, Test +using Functors: Functors const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) From ced209b257e3f30671ad90011c6132ce0e31c262 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 19 Oct 2024 02:27:04 +0200 Subject: [PATCH 6/9] fix test --- test/misc_tests.jl | 3 ++- test/runtests.jl | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 614cc20..3da82ff 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -3,6 +3,7 @@ using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using Functors: Functors @testset "Issues Patches" begin @testset "#10 patch" begin @@ -172,7 +173,7 @@ end cpu = cpu_device() t = Tleaf(ones(2)) - t = cpu(t) + y = cpu(t) @test y.x == 2 .* ones(2) y = cpu([(t,)]) @test y[1][1].x == 2 .* ones(2) diff --git a/test/runtests.jl b/test/runtests.jl index f098f6f..bdca4aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ import Pkg using SafeTestsets, Test -using Functors: Functors const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) @@ -24,7 +23,8 @@ end all_files = ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] file_names = BACKEND_GROUP == "all" ? all_files : - (BACKEND_GROUP == "cpu" ? [] : [BACKEND_GROUP * "_tests.jl"]) + BACKEND_GROUP ∈ ("cpu", "none") ? [] : + [BACKEND_GROUP * "_tests.jl"] @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) From ffa0fb8117c6cdf77677d3c336e6e55e3731805b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 19 Oct 2024 16:16:47 -0400 Subject: [PATCH 7/9] chore: reduce min compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa2489c..293c122 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ AMDGPU = "0.9.6, 1" Adapt = "4" CUDA = "5.2" ChainRulesCore = "1.23" -Compat = "4.16.0" +Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" From 8a7bedc32ed1ba823046c80debd113094224944e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 19 Oct 2024 16:20:23 -0400 Subject: [PATCH 8/9] chore: run formatter --- src/public.jl | 8 ++++---- test/misc_tests.jl | 4 ++-- test/runtests.jl | 3 +-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/public.jl b/src/public.jl index 7a4b5d5..281980e 100644 --- a/src/public.jl +++ b/src/public.jl @@ -348,7 +348,7 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) - return Functors.fmap(D, x, exclude=isleaf) + return Functors.fmap(D, x; exclude=isleaf) end end end @@ -384,10 +384,10 @@ end """ isleaf(x) -> Bool -Returns `true` if `x` is a leaf node in the data structure. +Returns `true` if `x` is a leaf node in the data structure. -Defining `MLDataDevices.isleaf(x::T) = true` for custom types -can be used to customize the behavior the data movement behavior +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior when an object with nested structure containing the type is transferred to a device. `Adapt.adapt_structure(::AbstractDevice, x::T)` or diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 3da82ff..942c2ff 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -163,9 +163,9 @@ end # Functors.isleaf fallback @test MLDataDevices.isleaf(rand(2)) @test !MLDataDevices.isleaf((rand(2),)) - + struct Tleaf - x + x::Any end Functors.@functor Tleaf MLDataDevices.isleaf(::Tleaf) = true diff --git a/test/runtests.jl b/test/runtests.jl index bdca4aa..f3f2596 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,8 +23,7 @@ end all_files = ["cuda_tests.jl", "amdgpu_tests.jl", "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] file_names = BACKEND_GROUP == "all" ? all_files : - BACKEND_GROUP ∈ ("cpu", "none") ? [] : - [BACKEND_GROUP * "_tests.jl"] + BACKEND_GROUP ∈ ("cpu", "none") ? [] : [BACKEND_GROUP * "_tests.jl"] @testset "$(file_name)" for file_name in file_names run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) From 015a326f517f4eeb3ad8e7c45dbd0b5ee5aac3c8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 19 Oct 2024 16:20:38 -0400 Subject: [PATCH 9/9] chore: bump version for release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 293c122..7f34fa4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.3.0" +version = "1.4.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"