From f9eda789e88ad2a4e6c542711e374b5af0925d39 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 13 Dec 2023 13:34:08 +0100 Subject: [PATCH] Remove ndims/eltype, and simplify parent type queries. Instead of returning the raw typename, provide the full type so that callers can use type variables, if required. Also provide a version that fully unpeels the array wrapper. --- Project.toml | 2 +- src/wrappers.jl | 68 +++++++++++++++++++++++++----------------------- test/runtests.jl | 13 +++++---- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/Project.toml b/Project.toml index 64cf2a0..4908a68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Adapt" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.7.2" +version = "4.0.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/wrappers.jl b/src/wrappers.jl index ca29c03..77dbc63 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -5,26 +5,26 @@ using LinearAlgebra permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm -export WrappedArray +export WrappedArray, parent_type, unwrap_type adapt_structure(to, A::SubArray) = - SubArray(adapt(to, Base.parent(A)), adapt(to, parentindices(A))) + SubArray(adapt(to, parent(A)), adapt(to, parentindices(A))) function adapt_structure(to, A::PermutedDimsArray) perm = permutation(A) iperm = invperm(perm) - A′ = adapt(to, Base.parent(A)) - PermutedDimsArray{Base.eltype(A′),Base.ndims(A′),perm,iperm,typeof(A′)}(A′) + A′ = adapt(to, parent(A)) + PermutedDimsArray{eltype(A′),ndims(A′),perm,iperm,typeof(A′)}(A′) end adapt_structure(to, A::Base.ReshapedArray) = - Base.reshape(adapt(to, Base.parent(A)), size(A)) + Base.reshape(adapt(to, parent(A)), size(A)) @static if isdefined(Base, :NonReshapedReinterpretArray) adapt_structure(to, A::Base.NonReshapedReinterpretArray) = - Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) + Base.reinterpret(eltype(A), adapt(to, parent(A))) adapt_structure(to, A::Base.ReshapedReinterpretArray) = - Base.reinterpret(reshape, Base.eltype(A), adapt(to, Base.parent(A))) + Base.reinterpret(reshape, eltype(A), adapt(to, parent(A))) else adapt_structure(to, A::Base.ReinterpretArray) = - Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) + Base.reinterpret(eltype(A), adapt(to, parent(A))) end @eval function adapt_structure(to, A::Base.LogicalIndex{T}) where T # prevent re-calculating the count of booleans during LogicalIndex construction @@ -33,23 +33,23 @@ end end adapt_structure(to, A::LinearAlgebra.Adjoint) = - LinearAlgebra.adjoint(adapt(to, Base.parent(A))) + LinearAlgebra.adjoint(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.Transpose) = - LinearAlgebra.transpose(adapt(to, Base.parent(A))) + LinearAlgebra.transpose(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.LowerTriangular) = - LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A))) + LinearAlgebra.LowerTriangular(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) = - LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A))) + LinearAlgebra.UnitLowerTriangular(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.UpperTriangular) = - LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A))) + LinearAlgebra.UpperTriangular(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) = - LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A))) + LinearAlgebra.UnitUpperTriangular(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.Diagonal) = - LinearAlgebra.Diagonal(adapt(to, Base.parent(A))) + LinearAlgebra.Diagonal(adapt(to, parent(A))) adapt_structure(to, A::LinearAlgebra.Tridiagonal) = LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) adapt_structure(to, A::LinearAlgebra.Symmetric) = - LinearAlgebra.Symmetric(adapt(to, Base.parent(A))) + LinearAlgebra.Symmetric(adapt(to, parent(A))) # we generally don't support multiple layers of wrappers, but some occur often @@ -119,20 +119,14 @@ const WrappedArray{T,N,Src,Dst} = Union{ # `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}` for dispatch. # https://github.com/JuliaLang/julia/pull/31563 -# accessors for extracting information about the wrapper type -ndims(::Type{<:Base.LogicalIndex}) = 1 -ndims(::Type{<:LinearAlgebra.Adjoint}) = 2 -ndims(::Type{<:LinearAlgebra.Transpose}) = 2 -ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Symmetric}) = 2 -ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N - -eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar + +""" + parent_type(W::Type{<:WrappedArray}) + +Return the parent type of a wrapped array type. This is the type of the array that is +wrapped by the wrapper, e.g. `parent_type(SubArray{Int, 1, Matrix{Int}}) == Matrix{Int}`. +""" +parent_type(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst for T in [:(Base.LogicalIndex{<:Any,<:Src}), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}), @@ -140,7 +134,17 @@ for T in [:(Base.LogicalIndex{<:Any,<:Src}), :(WrappedReshapedArray{<:Any,<:Any,<:Src}), :(WrappedSubArray{<:Any,<:Any,<:Src})] @eval begin - parent(::Type{<:$T}) where {Src} = Src.name.wrapper + parent_type(::Type{<:$T}) where {Src} = Src end end -parent(::Type{<:WrappedArray{<:Any,<:Any,<:Any,Dst}}) where {Dst} = Dst.name.wrapper + + +""" + unwrap_type(W::Type{<:WrappedArray}) + +Fully unwrap a wrapped array type, i.e., returns the `parent_type` until the result is no +longer a wrapped array type. This is useful for accessing properties of the innermost +array type. +""" +unwrap_type(W::Type{<:WrappedArray}) = unwrap_type(parent_type(W)) +unwrap_type(W::Type{<:AbstractArray}) = W diff --git a/test/runtests.jl b/test/runtests.jl index 65a5b24..55e913c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,6 @@ struct CustomArray{T,N} <: AbstractArray{T,N} arr::Array{T,N} end -CustomArray(x::Array{T,N}) where {T,N} = CustomArray{T,N}(x) Adapt.adapt_storage(::Type{<:CustomArray}, xs::Array) = CustomArray(xs) Base.size(x::CustomArray, y...) = size(x.arr, y...) @@ -58,7 +57,6 @@ AnyCustomArray{T,N} = Union{CustomArray,WrappedArray{T,N,CustomArray,CustomArray struct Wrapper{T} arr::T end -Wrapper(x::T) where T = Wrapper{T}(x) Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr)) @test_adapt CustomArray Wrapper(mat.arr) Wrapper(mat) @@ -192,12 +190,13 @@ end @testset "type information" begin - @test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2 - @test Adapt.ndims(LinearAlgebra.Symmetric{Float64,Matrix{Float64}}) == 2 - @test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3 + # single wrapping + @test parent_type(Transpose{Int,Array{Int,1}}) == Array{Int,1} + @test parent_type(Transpose{Int,Transpose{Int,Array{Int,1}}}) == Transpose{Int,Array{Int,1}} - @test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array - @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array + # double wrapping + @test unwrap_type(Transpose{Int,Array{Int,1}}) == Array{Int,1} + @test unwrap_type(Transpose{Int,Transpose{Int,Array{Int,1}}}) == Array{Int,1} end