From 4d78d91086dbbbdad9461295b088d2d780960274 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 16 Sep 2020 11:15:16 +0200 Subject: [PATCH] Support for Base.LogicalIndex. --- src/wrappers.jl | 1 + test/runtests.jl | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index 6f5dc06..d5a8361 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -10,6 +10,7 @@ export WrappedArray # database of array wrappers const _wrappers = ( :(SubArray{T,N,<:Src}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))), + :(Base.LogicalIndex{T,<:Src}) => (A,mut)->Base.LogicalIndex(mut(A.mask)), :(PermutedDimsArray{T,N,<:Any,<:Any,<:Src}) => (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)), :(Base.ReshapedArray{T,N,<:Src}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)), :(Base.ReinterpretArray{T,N,<:Src}) => (A,mut)->Base.reinterpret(eltype(A), mut(parent(A))), diff --git a/test/runtests.jl b/test/runtests.jl index ece44da..2495556 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,11 +13,13 @@ Adapt.adapt_storage(::Type{<:CustomArray}, xs::Array) = CustomArray(xs) Base.size(x::CustomArray, y...) = size(x.arr, y...) Base.getindex(x::CustomArray, y...) = getindex(x.arr, y...) - +Base.count(x::CustomArray) = count(x.arr) const mat = CustomArray{Float64,2}(rand(2,2)) const vec = CustomArray{Float64,1}(rand(2)) +const mat_bools = CustomArray{Bool,2}(rand(Bool,2,2)) + macro test_adapt(to, src, dst) quote @test adapt($to, $src) == $dst @@ -61,6 +63,8 @@ const inds = CustomArray{Int,1}([1,2]) # NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape` @test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2)) +@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools) + using LinearAlgebra