Skip to content

Commit

Permalink
Merge pull request #10 from JuliaGPU/tb/wrappers
Browse files Browse the repository at this point in the history
Provide database of array wrappers.
  • Loading branch information
maleadt authored Nov 5, 2018
2 parents 731f701 + 608e461 commit 555d56e
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,32 @@ adapt_structure(to, xs::Tuple) = Tuple(adapt(to, x) for x in xs)
@generated adapt_structure(to, x::NamedTuple) =
Expr(:tuple, (:($f=adapt(to, x.$f)) for f in fieldnames(x))...)

adapt(to, x::SubArray) = SubArray(adapt(to, parent(x)), parentindices(x))


## LinearAlgebra

import LinearAlgebra: Adjoint, Transpose
adapt_structure(to, x::Adjoint) = Adjoint(adapt(to, parent(x)))
adapt_structure(to, x::Transpose) = Transpose(adapt(to, parent(x)))
## Array wrappers

using LinearAlgebra

# database of array wrappers, for use throughout the package
#
# LHS entries are a symbolic type with AT for the array type
#
# RHS entries consist of a closure to reconstruct the wrapper, with as arguments
# a wrapper instance and mutator function to apply to the inner array
const wrappers = (
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), parentindices(A)),
:(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A)))
)

for (W, ctor) in wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut)
end


## Broadcast
Expand Down

0 comments on commit 555d56e

Please sign in to comment.