From 23c2351969eddd0f03bfe1ebf1079bedd9c7580b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 12 Jul 2019 12:25:10 +0200 Subject: [PATCH 1/2] Support for SubArray with indices, PermutedDimsArray and Tridiagonal. --- src/base.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/base.jl b/src/base.jl index 2ad0032..2fd26b4 100644 --- a/src/base.jl +++ b/src/base.jl @@ -29,7 +29,8 @@ using LinearAlgebra # RHS entries consist of a closure to reconstruct the wrapper, with as arguments # a wrapper instance and mutator function to apply to the inner array const wrappers = ( - :(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), parentindices(A)), + :(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))), + :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)), :(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))), :(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))), :(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))), @@ -37,9 +38,12 @@ const wrappers = ( :(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))), :(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))), :(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))), - :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)) + :(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)), + :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), ) +permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm + for (W, ctor) in wrappers mut = :(A -> adapt(to, A)) @eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut) From 334a3fdb969e452afbf4c3cb16f05b7a222899b6 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 12 Jul 2019 12:33:53 +0200 Subject: [PATCH 2/2] Add tests. --- src/base.jl | 2 +- test/runtests.jl | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/base.jl b/src/base.jl index 2fd26b4..5adb77f 100644 --- a/src/base.jl +++ b/src/base.jl @@ -31,6 +31,7 @@ using LinearAlgebra const wrappers = ( :(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)), + :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), :(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))), :(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))), :(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))), @@ -39,7 +40,6 @@ const wrappers = ( :(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))), :(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))), :(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)), - :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), ) permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm diff --git a/test/runtests.jl b/test/runtests.jl index 372e479..0383b46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,14 +45,26 @@ Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr)) @test adapt(CustomArray, (a=val.arr,)) == (a=val,) @test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:) -@test adapt(CustomArray, view(val.arr,:,:)) isa SubArray{<:Any,<:Any,<:CustomArray} +const inds = CustomArray{Int,1}([1,2]) +@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:) + +# NOTE: manual creation of PermutedDimsArray because permutedims collects +@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1)) # NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape` @test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2)) -@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) isa Base.ReshapedArray{<:Any,<:Any,<:CustomArray} using LinearAlgebra @test adapt(CustomArray, val.arr') == val' -@test adapt(CustomArray, val.arr') isa Adjoint{<:Any,<:CustomArray} + +@test adapt(CustomArray, transpose(val.arr)) == transpose(val) + +@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val) +@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val) +@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val) +@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val) + +@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val) +@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)