diff --git a/src/base.jl b/src/base.jl index dd2c544..c512f82 100644 --- a/src/base.jl +++ b/src/base.jl @@ -6,14 +6,32 @@ adapt_structure(to, xs::Tuple) = Tuple(adapt(to, x) for x in xs) @generated adapt_structure(to, x::NamedTuple) = Expr(:tuple, (:($f=adapt(to, x.$f)) for f in fieldnames(x))...) -adapt(to, x::SubArray) = SubArray(adapt(to, parent(x)), parentindices(x)) - -## LinearAlgebra - -import LinearAlgebra: Adjoint, Transpose -adapt_structure(to, x::Adjoint) = Adjoint(adapt(to, parent(x))) -adapt_structure(to, x::Transpose) = Transpose(adapt(to, parent(x))) +## Array wrappers + +using LinearAlgebra + +# database of array wrappers, for use throughout the package +# +# LHS entries are a symbolic type with AT for the array type +# +# RHS entries consist of a closure to reconstruct the wrapper, with as arguments +# a wrapper instance and mutator function to apply to the inner array +const wrappers = ( + :(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), parentindices(A)), + :(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))), + :(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))), + :(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))), + :(LinearAlgebra.UnitLowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))), + :(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))), + :(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))), + :(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))) +) + +for (W, ctor) in wrappers + mut = :(A -> adapt(to, A)) + @eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut) +end ## Broadcast