diff --git a/src/base.jl b/src/base.jl index 2fd26b4..5adb77f 100644 --- a/src/base.jl +++ b/src/base.jl @@ -31,6 +31,7 @@ using LinearAlgebra const wrappers = ( :(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))), :(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)), + :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), :(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))), :(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))), :(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))), @@ -39,7 +40,6 @@ const wrappers = ( :(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))), :(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))), :(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)), - :(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), ) permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm diff --git a/test/runtests.jl b/test/runtests.jl index 372e479..0383b46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,14 +45,26 @@ Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr)) @test adapt(CustomArray, (a=val.arr,)) == (a=val,) @test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:) -@test adapt(CustomArray, view(val.arr,:,:)) isa SubArray{<:Any,<:Any,<:CustomArray} +const inds = CustomArray{Int,1}([1,2]) +@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:) + +# NOTE: manual creation of PermutedDimsArray because permutedims collects +@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1)) # NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape` @test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2)) -@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) isa Base.ReshapedArray{<:Any,<:Any,<:CustomArray} using LinearAlgebra @test adapt(CustomArray, val.arr') == val' -@test adapt(CustomArray, val.arr') isa Adjoint{<:Any,<:CustomArray} + +@test adapt(CustomArray, transpose(val.arr)) == transpose(val) + +@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val) +@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val) +@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val) +@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val) + +@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val) +@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)