Skip to content

Commit

Permalink
Merge pull request #27 from JuliaGPU/tb/logicalindex
Browse files Browse the repository at this point in the history
Support for Base.LogicalIndex.
  • Loading branch information
maleadt authored Sep 16, 2020
2 parents 4b96fa1 + 4d78d91 commit 5efd774
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export WrappedArray
# database of array wrappers
const _wrappers = (
:(SubArray{T,N,<:Src}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(Base.LogicalIndex{T,<:Src}) => (A,mut)->Base.LogicalIndex(mut(A.mask)),
:(PermutedDimsArray{T,N,<:Any,<:Any,<:Src}) => (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{T,N,<:Src}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(Base.ReinterpretArray{T,N,<:Src}) => (A,mut)->Base.reinterpret(eltype(A), mut(parent(A))),
Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ Adapt.adapt_storage(::Type{<:CustomArray}, xs::Array) = CustomArray(xs)

Base.size(x::CustomArray, y...) = size(x.arr, y...)
Base.getindex(x::CustomArray, y...) = getindex(x.arr, y...)

Base.count(x::CustomArray) = count(x.arr)

const mat = CustomArray{Float64,2}(rand(2,2))
const vec = CustomArray{Float64,1}(rand(2))

const mat_bools = CustomArray{Bool,2}(rand(Bool,2,2))

macro test_adapt(to, src, dst)
quote
@test adapt($to, $src) == $dst
Expand Down Expand Up @@ -61,6 +63,8 @@ const inds = CustomArray{Int,1}([1,2])
# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2))

@test_adapt CustomArray Base.LogicalIndex(mat_bools.arr) Base.LogicalIndex(mat_bools)


using LinearAlgebra

Expand Down

0 comments on commit 5efd774

Please sign in to comment.