Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement missing LinearAlgebra wrappers and add support for uplo parameter #51

Merged
merged 8 commits into from
Oct 23, 2023
80 changes: 45 additions & 35 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,19 @@ end
$(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum)))
end

adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Transpose) =
LinearAlgebra.transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Diagonal) =
LinearAlgebra.Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::LinearAlgebra.Symmetric) =
LinearAlgebra.Symmetric(adapt(to, Base.parent(A)))
adapt_structure(to, A::Adjoint) = adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::Transpose) = transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LowerTriangular) = LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitLowerTriangular) = UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UpperTriangular) = UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitUpperTriangular) = UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::Diagonal) = Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::Bidiagonal) = Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo))
adapt_structure(to, A::Tridiagonal) = Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::SymTridiagonal) = SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev))
adapt_structure(to, A::Symmetric) = Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::Hermitian) = Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::UpperHessenberg) = UpperHessenberg(adapt(to, Base.parent(A)))


# we generally don't support multiple layers of wrappers, but some occur often
Expand Down Expand Up @@ -93,15 +88,19 @@ const WrappedArray{T,N,Src,Dst} = Union{
#Base.ReshapedArray{T,N,<:Src},
#Base.ReinterpretArray{T,N,<:Any,<:Src},

LinearAlgebra.Adjoint{T,<:Dst},
LinearAlgebra.Transpose{T,<:Dst},
LinearAlgebra.LowerTriangular{T,<:Dst},
LinearAlgebra.UnitLowerTriangular{T,<:Dst},
LinearAlgebra.UpperTriangular{T,<:Dst},
LinearAlgebra.UnitUpperTriangular{T,<:Dst},
LinearAlgebra.Diagonal{T,<:Dst},
LinearAlgebra.Tridiagonal{T,<:Dst},
LinearAlgebra.Symmetric{T,<:Dst},
Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d
Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst
LowerTriangular{T,<:Dst},
UnitLowerTriangular{T,<:Dst},
UpperTriangular{T,<:Dst},
UnitUpperTriangular{T,<:Dst},
Diagonal{T,<:Src},
Bidiagonal{T,<:Src},
Tridiagonal{T,<:Src},
SymTridiagonal{T,<:Src},
Symmetric{T,<:Dst},
Hermitian{T,<:Dst},
UpperHessenberg{T,<:Dst},

WrappedReinterpretArray{T,N,<:Src},
WrappedReshapedArray{T,N,<:Src},
Expand All @@ -121,20 +120,31 @@ const WrappedArray{T,N,Src,Dst} = Union{

# accessors for extracting information about the wrapper type
ndims(::Type{<:Base.LogicalIndex}) = 1
ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
ndims(::Type{<:LinearAlgebra.Transpose}) = 2
ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
ndims(::Type{<:Adjoint}) = 2
ndims(::Type{<:Transpose}) = 2
ndims(::Type{<:LowerTriangular}) = 2
ndims(::Type{<:UnitLowerTriangular}) = 2
ndims(::Type{<:UpperTriangular}) = 2
ndims(::Type{<:UnitUpperTriangular}) = 2
ndims(::Type{<:Diagonal}) = 2
ndims(::Type{<:Bidiagonal}) = 2
ndims(::Type{<:Tridiagonal}) = 2
ndims(::Type{<:SymTridiagonal}) = 2
ndims(::Type{<:Symmetric}) = 2
ndims(::Type{<:Hermitian}) = 2
ndims(::Type{<:UpperHessenberg}) = 2
ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar

for T in [:(Base.LogicalIndex{<:Any,<:Src}),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}),
:(Adjoint{<:Any,<:Src}),
:(Transpose{<:Any,<:Src}),
:(Diagonal{<:Any,<:Src}),
:(Bidiagonal{<:Any,<:Src}),
:(Tridiagonal{<:Any,<:Src}),
:(SymTridiagonal{<:Any,<:Src}),
:(WrappedReinterpretArray{<:Any,<:Any,<:Src}),
:(WrappedReshapedArray{<:Any,<:Any,<:Src}),
:(WrappedSubArray{<:Any,<:Any,<:Src})]
Expand Down
77 changes: 43 additions & 34 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ macro test_adapt(to, src_expr, dst_expr, typ=nothing)
end
end

AnyCustomArray{T,N} = Union{CustomArray,WrappedArray{T,N,CustomArray,CustomArray{T,N}}}
AnyCustomArray{T,N} = Union{CustomArray{T,N},WrappedArray{T,N,CustomArray,CustomArray{T,N}}}
AnyCustomVector{T} = AnyCustomArray{T,1}
AnyCustomMatrix{T} = AnyCustomArray{T,2}


# basic adaption
Expand Down Expand Up @@ -128,75 +130,82 @@ end

@testset "array wrappers" begin

@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomArray
@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomMatrix
inds = CustomArray{Int,1}([1,2])
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomArray
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomMatrix

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomArray
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomMatrix

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomArray
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomMatrix

@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomArray
@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomVector

@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray
@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomMatrix

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomMatrix
end


## doubly-wrapped

@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomArray
@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomMatrix

@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomMatrix

@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomArray
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomMatrix

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomArray
@test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomMatrix
end


using LinearAlgebra

@test_adapt CustomArray mat.arr' mat' AnyCustomArray
@test_adapt CustomArray mat.arr' mat' AnyCustomMatrix

@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomArray
@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomMatrix

@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomArray
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomArray
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray
@test_adapt CustomArray Symmetric(mat.arr) Symmetric(mat) AnyCustomArray
@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray Symmetric(mat.arr, :U) Symmetric(mat, :U) AnyCustomMatrix
@test_adapt CustomArray Symmetric(mat.arr, :L) Symmetric(mat, :L) AnyCustomMatrix
@test_adapt CustomArray Hermitian(mat.arr, :U) Hermitian(mat, :U) AnyCustomMatrix
@test_adapt CustomArray Hermitian(mat.arr, :L) Hermitian(mat, :L) AnyCustomMatrix
@test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomMatrix

@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray
@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomMatrix

dl = CustomArray{Float64,1}(rand(2))
du = CustomArray{Float64,1}(rand(2))
d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray
@test_adapt CustomArray Bidiagonal(d.arr, du.arr, :U) Bidiagonal(d, du, :U) AnyCustomMatrix
@test_adapt CustomArray Bidiagonal(d.arr, dl.arr, :L) Bidiagonal(d, dl, :L) AnyCustomMatrix
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomMatrix
@test_adapt CustomArray SymTridiagonal(d.arr, du.arr) SymTridiagonal(d, du) AnyCustomMatrix

end


@testset "type information" begin
@test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3
@test Adapt.ndims(Transpose{Float64,Array{Float64,1}}) == 2
@test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == 3

@test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
@test Adapt.parent(Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == Array
end


Expand Down