From e9a2ed795dc1b4c357ebaf5c5c8b1753896dd235 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 22 Oct 2024 14:54:51 +0200 Subject: [PATCH] fix: handle bitstypes and wrapped arrays in isleaf (#88) * bitstype and wrapped arrays * fixes * fix import * bound * cleanup * chore: fix min version of LinearAlgebra * chore: run formatter --------- Co-authored-by: Avik Pal Co-authored-by: Avik Pal --- Project.toml | 4 ++- src/MLDataDevices.jl | 1 + src/public.jl | 3 +++ test/misc_tests.jl | 59 +++++++++++++++++++++++++++++++++----------- 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 7f34fa4..c85cb0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.0" +version = "1.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -53,6 +54,7 @@ Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" +LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index 108d8bf..c837887 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -5,6 +5,7 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat +using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/src/public.jl b/src/public.jl index 281980e..104a424 100644 --- a/src/public.jl +++ b/src/public.jl @@ -397,3 +397,6 @@ data movement if `isleaf(x::T) == true`. If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. """ isleaf(x) = Functors.isleaf(x) + +isleaf(::AbstractArray{T}) where {T} = isbitstype(T) +isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 942c2ff..9bec386 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -160,21 +160,50 @@ end end @testset "isleaf" begin - # Functors.isleaf fallback - @test MLDataDevices.isleaf(rand(2)) - @test !MLDataDevices.isleaf((rand(2),)) + @testset "basics" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x::Any + end + Functors.@functor Tleaf + MLDataDevices.isleaf(::Tleaf) = true + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + y = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) + end + + @testset "shared parameters" begin + # from + x = rand(1) + m = (; a=x, b=x') + count = Ref(0) + mcopy = Functors.fmap(m; exclude=MLDataDevices.isleaf) do x + count[] += 1 + return copy(x) + end + @test count[] == 1 + @test mcopy.a === mcopy.b' + end - struct Tleaf - x::Any + @testset "bitstypes and wrapped types" begin + struct BitsType + x::Int32 + y::Float64 + end + + for x in [1.0, 'a', BitsType(1, 2.0)] + @test MLDataDevices.isleaf([x]) + @test !MLDataDevices.isleaf([x]') + @test !MLDataDevices.isleaf(transpose([x])) + @test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2))) + end end - Functors.@functor Tleaf - MLDataDevices.isleaf(::Tleaf) = true - Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) - - cpu = cpu_device() - t = Tleaf(ones(2)) - y = cpu(t) - @test y.x == 2 .* ones(2) - y = cpu([(t,)]) - @test y[1][1].x == 2 .* ones(2) end