Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formalize interface for keyword arguments in indexing #83

Merged
merged 9 commits into from
Nov 29, 2020
1 change: 1 addition & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ end

include("static.jl")
include("ranges.jl")
include("dimensions.jl")
include("indexing.jl")
include("stridelayout.jl")

Expand Down
157 changes: 157 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

"""
has_dimnames(x) -> Bool

Returns `true` if `x` has names for each dimension.
"""
@inline has_dimnames(x) = has_dimnames(typeof(x))
function has_dimnames(::Type{T}) where {T}
if parent_type(T) <: T
return false
else
return has_dimnames(parent_type(T))
end
end

"""
dimnames(x) -> Tuple{Vararg{Symbol}}
dimnames(x, d) -> Symbol

Return the names of the dimensions for `x`.
"""
@inline dimnames(x) = dimnames(typeof(x))
@inline dimnames(x, i::Integer) = dimnames(typeof(x), i)
@inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d))
@generated function dimnames(::Type{T}) where {T}
if parent_type(T) <: T
return ntuple(i -> Symbol(:dim_, i), Val(ndims(T)))
Tokazama marked this conversation as resolved.
Show resolved Hide resolved
else
return dimnames(parent_type(T))
end
end
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
return map(i -> dimnames(parent_type(T), i), (2, 1))
end
@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
return map(i -> dimnames(parent_type(T), i), I)
end
function dimnames(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
return _sub_array_dimnames(Val(dimnames(P)), Val(argdims(P, I)))
end
@generated function _sub_array_dimnames(::Val{L}, ::Val{I}) where {L,I}
e = Expr(:tuple)
for i in 1:length(L)
if I[i] > 0
push!(e.args, QuoteNode(L[i]))
end
end
return e
end

"""
named_axes(x) -> NamedTuple{dimnames(x)}(axes(x))

Returns a `NamedTuple` of the axes of `x` with dimension names as keys.
"""
named_axes(x) = NamedTuple{dimnames(x)}(axes(x))

"""
to_dims(x, d)

This returns the dimension(s) of `x` corresponding to `d`.
"""
to_dims(x, d::Integer) = Int(d)
to_dims(x, d::Colon) = d # `:` is the default for most methods that take `dims`
@inline to_dims(x, d::Tuple) = map(i -> to_dims(x, i), d)
@inline function to_dims(x, d::Symbol)::Int
i = _sym_to_dim(dimnames(x), d)
if i === 0
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(dimnames(x)))"))
end
return i
end
Base.@pure function _sym_to_dim(x::Tuple{Vararg{Symbol,N}}, sym::Symbol) where {N}
for i in 1:N
getfield(x, i) === sym && return i
end
return 0
end

"""
tuple_issubset

A version of `issubset` sepecifically for `Tuple`s of `Symbol`s, that is `@pure`.
This helps it get optimised out of existance. It is less of an abuse of `@pure` than
most of the stuff for making `NamedTuples` work.
"""
Base.@pure function tuple_issubset(
lhs::Tuple{Vararg{Symbol,N}}, rhs::Tuple{Vararg{Symbol,M}},
) where {N,M}
N <= M || return false
for a in lhs
found = false
for b in rhs
found |= a === b
end
found || return false
end
return true
end

"""
order_named_inds(Val(names); kw...)
order_named_inds(Val(names), namedtuple)

Returns the tuple of index values for an array with `names`, when indexed by keywords.
Any dimensions not fixed are given as `:`, to make a slice.
An error is thrown if any keywords are used which do not occur in `nda`'s names.
"""
order_named_inds(val::Val{L}; kw...) where {L} = order_named_inds(val, kw.data)
@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K}
if length(K) === 0
return () # if kwargs were empty
else
tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K"))
exs = map(L) do n
if Base.sym_in(n, K)
qn = QuoteNode(n)
:(getfield(ni, $qn))
else
:(Colon())
end
end
return Expr(:tuple, exs...)
end
end

"""
size(A)

Returns the size of `A`. If the size of any axes are known at compile time,
these should be returned as `Static` numbers. For example:
```julia
julia> using StaticArrays, ArrayInterface

julia> A = @SMatrix rand(3,4);

julia> ArrayInterface.size(A)
(StaticInt{3}(), StaticInt{4}())
```
"""
size(A) = Base.size(A)
size(A, d) = size(A)[to_dims(A, d)]

