diff --git a/src/wrappers.jl b/src/wrappers.jl index f2ce2c6..7145ff9 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -32,32 +32,19 @@ end $(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum))) end -adapt_structure(to, A::LinearAlgebra.Adjoint) = - LinearAlgebra.adjoint(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Transpose) = - LinearAlgebra.transpose(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.LowerTriangular) = - LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) = - LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UpperTriangular) = - LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) = - LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Diagonal) = - LinearAlgebra.Diagonal(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Bidiagonal) = - LinearAlgebra.Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.Tridiagonal) = - LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) -adapt_structure(to, A::LinearAlgebra.SymTridiagonal) = - LinearAlgebra.SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) -adapt_structure(to, A::LinearAlgebra.Symmetric) = - LinearAlgebra.Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.Hermitian) = - LinearAlgebra.Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.UpperHessenberg) = - LinearAlgebra.UpperHessenberg(adapt(to, Base.parent(A))) +adapt_structure(to, A::Adjoint) = adjoint(adapt(to, Base.parent(A))) +adapt_structure(to, A::Transpose) = transpose(adapt(to, Base.parent(A))) +adapt_structure(to, A::LowerTriangular) = LowerTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UnitLowerTriangular) = UnitLowerTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UpperTriangular) = UpperTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UnitUpperTriangular) = UnitUpperTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::Diagonal) = Diagonal(adapt(to, Base.parent(A))) +adapt_structure(to, A::Bidiagonal) = Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo)) +adapt_structure(to, A::Tridiagonal) = Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) +adapt_structure(to, A::SymTridiagonal) = SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) +adapt_structure(to, A::Symmetric) = Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo)) +adapt_structure(to, A::Hermitian) = Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo)) +adapt_structure(to, A::UpperHessenberg) = UpperHessenberg(adapt(to, Base.parent(A))) # we generally don't support multiple layers of wrappers, but some occur often @@ -101,19 +88,19 @@ const WrappedArray{T,N,Src,Dst} = Union{ #Base.ReshapedArray{T,N,<:Src}, #Base.ReinterpretArray{T,N,<:Any,<:Src}, - LinearAlgebra.Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d - LinearAlgebra.Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst - LinearAlgebra.LowerTriangular{T,<:Dst}, - LinearAlgebra.UnitLowerTriangular{T,<:Dst}, - LinearAlgebra.UpperTriangular{T,<:Dst}, - LinearAlgebra.UnitUpperTriangular{T,<:Dst}, - LinearAlgebra.Diagonal{T,<:Src}, - LinearAlgebra.Bidiagonal{T,<:Src}, - LinearAlgebra.Tridiagonal{T,<:Src}, - LinearAlgebra.SymTridiagonal{T,<:Src}, - LinearAlgebra.Symmetric{T,<:Dst}, - LinearAlgebra.Hermitian{T,<:Dst}, - LinearAlgebra.UpperHessenberg{T,<:Dst}, + Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d + Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst + LowerTriangular{T,<:Dst}, + UnitLowerTriangular{T,<:Dst}, + UpperTriangular{T,<:Dst}, + UnitUpperTriangular{T,<:Dst}, + Diagonal{T,<:Src}, + Bidiagonal{T,<:Src}, + Tridiagonal{T,<:Src}, + SymTridiagonal{T,<:Src}, + Symmetric{T,<:Dst}, + Hermitian{T,<:Dst}, + UpperHessenberg{T,<:Dst}, WrappedReinterpretArray{T,N,<:Src}, WrappedReshapedArray{T,N,<:Src}, @@ -133,31 +120,31 @@ const WrappedArray{T,N,Src,Dst} = Union{ # accessors for extracting information about the wrapper type ndims(::Type{<:Base.LogicalIndex}) = 1 -ndims(::Type{<:LinearAlgebra.Adjoint}) = 2 -ndims(::Type{<:LinearAlgebra.Transpose}) = 2 -ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Bidiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.SymTridiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Symmetric}) = 2 -ndims(::Type{<:LinearAlgebra.Hermitian}) = 2 -ndims(::Type{<:LinearAlgebra.UpperHessenberg}) = 2 +ndims(::Type{<:Adjoint}) = 2 +ndims(::Type{<:Transpose}) = 2 +ndims(::Type{<:LowerTriangular}) = 2 +ndims(::Type{<:UnitLowerTriangular}) = 2 +ndims(::Type{<:UpperTriangular}) = 2 +ndims(::Type{<:UnitUpperTriangular}) = 2 +ndims(::Type{<:Diagonal}) = 2 +ndims(::Type{<:Bidiagonal}) = 2 +ndims(::Type{<:Tridiagonal}) = 2 +ndims(::Type{<:SymTridiagonal}) = 2 +ndims(::Type{<:Symmetric}) = 2 +ndims(::Type{<:Hermitian}) = 2 +ndims(::Type{<:UpperHessenberg}) = 2 ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar for T in [:(Base.LogicalIndex{<:Any,<:Src}), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}), - :(LinearAlgebra.Adjoint{<:Any,<:Src}), - :(LinearAlgebra.Transpose{<:Any,<:Src}), - :(LinearAlgebra.Diagonal{<:Any,<:Src}), - :(LinearAlgebra.Bidiagonal{<:Any,<:Src}), - :(LinearAlgebra.Tridiagonal{<:Any,<:Src}), - :(LinearAlgebra.SymTridiagonal{<:Any,<:Src}), + :(Adjoint{<:Any,<:Src}), + :(Transpose{<:Any,<:Src}), + :(Diagonal{<:Any,<:Src}), + :(Bidiagonal{<:Any,<:Src}), + :(Tridiagonal{<:Any,<:Src}), + :(SymTridiagonal{<:Any,<:Src}), :(WrappedReinterpretArray{<:Any,<:Any,<:Src}), :(WrappedReshapedArray{<:Any,<:Any,<:Src}), :(WrappedSubArray{<:Any,<:Any,<:Src})] diff --git a/test/runtests.jl b/test/runtests.jl index f0c3cd4..cb6121a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -201,10 +201,10 @@ end @testset "type information" begin - @test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2 + @test Adapt.ndims(Transpose{Float64,Array{Float64,1}}) == 2 @test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3 - @test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array + @test Adapt.parent(Transpose{Float64,Array{Float64,1}}) == Array @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array end