Skip to content

Commit

Permalink
Merge pull request #18 from JuliaGPU/tb/more
Browse files Browse the repository at this point in the history
More wrappers and tests
  • Loading branch information
maleadt authored Jul 12, 2019
2 parents 4cfb6e9 + 334a3fd commit 3ef381c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
8 changes: 6 additions & 2 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ using LinearAlgebra
# RHS entries consist of a closure to reconstruct the wrapper, with as arguments
# a wrapper instance and mutator function to apply to the inner array
const wrappers = (
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), parentindices(A)),
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A))
:(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
)

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm

for (W, ctor) in wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut)
Expand Down
18 changes: 15 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,26 @@ Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr))
@test adapt(CustomArray, (a=val.arr,)) == (a=val,)

@test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:)
@test adapt(CustomArray, view(val.arr,:,:)) isa SubArray{<:Any,<:Any,<:CustomArray}
const inds = CustomArray{Int,1}([1,2])
@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:)

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1))

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2))
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) isa Base.ReshapedArray{<:Any,<:Any,<:CustomArray}


using LinearAlgebra

@test adapt(CustomArray, val.arr') == val'
@test adapt(CustomArray, val.arr') isa Adjoint{<:Any,<:CustomArray}

@test adapt(CustomArray, transpose(val.arr)) == transpose(val)

@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val)
@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val)
@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val)
@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val)

@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val)
@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)

0 comments on commit 3ef381c

Please sign in to comment.