From a98359122ee663f59c2c87cf75656855c898706e Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 9 Nov 2020 08:34:17 -0500 Subject: [PATCH 1/7] official management of kwargs for indexing --- src/ArrayInterface.jl | 1 + src/dimensions.jl | 117 ++++++++++++++++++++++++++++++++++++++++++ src/indexing.jl | 6 +++ src/ranges.jl | 28 +++++++--- 4 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 src/dimensions.jl diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 9e901537d..810e8bd18 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -889,6 +889,7 @@ end include("static.jl") include("ranges.jl") +include("dimensions.jl") include("indexing.jl") include("stridelayout.jl") diff --git a/src/dimensions.jl b/src/dimensions.jl new file mode 100644 index 000000000..c5a3a2a57 --- /dev/null +++ b/src/dimensions.jl @@ -0,0 +1,117 @@ + +""" + has_dimnames(x) -> Bool + +Returns `true` if `x` has names for each dimension. +""" +@inline has_dimnames(x) = has_dimnames(typeof(x)) +function has_dimnames(::Type{T}) where {T} + if parent_type(T) <: T + return false + else + return has_dimnames(parent_type(T)) + end +end + +@generated default_dimnames(::Val{N}) where {N} = :($(ntuple(i -> Symbol(:dim_, i), N))) + +""" + dimnames(x) -> Tuple + +Return the names of the dimensions for `x`. +""" +@inline dimnames(x) = dimnames(typeof(x)) +function dimnames(::Type{T}) where {T} + if parent_type(T) <: T + return default_dimnames(Val(ndims(T))) + else + return dimnames(parent_type(T)) + end +end +dimnames(::Type{T}) where {T<:Transpose} = reverse(dimnames(parent_type(T))) +dimnames(::Type{T}) where {T<:Adjoint} = reverse(dimnames(parent_type(T))) +@inline function dimnames(::Type{T}) where {I1,A,T<:PermutedDimsArray{<:Any,<:Any,I1,<:Any,A}} + ns = dimnames(A) + return map(i -> getfield(ns, i), I1) +end +@generated function dimnames(::Type{T}) where {N,A,I,T<:SubArray{<:Any,N,A,I}} + e = Expr(:tuple) + d = dimnames(A) + for i in 1:N + if argdims(A, I.parameters[i]) > 0 + push!(e.args, d[i]) + end + end + return e +end + +""" + to_dims(x, d) + +This returns the dimension(s) of `x` corresponding to `d`. +""" +to_dims(x, d::Integer) = Int(d) +to_dims(x, d::StaticInt) = d +to_dims(x, d::Colon) = d # `:` is the default for most methods that take `dims` +@inline to_dims(x, d::Tuple) = map(i -> to_dims(x, i), d) +@inline function to_dims(x, d::Symbol)::Int + i = _sym_to_dim(dimnames(x), d) + if i === 0 + throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(dimnames(x)))")) + end + return i +end +Base.@pure function _sym_to_dim(x::Tuple{Vararg{Symbol,N}}, sym::Symbol) where {N} + for i in 1:N + getfield(x, i) === sym && return i + end + return 0 +end + +""" + tuple_issubset + +A version of `issubset` sepecifically for `Tuple`s of `Symbol`s, that is `@pure`. +This helps it get optimised out of existance. It is less of an abuse of `@pure` than +most of the stuff for making `NamedTuples` work. +""" +Base.@pure function tuple_issubset( + lhs::Tuple{Vararg{Symbol,N}}, rhs::Tuple{Vararg{Symbol,M}}, +) where {N,M} + N <= M || return false + for a in lhs + found = false + for b in rhs + found |= a === b + end + found || return false + end + return true +end + +""" + order_named_inds(Val(names); kw...) + order_named_inds(Val(names), namedtuple) + +Returns the tuple of index values for an array with `names`, when indexed by keywords. +Any dimensions not fixed are given as `:`, to make a slice. +An error is thrown if any keywords are used which do not occur in `nda`'s names. +""" +order_named_inds(val::Val{L}; kw...) where {L} = order_named_inds(val, kw.data) +@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K} + if length(K) === 0 + return () # if kwargs were empty + else + tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K")) + exs = map(L) do n + if Base.sym_in(n, K) + qn = QuoteNode(n) + :(getfield(ni, $qn)) + else + :(Colon()) + end + end + return Expr(:tuple, exs...) + end +end + diff --git a/src/indexing.jl b/src/indexing.jl index 3e31c82b4..28439e281 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -390,6 +390,7 @@ Changing indexing based on a given argument from `args` should be done through [`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref). """ @propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)) +@propagate_inbounds getindex(A; kwargs...) = A[order_named_inds(Val(dimnames(A)); kwargs...)...] """ unsafe_getindex(A, inds) @@ -490,6 +491,10 @@ Store the given values at the given key or index within a collection. "elements after construction.") end end +@propagate_inbounds function setindex!(A, val; kwargs...) + A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val +end + """ unsafe_setindex!(A, val, inds::Tuple) @@ -530,3 +535,4 @@ Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked. @inline function unsafe_set_collection!(A, val, inds) return Base._unsafe_setindex!(IndexStyle(A), A, val, inds...) end + diff --git a/src/ranges.jl b/src/ranges.jl index 973958e92..4c41de8df 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -9,9 +9,14 @@ Otherwise, return `nothing`. @test isone(known_first(typeof(Base.OneTo(4)))) """ known_first(x) = known_first(typeof(x)) -known_first(::Type{T}) where {T} = nothing +function known_first(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return known_first(parent_type(T)) + end +end known_first(::Type{Base.OneTo{T}}) where {T} = one(T) -known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T)) """ known_last(::Type{T}) @@ -24,8 +29,13 @@ using StaticArrays @test known_last(typeof(SOneTo(4))) == 4 """ known_last(x) = known_last(typeof(x)) -known_last(::Type{T}) where {T} = nothing -known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T)) +function known_last(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return known_last(parent_type(T)) + end +end """ known_step(::Type{T}) @@ -37,11 +47,15 @@ Otherwise, return `nothing`. @test isone(known_step(typeof(1:4))) """ known_step(x) = known_step(typeof(x)) -known_step(::Type{T}) where {T} = nothing +function known_step(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return known_step(parent_type(T)) + end +end known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T) -# add methods to support ArrayInterface - """ OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int} From 521b3aea78674471f653cdfd97c9e0d8aa53fce6 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 9 Nov 2020 11:03:18 -0500 Subject: [PATCH 2/7] Selectively propagate kwargs --- src/dimensions.jl | 44 ++++++++++++------- src/indexing.jl | 108 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 118 insertions(+), 34 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index c5a3a2a57..8f69b7e82 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,4 +1,14 @@ +Base.@pure function _permute_dimnames(x::Tuple{Vararg{Symbol,D}}, perm::NTuple{N,Int}) where {D,N} + ntuple(Val(N)) do i + if D < i + return :_ + else + return getfield(x, getfield(perm, i)) + end + end +end + """ has_dimnames(x) -> Bool @@ -13,33 +23,35 @@ function has_dimnames(::Type{T}) where {T} end end -@generated default_dimnames(::Val{N}) where {N} = :($(ntuple(i -> Symbol(:dim_, i), N))) - """ - dimnames(x) -> Tuple + dimnames(x) -> Tuple{Vararg{Symbol}} + dimnames(x, d) -> Symbol Return the names of the dimensions for `x`. """ @inline dimnames(x) = dimnames(typeof(x)) -function dimnames(::Type{T}) where {T} - if parent_type(T) <: T - return default_dimnames(Val(ndims(T))) +@inline dimnames(x, i::Int) = dimnames(typeof(x), i) +@inline function dimnames(::Type{T}, d::Int) where {T} + if has_dimnames(T) + return getfield(dimnames(T), d) else - return dimnames(parent_type(T)) + return nothing end end -dimnames(::Type{T}) where {T<:Transpose} = reverse(dimnames(parent_type(T))) -dimnames(::Type{T}) where {T<:Adjoint} = reverse(dimnames(parent_type(T))) -@inline function dimnames(::Type{T}) where {I1,A,T<:PermutedDimsArray{<:Any,<:Any,I1,<:Any,A}} - ns = dimnames(A) - return map(i -> getfield(ns, i), I1) +@inline function dimnames(::Type{T}) where {T} + return ntuple(i -> dimnames(parent_type(T), i), Val(ndims(T))) +end +@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}} + return map(d -> dimnames(dimnames(parent_type(T))), (2, 1)) end -@generated function dimnames(::Type{T}) where {N,A,I,T<:SubArray{<:Any,N,A,I}} +@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} + return map(d -> dimnames(dimnames(parent_type(T))), I) +end +@generated function dimnames(::Type{T}) where {I,T<:SubArray{<:Any,<:Any,<:Any,I}} e = Expr(:tuple) - d = dimnames(A) - for i in 1:N + for i in 1:ndims(T) if argdims(A, I.parameters[i]) > 0 - push!(e.args, d[i]) + push!(e.args, dimenames(parent_type(P), i)) end end return e diff --git a/src/indexing.jl b/src/indexing.jl index 28439e281..63dc5c2dd 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -390,7 +390,13 @@ Changing indexing based on a given argument from `args` should be done through [`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref). """ @propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)) -@propagate_inbounds getindex(A; kwargs...) = A[order_named_inds(Val(dimnames(A)); kwargs...)...] +@propagate_inbounds function getindex(A; kwargs...) + if has_dimnames(A) + return A[order_named_inds(Val(dimnames(A)); kwargs...)...] + else + return unsafe_getindex(A, to_indices(A, ()); kwargs...) + end +end """ unsafe_getindex(A, inds) @@ -398,9 +404,15 @@ Changing indexing based on a given argument from `args` should be done through Indexes into `A` given `inds`. This method assumes that `inds` have already been bounds-checked. """ -unsafe_getindex(A, inds) = unsafe_getindex(UnsafeIndex(A, inds), A, inds) -unsafe_getindex(::UnsafeGetElement, A, inds) = unsafe_get_element(A, inds) -unsafe_getindex(::UnsafeGetCollection, A, inds) = unsafe_get_collection(A, inds) +function unsafe_getindex(A, inds; kwargs...) + return unsafe_getindex(UnsafeIndex(A, inds), A, inds; kwargs...) +end +function unsafe_getindex(::UnsafeGetElement, A, inds; kwargs...) + return unsafe_get_element(A, inds; kwargs...) +end +function unsafe_getindex(::UnsafeGetCollection, A, inds; kwargs...) + return unsafe_get_collection(A, inds; kwargs...) +end """ unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T @@ -409,9 +421,7 @@ Returns an element of `A` at the indices `inds`. This method assumes all `inds` have been checked for being in bounds. Any new array type using `ArrayInterface.getindex` must define `unsafe_get_element(::NewArrayType, inds)`. """ -function unsafe_get_element(A, inds) - throw(MethodError(unsafe_getindex, (A, inds))) -end +unsafe_get_element(A, inds; kwargs...) = throw(MethodError(unsafe_getindex, (A, inds))) function unsafe_get_element(A::Array, inds) if length(inds) === 0 return Base.arrayref(false, A, 1) @@ -434,11 +444,11 @@ end Returns a collection of `A` given `inds`. `inds` is assumed to have been bounds-checked. """ -function unsafe_get_collection(A, inds) +function unsafe_get_collection(A, inds; kwargs...) axs = to_axes(A, inds) dest = similar(A, axs) if map(Base.unsafe_length, axes(dest)) == map(Base.unsafe_length, axs) - Base._unsafe_getindex!(dest, A, inds...) # usually a generated function, don't allow it to impact inference result + _unsafe_getindex!(dest, A, inds...; kwargs...) # usually a generated function, don't allow it to impact inference result else Base.throw_checksize_error(dest, axs) end @@ -492,19 +502,28 @@ Store the given values at the given key or index within a collection. end end @propagate_inbounds function setindex!(A, val; kwargs...) - A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val + if has_dimnames(A) + A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val + else + return unsafe_setindex!(A, val, to_indices(A, ()); kwargs...) + end end - """ - unsafe_setindex!(A, val, inds::Tuple) + unsafe_setindex!(A, val, inds::Tuple; kwargs...) Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been bounds-checked. This step of the processing pipeline can be customized by: """ -unsafe_setindex!(A, val, inds::Tuple) = unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds) -unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple) = unsafe_set_element!(A, val, inds) -unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple) = unsafe_set_collection!(A, val, inds) +function unsafe_setindex!(A, val, inds::Tuple; kwargs...) + return unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds; kwargs...) +end +function unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple; kwargs...) + return unsafe_set_element!(A, val, inds; kwargs...) +end +function unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple; kwargs...) + return unsafe_set_collection!(A, val, inds; kwargs...) +end """ unsafe_set_element!(A, val, inds::Tuple) @@ -513,7 +532,7 @@ Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!` must define `unsafe_set_element!(::NewArrayType, val, inds)`. """ -function unsafe_set_element!(A, val, inds) +function unsafe_set_element!(A, val, inds; kwargs...) throw(MethodError(unsafe_set_element!, (A, val, inds))) end function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T} @@ -532,7 +551,60 @@ end Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked. """ -@inline function unsafe_set_collection!(A, val, inds) - return Base._unsafe_setindex!(IndexStyle(A), A, val, inds...) +@inline function unsafe_set_collection!(A, val, inds; kwargs...) + return _unsafe_setindex!(IndexStyle(A), A, val, inds...; kwargs...) +end + + +# these let us use `@ncall` on getindex/setindex! that have kwargs +function _setindex_kwargs!(x, val, kwargs, args...) + @inbounds setindex!(x, val, args...; kwargs...) +end +function _getindex_kwargs(x, kwargs, args...) + @inbounds getindex(x, args...; kwargs...) +end + +function _generate_unsafe_getindex!_body(N::Int) + quote + Base.@_inline_meta + D = eachindex(dest) + Dy = iterate(D) + @inbounds Base.Cartesian.@nloops $N j d->I[d] begin + # This condition is never hit, but at the moment + # the optimizer is not clever enough to split the union without it + Dy === nothing && return dest + (idx, state) = Dy + dest[idx] = Base.Cartesian.@ncall $N _getindex_kwargs src kwargs j + Dy = iterate(D, state) + end + return dest + end +end + +function _generate_unsafe_setindex!_body(N::Int) + quote + x′ = Base.unalias(A, x) + @nexprs $N d->(I_d = Base.unalias(A, I[d])) + idxlens = Base.Cartesian.@ncall $N Base.index_lengths I + @ncall $N Base.setindex_shape_check x′ (d->idxlens[d]) + Xy = iterate(x′) + @inbounds Base.Cartesian.@nloops $N i d->I_d begin + # This is never reached, but serves as an assumption for + # the optimizer that it does not need to emit error paths + Xy === nothing && break + (val, state) = Xy + Base.Cartesian.@ncall $N _setindex_kwargs! A val kwargs i + Xy = iterate(x′, state) + end + A + end +end + +@generated function _unsafe_getindex!(dest::AbstractArray, src::AbstractArray, I::Vararg{Union{Real, AbstractArray}, N}; kwargs...) where N + _generate_unsafe_getindex!_body(N) +end + +@generated function _unsafe_setindex!(::IndexStyle, A::AbstractArray, x, I::Vararg{Union{Real,AbstractArray}, N}; kwargs...) where N + _generate_unsafe_setindex!_body(N) end From 209cd58a6bb01f24b1c36d8f74b8b3c9436a3bdb Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 10 Nov 2020 12:10:59 -0500 Subject: [PATCH 3/7] Dummy tests for dimnames --- src/dimensions.jl | 40 ++++++++++++------------------- src/indexing.jl | 6 ++--- test/dimensions.jl | 60 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 79 insertions(+), 28 deletions(-) create mode 100644 test/dimensions.jl diff --git a/src/dimensions.jl b/src/dimensions.jl index 8f69b7e82..fdd8e3f70 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,14 +1,4 @@ -Base.@pure function _permute_dimnames(x::Tuple{Vararg{Symbol,D}}, perm::NTuple{N,Int}) where {D,N} - ntuple(Val(N)) do i - if D < i - return :_ - else - return getfield(x, getfield(perm, i)) - end - end -end - """ has_dimnames(x) -> Bool @@ -30,28 +20,29 @@ end Return the names of the dimensions for `x`. """ @inline dimnames(x) = dimnames(typeof(x)) -@inline dimnames(x, i::Int) = dimnames(typeof(x), i) -@inline function dimnames(::Type{T}, d::Int) where {T} - if has_dimnames(T) - return getfield(dimnames(T), d) +@inline dimnames(x, i::Integer) = dimnames(typeof(x), i) +@inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d)) +@generated function dimnames(::Type{T}) where {T} + if parent_type(T) <: T + return ntuple(i -> Symbol(:dim_, i), Val(ndims(T))) else - return nothing + return dimnames(parent_type(T)) end end -@inline function dimnames(::Type{T}) where {T} - return ntuple(i -> dimnames(parent_type(T), i), Val(ndims(T))) -end @inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}} - return map(d -> dimnames(dimnames(parent_type(T))), (2, 1)) + return map(i -> dimnames(parent_type(T), i), (2, 1)) end @inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} - return map(d -> dimnames(dimnames(parent_type(T))), I) + return map(i -> dimnames(parent_type(T), i), I) +end +function dimnames(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} + return _sub_array_dimnames(Val(dimnames(P)), Val(argdims(P, I))) end -@generated function dimnames(::Type{T}) where {I,T<:SubArray{<:Any,<:Any,<:Any,I}} +@generated function _sub_array_dimnames(::Val{L}, ::Val{I}) where {L,I} e = Expr(:tuple) - for i in 1:ndims(T) - if argdims(A, I.parameters[i]) > 0 - push!(e.args, dimenames(parent_type(P), i)) + for i in 1:length(L) + if I[i] > 0 + push!(e.args, QuoteNode(L[i])) end end return e @@ -63,7 +54,6 @@ end This returns the dimension(s) of `x` corresponding to `d`. """ to_dims(x, d::Integer) = Int(d) -to_dims(x, d::StaticInt) = d to_dims(x, d::Colon) = d # `:` is the default for most methods that take `dims` @inline to_dims(x, d::Tuple) = map(i -> to_dims(x, i), d) @inline function to_dims(x, d::Symbol)::Int diff --git a/src/indexing.jl b/src/indexing.jl index 63dc5c2dd..ab48b9f6a 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -385,7 +385,7 @@ end ArrayInterface.getindex(A, args...) Retrieve the value(s) stored at the given key or index within a collection. Creating -another instance of `ArrayInterface.getindex` should only be done by overloading `A`. +anothe instance of `ArrayInterface.getindex` should only be done by overloading `A`. Changing indexing based on a given argument from `args` should be done through [`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref). """ @@ -584,9 +584,9 @@ end function _generate_unsafe_setindex!_body(N::Int) quote x′ = Base.unalias(A, x) - @nexprs $N d->(I_d = Base.unalias(A, I[d])) + Base.Cartesian.@nexprs $N d->(I_d = Base.unalias(A, I[d])) idxlens = Base.Cartesian.@ncall $N Base.index_lengths I - @ncall $N Base.setindex_shape_check x′ (d->idxlens[d]) + Base.Cartesian.@ncall $N Base.setindex_shape_check x′ (d->idxlens[d]) Xy = iterate(x′) @inbounds Base.Cartesian.@nloops $N i d->I_d begin # This is never reached, but serves as an assumption for diff --git a/test/dimensions.jl b/test/dimensions.jl new file mode 100644 index 000000000..3b0deb3a9 --- /dev/null +++ b/test/dimensions.jl @@ -0,0 +1,60 @@ + +@testset "dimensions" begin + +struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} + parent::P + NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p) +end +ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,P}} = P +ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true +ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = L +Base.parent(x::NamedDimsWrapper) = x.parent +Base.size(x::NamedDimsWrapper) = size(parent(x)) +Base.axes(x::NamedDimsWrapper) = axes(parent(x)) + +Base.getindex(x::NamedDimsWrapper; kwargs...) = ArrayInterface.getindex(x; kwargs...) +Base.getindex(x::NamedDimsWrapper, args...) = ArrayInterface.getindex(x, args...) +Base.setindex!(x::NamedDimsWrapper, val; kwargs...) = ArrayInterface.setindex!(x, val; kwargs...) +Base.setindex!(x::NamedDimsWrapper, val, args...) = ArrayInterface.setindex!(x, val, args...) +function ArrayInterface.unsafe_get_element(x::NamedDimsWrapper, inds; kwargs...) + return @inbounds(parent(x)[inds...]) +end +function ArrayInterface.unsafe_set_element!(x::NamedDimsWrapper, val, inds; kwargs...) + return @inbounds(parent(x)[inds...] = val) +end + +val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x)) +val_dimnames(x) = Val(ArrayInterface.dimnames(x)) +val_dimnames(x, d) = Val(ArrayInterface.dimnames(x, d)) + +d = (:x, :y) +x = NamedDimsWrapper{d}(ones(2,2)) +dnums = ntuple(+, length(d)) +@test @inferred(val_has_dimnames(x)) === Val(true) +@test @inferred(val_has_dimnames(typeof(x))) === Val(true) +@test @inferred(val_dimnames(x)) === Val(d) +@test @inferred(val_dimnames(x')) === Val(reverse(d)) +@test @inferred(val_dimnames(PermutedDimsArray(x, (2, 1)))) ===Val(reverse(d)) +@test @inferred(val_dimnames(view(x, :, 1))) === Val((:x,)) +@test @inferred(val_dimnames(x, ArrayInterface.One())) === Val(:x) +@test @inferred(ArrayInterface.to_dims(x, d)) === dnums +@test @inferred(ArrayInterface.to_dims(x, reverse(d))) === reverse(dnums) +@test_throws ArgumentError ArrayInterface.to_dims(x, :z) + +x[x = 1] = [2, 3] +@test @inferred(getindex(x, x = 1)) == [2, 3] + +@testset "order_named_inds" begin + @test ArrayInterface.order_named_inds(Val((:x,)); x=2) == (2,) + @test ArrayInterface.order_named_inds(Val((:x, :y)); x=2) == (2, :) + @test ArrayInterface.order_named_inds(Val((:x, :y)); y=2, ) == (:, 2) + @test ArrayInterface.order_named_inds(Val((:x, :y)); y=20, x=30) == (30, 20) + @test ArrayInterface.order_named_inds(Val((:x, :y)); x=30, y=20) == (30, 20) +end + +@testset "tuple_issubset" begin + @test ArrayInterface.tuple_issubset((:a, :c), (:a, :b, :c)) == true + @test ArrayInterface.tuple_issubset((:a, :b, :c), (:a, :c)) == false +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 046fb386a..ab32a49cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -514,4 +514,5 @@ end end include("indexing.jl") +include("dimensions.jl") From 150ff884d3289cc64e00a4957602d9afa052819f Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 11 Nov 2020 04:27:30 -0500 Subject: [PATCH 4/7] Add ArrayInterface.to_dims support to axes/size/strides --- src/dimensions.jl | 38 ++++++++++++++++++++++++++++++++++++++ src/stridelayout.jl | 19 +++---------------- test/dimensions.jl | 8 +++++++- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index fdd8e3f70..99d84c6d4 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -48,6 +48,13 @@ end return e end +""" + named_axes(x) -> NamedTuple{dimnames(x)}(axes(x)) + +Returns a `NamedTuple` of the axes of `x` with dimension names as keys. +""" +named_axes(x) = NamedTuple{dimnames(x)}(axes(x)) + """ to_dims(x, d) @@ -117,3 +124,34 @@ order_named_inds(val::Val{L}; kw...) where {L} = order_named_inds(val, kw.data) end end +""" + size(A) + +Returns the size of `A`. If the size of any axes are known at compile time, +these should be returned as `Static` numbers. For example: +```julia +julia> using StaticArrays, ArrayInterface + +julia> A = @SMatrix rand(3,4); + +julia> ArrayInterface.size(A) +(StaticInt{3}(), StaticInt{4}()) +``` +""" +size(A) = Base.size(A) +size(A, d) = size(A)[to_dims(A, d)] + +""" + axes(A, d) + +Return a valid range that maps to each index along dimension `d` of `A`. +""" +axes(A, d) = axes(A)[to_dims(A, d)] + +""" + axes(A) + +Return a tuple of ranges where each range maps to each element along a dimension of `A`. +""" +axes(A) = Base.axes(A) + diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 159aeec53..530613523 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -225,21 +225,6 @@ permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}( Expr(:block, Expr(:meta, :inline), t) end -""" - size(A) - -Returns the size of `A`. If the size of any axes are known at compile time, -these should be returned as `Static` numbers. For example: -```julia -julia> using StaticArrays, ArrayInterface - -julia> A = @SMatrix rand(3,4); - -julia> ArrayInterface.size(A) -(StaticInt{3}(), StaticInt{4}()) -``` -""" -size(A) = Base.size(A) """ strides(A) @@ -253,6 +238,8 @@ julia> ArrayInterface.strides(A) ``` """ strides(A) = Base.strides(A) +strides(A, d) = strides(A)[to_dims(A, d)] + """ offsets(A) @@ -279,7 +266,6 @@ end end end - @inline size(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = permute(size(parent(B)), Val{(2,1)}()) @inline size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(size(parent(B)), Val{I1}()) @inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] @@ -314,3 +300,4 @@ end end Expr(:block, Expr(:meta, :inline), t) end + diff --git a/test/dimensions.jl b/test/dimensions.jl index 3b0deb3a9..d8431cc90 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -11,6 +11,7 @@ ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = L Base.parent(x::NamedDimsWrapper) = x.parent Base.size(x::NamedDimsWrapper) = size(parent(x)) Base.axes(x::NamedDimsWrapper) = axes(parent(x)) +Base.strides(x::NamedDimsWrapper) = Base.strides(parent(x)) Base.getindex(x::NamedDimsWrapper; kwargs...) = ArrayInterface.getindex(x; kwargs...) Base.getindex(x::NamedDimsWrapper, args...) = ArrayInterface.getindex(x, args...) @@ -41,6 +42,11 @@ dnums = ntuple(+, length(d)) @test @inferred(ArrayInterface.to_dims(x, reverse(d))) === reverse(dnums) @test_throws ArgumentError ArrayInterface.to_dims(x, :z) +@test @inferred(ArrayInterface.size(x, :x)) == size(parent(x), 1) +@test @inferred(ArrayInterface.axes(x, :x)) == axes(parent(x), 1) +@test @inferred(ArrayInterface.strides(x, :x)) == strides(parent(x))[1] + + x[x = 1] = [2, 3] @test @inferred(getindex(x, x = 1)) == [2, 3] @@ -57,4 +63,4 @@ end @test ArrayInterface.tuple_issubset((:a, :b, :c), (:a, :c)) == false end -end \ No newline at end of file +end From 40113cca431bd4306e103018b74025db2dae4533 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 11 Nov 2020 22:04:45 -0500 Subject: [PATCH 5/7] Replace default dimnames with `:_` --- src/dimensions.jl | 25 +++++++++++++------------ src/indexing.jl | 2 +- test/dimensions.jl | 5 ++++- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index 99d84c6d4..6b2b0a6f8 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -22,16 +22,19 @@ Return the names of the dimensions for `x`. @inline dimnames(x) = dimnames(typeof(x)) @inline dimnames(x, i::Integer) = dimnames(typeof(x), i) @inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d)) -@generated function dimnames(::Type{T}) where {T} +@inline function dimnames(::Type{T}) where {T} if parent_type(T) <: T - return ntuple(i -> Symbol(:dim_, i), Val(ndims(T))) + return ntuple(i -> :_, Val(ndims(T))) else return dimnames(parent_type(T)) end end @inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}} - return map(i -> dimnames(parent_type(T), i), (2, 1)) + return _transpose_dimnames(dimnames(parent_type(T))) end +_transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x)) +_transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x)) + @inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} return map(i -> dimnames(parent_type(T), i), I) end @@ -40,21 +43,19 @@ function dimnames(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} end @generated function _sub_array_dimnames(::Val{L}, ::Val{I}) where {L,I} e = Expr(:tuple) - for i in 1:length(L) + nl = length(L) + for i in 1:length(I) if I[i] > 0 - push!(e.args, QuoteNode(L[i])) + if nl < i + push!(e.args, QuoteNode(:_)) + else + push!(e.args, QuoteNode(L[i])) + end end end return e end -""" - named_axes(x) -> NamedTuple{dimnames(x)}(axes(x)) - -Returns a `NamedTuple` of the axes of `x` with dimension names as keys. -""" -named_axes(x) = NamedTuple{dimnames(x)}(axes(x)) - """ to_dims(x, d) diff --git a/src/indexing.jl b/src/indexing.jl index ab48b9f6a..cd83bfa0d 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -385,7 +385,7 @@ end ArrayInterface.getindex(A, args...) Retrieve the value(s) stored at the given key or index within a collection. Creating -anothe instance of `ArrayInterface.getindex` should only be done by overloading `A`. +another instance of `ArrayInterface.getindex` should only be done by overloading `A`. Changing indexing based on a given argument from `args` should be done through [`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref). """ diff --git a/test/dimensions.jl b/test/dimensions.jl index d8431cc90..f3e78453f 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -30,13 +30,17 @@ val_dimnames(x, d) = Val(ArrayInterface.dimnames(x, d)) d = (:x, :y) x = NamedDimsWrapper{d}(ones(2,2)) +y = NamedDimsWrapper{(:x,)}(ones(2)) dnums = ntuple(+, length(d)) @test @inferred(val_has_dimnames(x)) === Val(true) @test @inferred(val_has_dimnames(typeof(x))) === Val(true) @test @inferred(val_dimnames(x)) === Val(d) @test @inferred(val_dimnames(x')) === Val(reverse(d)) +@test @inferred(val_dimnames(y')) === Val((:_, :x)) @test @inferred(val_dimnames(PermutedDimsArray(x, (2, 1)))) ===Val(reverse(d)) @test @inferred(val_dimnames(view(x, :, 1))) === Val((:x,)) +@test @inferred(val_dimnames(view(x, :, :, :))) === Val((:x, :y, :_)) +@test @inferred(val_dimnames(view(x, :, 1, :))) === Val((:x, :_)) @test @inferred(val_dimnames(x, ArrayInterface.One())) === Val(:x) @test @inferred(ArrayInterface.to_dims(x, d)) === dnums @test @inferred(ArrayInterface.to_dims(x, reverse(d))) === reverse(dnums) @@ -46,7 +50,6 @@ dnums = ntuple(+, length(d)) @test @inferred(ArrayInterface.axes(x, :x)) == axes(parent(x), 1) @test @inferred(ArrayInterface.strides(x, :x)) == strides(parent(x))[1] - x[x = 1] = [2, 3] @test @inferred(getindex(x, x = 1)) == [2, 3] From 960df415c3d29525962fc0fc10f9723877e2b292 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Fri, 13 Nov 2020 06:36:14 -0500 Subject: [PATCH 6/7] Use `Type{T}` in documentation where applicable. Users should only define these methods for new types like this so it should probably be represented in the documentation similarly. --- src/ArrayInterface.jl | 2 +- src/dimensions.jl | 42 ++++++++++++++++++++++-------------------- src/ranges.jl | 1 + 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 810e8bd18..96579efbf 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -11,7 +11,7 @@ parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) """ - parent_type(x) + parent_type(::Type{T}) Returns the parent array that `x` wraps. """ diff --git a/src/dimensions.jl b/src/dimensions.jl index 6b2b0a6f8..1f775041c 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,6 +1,6 @@ """ - has_dimnames(x) -> Bool + has_dimnames(::Type{T}) -> Bool Returns `true` if `x` has names for each dimension. """ @@ -14,8 +14,8 @@ function has_dimnames(::Type{T}) where {T} end """ - dimnames(x) -> Tuple{Vararg{Symbol}} - dimnames(x, d) -> Symbol + dimnames(::Type{T}) -> Tuple{Vararg{Symbol}} + dimnames(::Type{T}, d) -> Symbol Return the names of the dimensions for `x`. """ @@ -57,7 +57,7 @@ end end """ - to_dims(x, d) + to_dims(x[, d]) This returns the dimension(s) of `x` corresponding to `d`. """ @@ -100,29 +100,31 @@ Base.@pure function tuple_issubset( end """ - order_named_inds(Val(names); kw...) + order_named_inds(Val(names); kwargs...) order_named_inds(Val(names), namedtuple) Returns the tuple of index values for an array with `names`, when indexed by keywords. Any dimensions not fixed are given as `:`, to make a slice. An error is thrown if any keywords are used which do not occur in `nda`'s names. """ -order_named_inds(val::Val{L}; kw...) where {L} = order_named_inds(val, kw.data) -@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K} - if length(K) === 0 - return () # if kwargs were empty +@inline function order_named_inds(val::Val{L}; kwargs...) where {L} + if isempty(kwargs) + return () else - tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K")) - exs = map(L) do n - if Base.sym_in(n, K) - qn = QuoteNode(n) - :(getfield(ni, $qn)) - else - :(Colon()) - end + return order_named_inds(val, kwargs.data) + end +end +@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K} + tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K")) + exs = map(L) do n + if Base.sym_in(n, K) + qn = QuoteNode(n) + :(getfield(ni, $qn)) + else + :(Colon()) end - return Expr(:tuple, exs...) end + return Expr(:tuple, exs...) end """ @@ -140,14 +142,14 @@ julia> ArrayInterface.size(A) ``` """ size(A) = Base.size(A) -size(A, d) = size(A)[to_dims(A, d)] +size(A, d) = Base.size(A, to_dims(A, d)) """ axes(A, d) Return a valid range that maps to each index along dimension `d` of `A`. """ -axes(A, d) = axes(A)[to_dims(A, d)] +axes(A, d) = Base.axes(A, to_dims(A, d)) """ axes(A) diff --git a/src/ranges.jl b/src/ranges.jl index 4c41de8df..74f08faec 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -433,3 +433,4 @@ end lst = _try_static(static_last(x), static_last(y)) return Base.Slice(OptionallyStaticUnitRange(fst, lst)) end + From 65ce8507d11d20e76b1baff394c8e3869f46ef24 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sat, 28 Nov 2020 21:40:23 -0500 Subject: [PATCH 7/7] get tests working with new strides --- src/stridelayout.jl | 8 +++++++- test/dimensions.jl | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index f4d1d108d..5da3c96d9 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -9,7 +9,13 @@ If no axis is contiguous, it returns `Contiguous{-1}`. If unknown, it returns `nothing`. """ contiguous_axis(x) = contiguous_axis(typeof(x)) -contiguous_axis(::Type) = nothing +function contiguous_axis(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return contiguous_axis(parent_type(T)) + end +end contiguous_axis(::Type{<:Array}) = Contiguous{1}() contiguous_axis(::Type{<:Tuple}) = Contiguous{1}() function contiguous_axis(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}) where {T,A<:AbstractVector{T}} diff --git a/test/dimensions.jl b/test/dimensions.jl index f3e78453f..e3b5a9d5d 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -5,7 +5,7 @@ struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} parent::P NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p) end -ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,P}} = P +ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = L Base.parent(x::NamedDimsWrapper) = x.parent @@ -48,7 +48,7 @@ dnums = ntuple(+, length(d)) @test @inferred(ArrayInterface.size(x, :x)) == size(parent(x), 1) @test @inferred(ArrayInterface.axes(x, :x)) == axes(parent(x), 1) -@test @inferred(ArrayInterface.strides(x, :x)) == strides(parent(x))[1] +@test ArrayInterface.strides(x, :x) == ArrayInterface.strides(parent(x))[1] x[x = 1] = [2, 3] @test @inferred(getindex(x, x = 1)) == [2, 3]