"""
axes(A, d)

Return a valid range that maps to each index along dimension `d` of `A`.
"""
axes(A, d) = axes(A)[to_dims(A, d)]

"""
axes(A)

Return a tuple of ranges where each range maps to each element along a dimension of `A`.
"""
axes(A) = Base.axes(A)

110 changes: 94 additions & 16 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,21 +385,34 @@ end
ArrayInterface.getindex(A, args...)

Retrieve the value(s) stored at the given key or index within a collection. Creating
another instance of `ArrayInterface.getindex` should only be done by overloading `A`.
anothe instance of `ArrayInterface.getindex` should only be done by overloading `A`.
Tokazama marked this conversation as resolved.
Show resolved Hide resolved
Changing indexing based on a given argument from `args` should be done through
[`flatten_args`](@ref), [`to_index`](@ref), or [`to_axis`](@ref).
"""
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args))
@propagate_inbounds function getindex(A; kwargs...)
if has_dimnames(A)
return A[order_named_inds(Val(dimnames(A)); kwargs...)...]
else
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
end
end

"""
unsafe_getindex(A, inds)

Indexes into `A` given `inds`. This method assumes that `inds` have already been
bounds-checked.
"""
unsafe_getindex(A, inds) = unsafe_getindex(UnsafeIndex(A, inds), A, inds)
unsafe_getindex(::UnsafeGetElement, A, inds) = unsafe_get_element(A, inds)
unsafe_getindex(::UnsafeGetCollection, A, inds) = unsafe_get_collection(A, inds)
function unsafe_getindex(A, inds; kwargs...)
return unsafe_getindex(UnsafeIndex(A, inds), A, inds; kwargs...)
end
function unsafe_getindex(::UnsafeGetElement, A, inds; kwargs...)
return unsafe_get_element(A, inds; kwargs...)
end
function unsafe_getindex(::UnsafeGetCollection, A, inds; kwargs...)
return unsafe_get_collection(A, inds; kwargs...)
end

"""
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
Expand All @@ -408,9 +421,7 @@ Returns an element of `A` at the indices `inds`. This method assumes all `inds`
have been checked for being in bounds. Any new array type using `ArrayInterface.getindex`
must define `unsafe_get_element(::NewArrayType, inds)`.
"""
function unsafe_get_element(A, inds)
throw(MethodError(unsafe_getindex, (A, inds)))
end
unsafe_get_element(A, inds; kwargs...) = throw(MethodError(unsafe_getindex, (A, inds)))
function unsafe_get_element(A::Array, inds)
if length(inds) === 0
return Base.arrayref(false, A, 1)
Expand All @@ -433,11 +444,11 @@ end

Returns a collection of `A` given `inds`. `inds` is assumed to have been bounds-checked.
"""
function unsafe_get_collection(A, inds)
function unsafe_get_collection(A, inds; kwargs...)
axs = to_axes(A, inds)
dest = similar(A, axs)
if map(Base.unsafe_length, axes(dest)) == map(Base.unsafe_length, axs)
Base._unsafe_getindex!(dest, A, inds...) # usually a generated function, don't allow it to impact inference result
_unsafe_getindex!(dest, A, inds...; kwargs...) # usually a generated function, don't allow it to impact inference result
else
Base.throw_checksize_error(dest, axs)
end
Expand Down Expand Up @@ -490,16 +501,29 @@ Store the given values at the given key or index within a collection.
"elements after construction.")
end
end
@propagate_inbounds function setindex!(A, val; kwargs...)
if has_dimnames(A)
A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val
else
return unsafe_setindex!(A, val, to_indices(A, ()); kwargs...)
end
end

"""
unsafe_setindex!(A, val, inds::Tuple)
unsafe_setindex!(A, val, inds::Tuple; kwargs...)

Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
bounds-checked. This step of the processing pipeline can be customized by:
"""
unsafe_setindex!(A, val, inds::Tuple) = unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds)
unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple) = unsafe_set_element!(A, val, inds)
unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple) = unsafe_set_collection!(A, val, inds)
function unsafe_setindex!(A, val, inds::Tuple; kwargs...)
return unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds; kwargs...)
end
function unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple; kwargs...)
return unsafe_set_element!(A, val, inds; kwargs...)
end
function unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple; kwargs...)
return unsafe_set_collection!(A, val, inds; kwargs...)
end

