Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Implement missing LinearAlgebra wrappers and add support for uplo parameter" #70

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 35 additions & 45 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,24 @@ end
$(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum)))
end

adapt_structure(to, A::Adjoint) = adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::Transpose) = transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LowerTriangular) = LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitLowerTriangular) = UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UpperTriangular) = UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::UnitUpperTriangular) = UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::Diagonal) = Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::Bidiagonal) = Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo))
adapt_structure(to, A::Tridiagonal) = Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::SymTridiagonal) = SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev))
adapt_structure(to, A::Symmetric) = Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::Hermitian) = Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo))
adapt_structure(to, A::UpperHessenberg) = UpperHessenberg(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Transpose) =
LinearAlgebra.transpose(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.LowerTriangular) =
LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) =
LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UpperTriangular) =
LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) =
LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Diagonal) =
LinearAlgebra.Diagonal(adapt(to, Base.parent(A)))
adapt_structure(to, A::LinearAlgebra.Tridiagonal) =
LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du))
adapt_structure(to, A::LinearAlgebra.Symmetric) =
LinearAlgebra.Symmetric(adapt(to, Base.parent(A)))


# we generally don't support multiple layers of wrappers, but some occur often
Expand Down Expand Up @@ -88,19 +93,15 @@ const WrappedArray{T,N,Src,Dst} = Union{
#Base.ReshapedArray{T,N,<:Src},
#Base.ReinterpretArray{T,N,<:Any,<:Src},

Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d
Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst
LowerTriangular{T,<:Dst},
UnitLowerTriangular{T,<:Dst},
UpperTriangular{T,<:Dst},
UnitUpperTriangular{T,<:Dst},
Diagonal{T,<:Src},
Bidiagonal{T,<:Src},
Tridiagonal{T,<:Src},
SymTridiagonal{T,<:Src},
Symmetric{T,<:Dst},
Hermitian{T,<:Dst},
UpperHessenberg{T,<:Dst},
LinearAlgebra.Adjoint{T,<:Dst},
LinearAlgebra.Transpose{T,<:Dst},
LinearAlgebra.LowerTriangular{T,<:Dst},
LinearAlgebra.UnitLowerTriangular{T,<:Dst},
LinearAlgebra.UpperTriangular{T,<:Dst},
LinearAlgebra.UnitUpperTriangular{T,<:Dst},
LinearAlgebra.Diagonal{T,<:Dst},
LinearAlgebra.Tridiagonal{T,<:Dst},
LinearAlgebra.Symmetric{T,<:Dst},

WrappedReinterpretArray{T,N,<:Src},
WrappedReshapedArray{T,N,<:Src},
Expand All @@ -120,31 +121,20 @@ const WrappedArray{T,N,Src,Dst} = Union{

# accessors for extracting information about the wrapper type
ndims(::Type{<:Base.LogicalIndex}) = 1
ndims(::Type{<:Adjoint}) = 2
ndims(::Type{<:Transpose}) = 2
ndims(::Type{<:LowerTriangular}) = 2
ndims(::Type{<:UnitLowerTriangular}) = 2
ndims(::Type{<:UpperTriangular}) = 2
ndims(::Type{<:UnitUpperTriangular}) = 2
ndims(::Type{<:Diagonal}) = 2
ndims(::Type{<:Bidiagonal}) = 2
ndims(::Type{<:Tridiagonal}) = 2
ndims(::Type{<:SymTridiagonal}) = 2
ndims(::Type{<:Symmetric}) = 2
ndims(::Type{<:Hermitian}) = 2
ndims(::Type{<:UpperHessenberg}) = 2
ndims(::Type{<:LinearAlgebra.Adjoint}) = 2
ndims(::Type{<:LinearAlgebra.Transpose}) = 2
ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2
ndims(::Type{<:LinearAlgebra.Diagonal}) = 2
ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2
ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N

eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar

for T in [:(Base.LogicalIndex{<:Any,<:Src}),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}),
:(Adjoint{<:Any,<:Src}),
:(Transpose{<:Any,<:Src}),
:(Diagonal{<:Any,<:Src}),
:(Bidiagonal{<:Any,<:Src}),
:(Tridiagonal{<:Any,<:Src}),
:(SymTridiagonal{<:Any,<:Src}),
:(WrappedReinterpretArray{<:Any,<:Any,<:Src}),
:(WrappedReshapedArray{<:Any,<:Any,<:Src}),
:(WrappedSubArray{<:Any,<:Any,<:Src})]
Expand Down
77 changes: 34 additions & 43 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ macro test_adapt(to, src_expr, dst_expr, typ=nothing)
end
end

AnyCustomArray{T,N} = Union{CustomArray{T,N},WrappedArray{T,N,CustomArray,CustomArray{T,N}}}
AnyCustomVector{T} = AnyCustomArray{T,1}
AnyCustomMatrix{T} = AnyCustomArray{T,2}
AnyCustomArray{T,N} = Union{CustomArray,WrappedArray{T,N,CustomArray,CustomArray{T,N}}}


# basic adaption
Expand Down Expand Up @@ -130,82 +128,75 @@ end

@testset "array wrappers" begin

@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomMatrix
@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomArray
inds = CustomArray{Int,1}([1,2])
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomMatrix
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomArray

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomMatrix
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomArray

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomMatrix
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomArray

@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomVector
@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomArray

@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomMatrix
@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomMatrix
@test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray
end


## doubly-wrapped

@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomMatrix
@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomArray

@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomMatrix
@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomArray
@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomArray

@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomArray
@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray
@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray

@static if isdefined(Base, :NonReshapedReinterpretArray)
@test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomMatrix
@test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomMatrix
@test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomArray
@test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomArray
end


using LinearAlgebra

@test_adapt CustomArray mat.arr' mat' AnyCustomMatrix
@test_adapt CustomArray mat.arr' mat' AnyCustomArray

@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomMatrix
@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomArray

@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomMatrix
@test_adapt CustomArray Symmetric(mat.arr, :U) Symmetric(mat, :U) AnyCustomMatrix
@test_adapt CustomArray Symmetric(mat.arr, :L) Symmetric(mat, :L) AnyCustomMatrix
@test_adapt CustomArray Hermitian(mat.arr, :U) Hermitian(mat, :U) AnyCustomMatrix
@test_adapt CustomArray Hermitian(mat.arr, :L) Hermitian(mat, :L) AnyCustomMatrix
@test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomMatrix
@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomArray
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomArray
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray
@test_adapt CustomArray Symmetric(mat.arr) Symmetric(mat) AnyCustomArray

@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomMatrix
@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray

dl = CustomArray{Float64,1}(rand(2))
du = CustomArray{Float64,1}(rand(2))
d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Bidiagonal(d.arr, du.arr, :U) Bidiagonal(d, du, :U) AnyCustomMatrix
@test_adapt CustomArray Bidiagonal(d.arr, dl.arr, :L) Bidiagonal(d, dl, :L) AnyCustomMatrix
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomMatrix
@test_adapt CustomArray SymTridiagonal(d.arr, du.arr) SymTridiagonal(d, du) AnyCustomMatrix
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray

end


@testset "type information" begin
@test Adapt.ndims(Transpose{Float64,Array{Float64,1}}) == 2
@test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == 3
@test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2
@test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3

@test Adapt.parent(Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == Array
@test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
end


Expand Down