From bcec0e04d33ed13ba02b49b054ce931a8e10e9ec Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 18 Mar 2022 15:46:43 -0700 Subject: [PATCH 1/8] Add tests for missing LinearAlgebra wrappers Hermitian, UpperHessenberg, Bidiagonal, SymTridiagonal --- test/runtests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 11b4a6e..0c30ce8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -180,13 +180,18 @@ using LinearAlgebra @test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray @test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray @test_adapt CustomArray Symmetric(mat.arr) Symmetric(mat) AnyCustomArray +@test_adapt CustomArray Hermitian(mat.arr) Hermitian(mat) AnyCustomArray +@test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomArray @test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray dl = CustomArray{Float64,1}(rand(2)) du = CustomArray{Float64,1}(rand(2)) d = CustomArray{Float64,1}(rand(3)) +@test_adapt CustomArray Bidiagonal(d.arr, du.arr, :U) Bidiagonal(d, du, :U) AnyCustomArray +@test_adapt CustomArray Bidiagonal(d.arr, dl.arr, :L) Bidiagonal(d, dl, :L) AnyCustomArray @test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray +@test_adapt CustomArray SymTridiagonal(d.arr, du.arr) SymTridiagonal(d, du) AnyCustomArray end From 8c6bf9b585e026e2bddf4b97af94c90e56690508 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 18 Mar 2022 15:48:25 -0700 Subject: [PATCH 2/8] Add support for missing LinearAlgebra wrappers Hermitian, UpperHessenberg, Bidiagonal, SymTridiagonal Closes #46 --- src/wrappers.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/wrappers.jl b/src/wrappers.jl index 7bb0e89..3842019 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -46,10 +46,18 @@ adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) = LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A))) adapt_structure(to, A::LinearAlgebra.Diagonal) = LinearAlgebra.Diagonal(adapt(to, Base.parent(A))) +adapt_structure(to, A::LinearAlgebra.Bidiagonal) = + LinearAlgebra.Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo)) adapt_structure(to, A::LinearAlgebra.Tridiagonal) = LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) +adapt_structure(to, A::LinearAlgebra.SymTridiagonal) = + LinearAlgebra.SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) adapt_structure(to, A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(adapt(to, Base.parent(A))) +adapt_structure(to, A::LinearAlgebra.Hermitian) = + LinearAlgebra.Hermitian(adapt(to, Base.parent(A))) +adapt_structure(to, A::LinearAlgebra.UpperHessenberg) = + LinearAlgebra.UpperHessenberg(adapt(to, Base.parent(A))) # we generally don't support multiple layers of wrappers, but some occur often @@ -100,8 +108,12 @@ const WrappedArray{T,N,Src,Dst} = Union{ LinearAlgebra.UpperTriangular{T,<:Dst}, LinearAlgebra.UnitUpperTriangular{T,<:Dst}, LinearAlgebra.Diagonal{T,<:Dst}, + LinearAlgebra.Bidiagonal{T,<:Dst}, LinearAlgebra.Tridiagonal{T,<:Dst}, + LinearAlgebra.SymTridiagonal{T,<:Dst}, LinearAlgebra.Symmetric{T,<:Dst}, + LinearAlgebra.Hermitian{T,<:Dst}, + LinearAlgebra.UpperHessenberg{T,<:Dst}, WrappedReinterpretArray{T,N,<:Src}, WrappedReshapedArray{T,N,<:Src}, @@ -128,7 +140,12 @@ ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 +ndims(::Type{<:LinearAlgebra.Bidiagonal}) = 2 ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 +ndims(::Type{<:LinearAlgebra.SymTridiagonal}) = 2 +ndims(::Type{<:LinearAlgebra.Symmetric}) = 2 +ndims(::Type{<:LinearAlgebra.Hermitian}) = 2 +ndims(::Type{<:LinearAlgebra.UpperHessenberg}) = 2 ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar From f57fd8655f5f5110f3199ec051865128b40ed898 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 18 Mar 2022 15:49:47 -0700 Subject: [PATCH 3/8] Add tests for Symmetric and Hermitian uplo --- test/runtests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0c30ce8..d233881 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -179,8 +179,10 @@ using LinearAlgebra @test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomArray @test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray @test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray -@test_adapt CustomArray Symmetric(mat.arr) Symmetric(mat) AnyCustomArray -@test_adapt CustomArray Hermitian(mat.arr) Hermitian(mat) AnyCustomArray +@test_adapt CustomArray Symmetric(mat.arr, :U) Symmetric(mat, :U) AnyCustomArray +@test_adapt CustomArray Symmetric(mat.arr, :L) Symmetric(mat, :L) AnyCustomArray +@test_adapt CustomArray Hermitian(mat.arr, :U) Hermitian(mat, :U) AnyCustomArray +@test_adapt CustomArray Hermitian(mat.arr, :L) Hermitian(mat, :L) AnyCustomArray @test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomArray @test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray From 5f3a6f11e8ade7e07fdbeeb28f66576e2e108436 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Fri, 18 Mar 2022 15:50:45 -0700 Subject: [PATCH 4/8] Add support for Symmetric and Hermitian uplo --- src/wrappers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index 3842019..1a286bd 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -53,9 +53,9 @@ adapt_structure(to, A::LinearAlgebra.Tridiagonal) = adapt_structure(to, A::LinearAlgebra.SymTridiagonal) = LinearAlgebra.SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) adapt_structure(to, A::LinearAlgebra.Symmetric) = - LinearAlgebra.Symmetric(adapt(to, Base.parent(A))) + LinearAlgebra.Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo)) adapt_structure(to, A::LinearAlgebra.Hermitian) = - LinearAlgebra.Hermitian(adapt(to, Base.parent(A))) + LinearAlgebra.Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo)) adapt_structure(to, A::LinearAlgebra.UpperHessenberg) = LinearAlgebra.UpperHessenberg(adapt(to, Base.parent(A))) From 99bcb01a8dd794a3d4bc73931b9c5c60be24824d Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 24 Mar 2022 00:20:23 -0700 Subject: [PATCH 5/8] Be specific about dimensionality in tests Highlights inconsistent usage of Src/Dst in implementation --- test/runtests.jl | 76 +++++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d233881..f0c3cd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,7 +41,9 @@ macro test_adapt(to, src_expr, dst_expr, typ=nothing) end end -AnyCustomArray{T,N} = Union{CustomArray,WrappedArray{T,N,CustomArray,CustomArray{T,N}}} +AnyCustomArray{T,N} = Union{CustomArray{T,N},WrappedArray{T,N,CustomArray,CustomArray{T,N}}} +AnyCustomVector{T} = AnyCustomArray{T,1} +AnyCustomMatrix{T} = AnyCustomArray{T,2} # basic adaption @@ -128,72 +130,72 @@ end @testset "array wrappers" begin -@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomArray +@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:) AnyCustomMatrix inds = CustomArray{Int,1}([1,2]) -@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomArray +@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:) AnyCustomMatrix # NOTE: manual creation of PermutedDimsArray because permutedims collects -@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomArray +@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1)) AnyCustomMatrix # NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape` -@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomArray +@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) AnyCustomMatrix -@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomArray +@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) AnyCustomVector -@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomArray +@test_adapt CustomArray reinterpret(Int64,mat.arr) reinterpret(Int64,mat) AnyCustomMatrix @static if isdefined(Base, :NonReshapedReinterpretArray) - @test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomArray + @test_adapt CustomArray reinterpret(reshape,Int64,mat.arr) reinterpret(reshape,Int64,mat) AnyCustomMatrix end ## doubly-wrapped -@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomArray +@test_adapt CustomArray reinterpret(Int64,view(mat.arr,:,:)) reinterpret(Int64,view(mat,:,:)) AnyCustomMatrix -@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomArray -@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomArray -@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomArray +@test_adapt CustomArray reshape(view(mat.arr,:,:), (2,2)) reshape(view(mat,:,:), (2,2)) AnyCustomMatrix +@test_adapt CustomArray reshape(reinterpret(Int64,mat.arr), (2,2)) reshape(reinterpret(Int64,mat), (2,2)) AnyCustomMatrix +@test_adapt CustomArray reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)) reshape(reinterpret(Int64,view(mat,:,:)), (2,2)) AnyCustomMatrix -@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomArray -@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomArray -@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomArray -@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomArray -@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomArray -@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomArray +@test_adapt CustomArray view(reinterpret(Int64,mat.arr), :, :) view(reinterpret(Int64,mat), :, :) AnyCustomMatrix +@test_adapt CustomArray view(reinterpret(Int64,view(mat.arr,:,:)), :, :) view(reinterpret(Int64,view(mat,:,:)), :, :) AnyCustomMatrix +@test_adapt CustomArray view(Base.ReshapedArray(mat.arr,(2,2),()), :, :) view(reshape(mat, (2,2)), :, :) AnyCustomMatrix +@test_adapt CustomArray view(reshape(view(mat.arr,:,:), (2,2)), :, :) view(reshape(view(mat,:,:), (2,2)), :, :) AnyCustomMatrix +@test_adapt CustomArray view(reshape(reinterpret(Int64,mat.arr), (2,2)), :, :) view(reshape(reinterpret(Int64,mat), (2,2)), :, :) AnyCustomMatrix +@test_adapt CustomArray view(reshape(reinterpret(Int64,view(mat.arr,:,:)), (2,2)), :, :) view(reshape(reinterpret(Int64,view(mat,:,:)), (2,2)), :, :) AnyCustomMatrix @static if isdefined(Base, :NonReshapedReinterpretArray) - @test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomArray - @test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomArray - @test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomArray + @test_adapt CustomArray reinterpret(reshape,Int64,view(mat.arr,:,:)) reinterpret(reshape,Int64,view(mat,:,:)) AnyCustomMatrix + @test_adapt CustomArray view(reinterpret(reshape,Int64,mat.arr), :, :) view(reinterpret(reshape,Int64,mat), :, :) AnyCustomMatrix + @test_adapt CustomArray view(reinterpret(reshape,Int64,view(mat.arr,:,:)), :, :) view(reinterpret(reshape,Int64,view(mat,:,:)), :, :) AnyCustomMatrix end using LinearAlgebra -@test_adapt CustomArray mat.arr' mat' AnyCustomArray +@test_adapt CustomArray mat.arr' mat' AnyCustomMatrix -@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomArray +@test_adapt CustomArray transpose(mat.arr) transpose(mat) AnyCustomMatrix -@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomArray -@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomArray -@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomArray -@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomArray -@test_adapt CustomArray Symmetric(mat.arr, :U) Symmetric(mat, :U) AnyCustomArray -@test_adapt CustomArray Symmetric(mat.arr, :L) Symmetric(mat, :L) AnyCustomArray -@test_adapt CustomArray Hermitian(mat.arr, :U) Hermitian(mat, :U) AnyCustomArray -@test_adapt CustomArray Hermitian(mat.arr, :L) Hermitian(mat, :L) AnyCustomArray -@test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomArray +@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat) AnyCustomMatrix +@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat) AnyCustomMatrix +@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat) AnyCustomMatrix +@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat) AnyCustomMatrix +@test_adapt CustomArray Symmetric(mat.arr, :U) Symmetric(mat, :U) AnyCustomMatrix +@test_adapt CustomArray Symmetric(mat.arr, :L) Symmetric(mat, :L) AnyCustomMatrix +@test_adapt CustomArray Hermitian(mat.arr, :U) Hermitian(mat, :U) AnyCustomMatrix +@test_adapt CustomArray Hermitian(mat.arr, :L) Hermitian(mat, :L) AnyCustomMatrix +@test_adapt CustomArray UpperHessenberg(mat.arr) UpperHessenberg(mat) AnyCustomMatrix -@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomArray +@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec) AnyCustomMatrix dl = CustomArray{Float64,1}(rand(2)) du = CustomArray{Float64,1}(rand(2)) d = CustomArray{Float64,1}(rand(3)) -@test_adapt CustomArray Bidiagonal(d.arr, du.arr, :U) Bidiagonal(d, du, :U) AnyCustomArray -@test_adapt CustomArray Bidiagonal(d.arr, dl.arr, :L) Bidiagonal(d, dl, :L) AnyCustomArray -@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomArray -@test_adapt CustomArray SymTridiagonal(d.arr, du.arr) SymTridiagonal(d, du) AnyCustomArray +@test_adapt CustomArray Bidiagonal(d.arr, du.arr, :U) Bidiagonal(d, du, :U) AnyCustomMatrix +@test_adapt CustomArray Bidiagonal(d.arr, dl.arr, :L) Bidiagonal(d, dl, :L) AnyCustomMatrix +@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du) AnyCustomMatrix +@test_adapt CustomArray SymTridiagonal(d.arr, du.arr) SymTridiagonal(d, du) AnyCustomMatrix end From b8e580957b44ce88e955606d96dc9863698e265e Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 24 Mar 2022 00:25:17 -0700 Subject: [PATCH 6/8] Correct usage of Src/Dst --- src/wrappers.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index 1a286bd..f2ce2c6 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -101,16 +101,16 @@ const WrappedArray{T,N,Src,Dst} = Union{ #Base.ReshapedArray{T,N,<:Src}, #Base.ReinterpretArray{T,N,<:Any,<:Src}, - LinearAlgebra.Adjoint{T,<:Dst}, - LinearAlgebra.Transpose{T,<:Dst}, + LinearAlgebra.Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d + LinearAlgebra.Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst LinearAlgebra.LowerTriangular{T,<:Dst}, LinearAlgebra.UnitLowerTriangular{T,<:Dst}, LinearAlgebra.UpperTriangular{T,<:Dst}, LinearAlgebra.UnitUpperTriangular{T,<:Dst}, - LinearAlgebra.Diagonal{T,<:Dst}, - LinearAlgebra.Bidiagonal{T,<:Dst}, - LinearAlgebra.Tridiagonal{T,<:Dst}, - LinearAlgebra.SymTridiagonal{T,<:Dst}, + LinearAlgebra.Diagonal{T,<:Src}, + LinearAlgebra.Bidiagonal{T,<:Src}, + LinearAlgebra.Tridiagonal{T,<:Src}, + LinearAlgebra.SymTridiagonal{T,<:Src}, LinearAlgebra.Symmetric{T,<:Dst}, LinearAlgebra.Hermitian{T,<:Dst}, LinearAlgebra.UpperHessenberg{T,<:Dst}, @@ -152,6 +152,12 @@ eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar for T in [:(Base.LogicalIndex{<:Any,<:Src}), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}), + :(LinearAlgebra.Adjoint{<:Any,<:Src}), + :(LinearAlgebra.Transpose{<:Any,<:Src}), + :(LinearAlgebra.Diagonal{<:Any,<:Src}), + :(LinearAlgebra.Bidiagonal{<:Any,<:Src}), + :(LinearAlgebra.Tridiagonal{<:Any,<:Src}), + :(LinearAlgebra.SymTridiagonal{<:Any,<:Src}), :(WrappedReinterpretArray{<:Any,<:Any,<:Src}), :(WrappedReshapedArray{<:Any,<:Any,<:Src}), :(WrappedSubArray{<:Any,<:Any,<:Src})] From b6858e05470e9e045d13dc4c5199b481aa18b547 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 24 Mar 2022 00:26:30 -0700 Subject: [PATCH 7/8] Drop redundant LinearAlgebra qualifier --- src/wrappers.jl | 103 +++++++++++++++++++++-------------------------- test/runtests.jl | 4 +- 2 files changed, 47 insertions(+), 60 deletions(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index f2ce2c6..7145ff9 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -32,32 +32,19 @@ end $(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum))) end -adapt_structure(to, A::LinearAlgebra.Adjoint) = - LinearAlgebra.adjoint(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Transpose) = - LinearAlgebra.transpose(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.LowerTriangular) = - LinearAlgebra.LowerTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UnitLowerTriangular) = - LinearAlgebra.UnitLowerTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UpperTriangular) = - LinearAlgebra.UpperTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.UnitUpperTriangular) = - LinearAlgebra.UnitUpperTriangular(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Diagonal) = - LinearAlgebra.Diagonal(adapt(to, Base.parent(A))) -adapt_structure(to, A::LinearAlgebra.Bidiagonal) = - LinearAlgebra.Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.Tridiagonal) = - LinearAlgebra.Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) -adapt_structure(to, A::LinearAlgebra.SymTridiagonal) = - LinearAlgebra.SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) -adapt_structure(to, A::LinearAlgebra.Symmetric) = - LinearAlgebra.Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.Hermitian) = - LinearAlgebra.Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo)) -adapt_structure(to, A::LinearAlgebra.UpperHessenberg) = - LinearAlgebra.UpperHessenberg(adapt(to, Base.parent(A))) +adapt_structure(to, A::Adjoint) = adjoint(adapt(to, Base.parent(A))) +adapt_structure(to, A::Transpose) = transpose(adapt(to, Base.parent(A))) +adapt_structure(to, A::LowerTriangular) = LowerTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UnitLowerTriangular) = UnitLowerTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UpperTriangular) = UpperTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::UnitUpperTriangular) = UnitUpperTriangular(adapt(to, Base.parent(A))) +adapt_structure(to, A::Diagonal) = Diagonal(adapt(to, Base.parent(A))) +adapt_structure(to, A::Bidiagonal) = Bidiagonal(adapt(to, A.dv), adapt(to, A.ev), Symbol(A.uplo)) +adapt_structure(to, A::Tridiagonal) = Tridiagonal(adapt(to, A.dl), adapt(to, A.d), adapt(to, A.du)) +adapt_structure(to, A::SymTridiagonal) = SymTridiagonal(adapt(to, A.dv), adapt(to, A.ev)) +adapt_structure(to, A::Symmetric) = Symmetric(adapt(to, Base.parent(A)), Symbol(A.uplo)) +adapt_structure(to, A::Hermitian) = Hermitian(adapt(to, Base.parent(A)), Symbol(A.uplo)) +adapt_structure(to, A::UpperHessenberg) = UpperHessenberg(adapt(to, Base.parent(A))) # we generally don't support multiple layers of wrappers, but some occur often @@ -101,19 +88,19 @@ const WrappedArray{T,N,Src,Dst} = Union{ #Base.ReshapedArray{T,N,<:Src}, #Base.ReinterpretArray{T,N,<:Any,<:Src}, - LinearAlgebra.Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d - LinearAlgebra.Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst - LinearAlgebra.LowerTriangular{T,<:Dst}, - LinearAlgebra.UnitLowerTriangular{T,<:Dst}, - LinearAlgebra.UpperTriangular{T,<:Dst}, - LinearAlgebra.UnitUpperTriangular{T,<:Dst}, - LinearAlgebra.Diagonal{T,<:Src}, - LinearAlgebra.Bidiagonal{T,<:Src}, - LinearAlgebra.Tridiagonal{T,<:Src}, - LinearAlgebra.SymTridiagonal{T,<:Src}, - LinearAlgebra.Symmetric{T,<:Dst}, - LinearAlgebra.Hermitian{T,<:Dst}, - LinearAlgebra.UpperHessenberg{T,<:Dst}, + Adjoint{T,<:Src}, # The adjoint/transpose of a Vector has shape 1xN, so is a 2d + Transpose{T,<:Src}, # wrapper around a 1d array, hence use Src not Dst + LowerTriangular{T,<:Dst}, + UnitLowerTriangular{T,<:Dst}, + UpperTriangular{T,<:Dst}, + UnitUpperTriangular{T,<:Dst}, + Diagonal{T,<:Src}, + Bidiagonal{T,<:Src}, + Tridiagonal{T,<:Src}, + SymTridiagonal{T,<:Src}, + Symmetric{T,<:Dst}, + Hermitian{T,<:Dst}, + UpperHessenberg{T,<:Dst}, WrappedReinterpretArray{T,N,<:Src}, WrappedReshapedArray{T,N,<:Src}, @@ -133,31 +120,31 @@ const WrappedArray{T,N,Src,Dst} = Union{ # accessors for extracting information about the wrapper type ndims(::Type{<:Base.LogicalIndex}) = 1 -ndims(::Type{<:LinearAlgebra.Adjoint}) = 2 -ndims(::Type{<:LinearAlgebra.Transpose}) = 2 -ndims(::Type{<:LinearAlgebra.LowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitLowerTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.UnitUpperTriangular}) = 2 -ndims(::Type{<:LinearAlgebra.Diagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Bidiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Tridiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.SymTridiagonal}) = 2 -ndims(::Type{<:LinearAlgebra.Symmetric}) = 2 -ndims(::Type{<:LinearAlgebra.Hermitian}) = 2 -ndims(::Type{<:LinearAlgebra.UpperHessenberg}) = 2 +ndims(::Type{<:Adjoint}) = 2 +ndims(::Type{<:Transpose}) = 2 +ndims(::Type{<:LowerTriangular}) = 2 +ndims(::Type{<:UnitLowerTriangular}) = 2 +ndims(::Type{<:UpperTriangular}) = 2 +ndims(::Type{<:UnitUpperTriangular}) = 2 +ndims(::Type{<:Diagonal}) = 2 +ndims(::Type{<:Bidiagonal}) = 2 +ndims(::Type{<:Tridiagonal}) = 2 +ndims(::Type{<:SymTridiagonal}) = 2 +ndims(::Type{<:Symmetric}) = 2 +ndims(::Type{<:Hermitian}) = 2 +ndims(::Type{<:UpperHessenberg}) = 2 ndims(::Type{<:WrappedArray{<:Any,N}}) where {N} = N eltype(::Type{<:WrappedArray{T}}) where {T} = T # every wrapper has a T typevar for T in [:(Base.LogicalIndex{<:Any,<:Src}), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:Src}), - :(LinearAlgebra.Adjoint{<:Any,<:Src}), - :(LinearAlgebra.Transpose{<:Any,<:Src}), - :(LinearAlgebra.Diagonal{<:Any,<:Src}), - :(LinearAlgebra.Bidiagonal{<:Any,<:Src}), - :(LinearAlgebra.Tridiagonal{<:Any,<:Src}), - :(LinearAlgebra.SymTridiagonal{<:Any,<:Src}), + :(Adjoint{<:Any,<:Src}), + :(Transpose{<:Any,<:Src}), + :(Diagonal{<:Any,<:Src}), + :(Bidiagonal{<:Any,<:Src}), + :(Tridiagonal{<:Any,<:Src}), + :(SymTridiagonal{<:Any,<:Src}), :(WrappedReinterpretArray{<:Any,<:Any,<:Src}), :(WrappedReshapedArray{<:Any,<:Any,<:Src}), :(WrappedSubArray{<:Any,<:Any,<:Src})] diff --git a/test/runtests.jl b/test/runtests.jl index f0c3cd4..cb6121a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -201,10 +201,10 @@ end @testset "type information" begin - @test Adapt.ndims(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == 2 + @test Adapt.ndims(Transpose{Float64,Array{Float64,1}}) == 2 @test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3 - @test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array + @test Adapt.parent(Transpose{Float64,Array{Float64,1}}) == Array @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array end From d29a1e8e90c5b648b6e60db580e26242238babe2 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 24 Mar 2022 00:28:07 -0700 Subject: [PATCH 8/8] Make WrappedSubArray tests more interesting --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index cb6121a..7534635 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -202,10 +202,10 @@ end @testset "type information" begin @test Adapt.ndims(Transpose{Float64,Array{Float64,1}}) == 2 - @test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == 3 + @test Adapt.ndims(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == 3 @test Adapt.parent(Transpose{Float64,Array{Float64,1}}) == Array - @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array + @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,4}}) == Array end