"""
unsafe_set_element!(A, val, inds::Tuple)
Expand All @@ -508,7 +532,7 @@ Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds
have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!`
must define `unsafe_set_element!(::NewArrayType, val, inds)`.
"""
function unsafe_set_element!(A, val, inds)
function unsafe_set_element!(A, val, inds; kwargs...)
throw(MethodError(unsafe_set_element!, (A, val, inds)))
end
function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T}
Expand All @@ -527,6 +551,60 @@ end

Sets `inds` of `A` to `val`. `inds` is assumed to have been bounds-checked.
"""
@inline function unsafe_set_collection!(A, val, inds)
return Base._unsafe_setindex!(IndexStyle(A), A, val, inds...)
@inline function unsafe_set_collection!(A, val, inds; kwargs...)
return _unsafe_setindex!(IndexStyle(A), A, val, inds...; kwargs...)
end


# these let us use `@ncall` on getindex/setindex! that have kwargs
function _setindex_kwargs!(x, val, kwargs, args...)
@inbounds setindex!(x, val, args...; kwargs...)
end
function _getindex_kwargs(x, kwargs, args...)
@inbounds getindex(x, args...; kwargs...)
end

function _generate_unsafe_getindex!_body(N::Int)
quote
Base.@_inline_meta
D = eachindex(dest)
Dy = iterate(D)
@inbounds Base.Cartesian.@nloops $N j d->I[d] begin
# This condition is never hit, but at the moment
# the optimizer is not clever enough to split the union without it
Dy === nothing && return dest
(idx, state) = Dy
dest[idx] = Base.Cartesian.@ncall $N _getindex_kwargs src kwargs j
Dy = iterate(D, state)
end
return dest
end
end

function _generate_unsafe_setindex!_body(N::Int)
quote
x′ = Base.unalias(A, x)
Base.Cartesian.@nexprs $N d->(I_d = Base.unalias(A, I[d]))
idxlens = Base.Cartesian.@ncall $N Base.index_lengths I
Base.Cartesian.@ncall $N Base.setindex_shape_check x′ (d->idxlens[d])
Xy = iterate(x′)
@inbounds Base.Cartesian.@nloops $N i d->I_d begin
# This is never reached, but serves as an assumption for
# the optimizer that it does not need to emit error paths
Xy === nothing && break
(val, state) = Xy
Base.Cartesian.@ncall $N _setindex_kwargs! A val kwargs i
Xy = iterate(x′, state)
end
A
end
end

@generated function _unsafe_getindex!(dest::AbstractArray, src::AbstractArray, I::Vararg{Union{Real, AbstractArray}, N}; kwargs...) where N
_generate_unsafe_getindex!_body(N)
end

@generated function _unsafe_setindex!(::IndexStyle, A::AbstractArray, x, I::Vararg{Union{Real,AbstractArray}, N}; kwargs...) where N
_generate_unsafe_setindex!_body(N)
end

28 changes: 21 additions & 7 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ Otherwise, return `nothing`.
@test isone(known_first(typeof(Base.OneTo(4))))
"""
known_first(x) = known_first(typeof(x))
known_first(::Type{T}) where {T} = nothing
function known_first(::Type{T}) where {T}
if parent_type(T) <: T
return nothing
else
return known_first(parent_type(T))
end
end
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))

"""
known_last(::Type{T})
Expand All @@ -24,8 +29,13 @@ using StaticArrays
@test known_last(typeof(SOneTo(4))) == 4
"""
known_last(x) = known_last(typeof(x))
known_last(::Type{T}) where {T} = nothing
known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T))
function known_last(::Type{T}) where {T}
if parent_type(T) <: T
return nothing
else
return known_last(parent_type(T))
end
end

"""
known_step(::Type{T})
Expand All @@ -37,11 +47,15 @@ Otherwise, return `nothing`.
@test isone(known_step(typeof(1:4)))
"""
known_step(x) = known_step(typeof(x))
known_step(::Type{T}) where {T} = nothing
function known_step(::Type{T}) where {T}
if parent_type(T) <: T
return nothing
else
return known_step(parent_type(T))
end
end
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

# add methods to support ArrayInterface

"""
OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int}

Expand Down
Loading