Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jul 12, 2019
1 parent 23c2351 commit 334a3fd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using LinearAlgebra
const wrappers = (
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
Expand All @@ -39,7 +40,6 @@ const wrappers = (
:(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
)

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm
Expand Down
18 changes: 15 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,26 @@ Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr))
@test adapt(CustomArray, (a=val.arr,)) == (a=val,)

@test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:)
@test adapt(CustomArray, view(val.arr,:,:)) isa SubArray{<:Any,<:Any,<:CustomArray}
const inds = CustomArray{Int,1}([1,2])
@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:)

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1))

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2))
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) isa Base.ReshapedArray{<:Any,<:Any,<:CustomArray}


using LinearAlgebra

@test adapt(CustomArray, val.arr') == val'
@test adapt(CustomArray, val.arr') isa Adjoint{<:Any,<:CustomArray}

@test adapt(CustomArray, transpose(val.arr)) == transpose(val)

@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val)
@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val)
@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val)
@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val)

@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val)
@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)

0 comments on commit 334a3fd

Please sign in to comment.