From 6e692d95ea488e9d955c638f0e42e3ecf0d0a16e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 21 Sep 2023 18:04:48 -0400 Subject: [PATCH] [NDTensors] Add SmallVectors submodule --- NDTensors/src/NDTensors.jl | 2 + NDTensors/src/SmallVectors/README.md | 18 ++ .../src/SmallVectors/src/SmallVectors.jl | 23 ++ .../abstractsmallvector.jl | 24 ++ .../src/abstractsmallvector/deque.jl | 280 ++++++++++++++++++ .../src/msmallvector/msmallvector.jl | 70 +++++ .../src/smallvector/smallvector.jl | 64 ++++ .../src/subsmallvector/subsmallvector.jl | 69 +++++ NDTensors/src/SmallVectors/test/runtests.jl | 78 +++++ NDTensors/test/SmallVectors.jl | 4 + NDTensors/test/runtests.jl | 1 + 11 files changed, 633 insertions(+) create mode 100644 NDTensors/src/SmallVectors/README.md create mode 100644 NDTensors/src/SmallVectors/src/SmallVectors.jl create mode 100644 NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl create mode 100644 NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl create mode 100644 NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl create mode 100644 NDTensors/src/SmallVectors/src/smallvector/smallvector.jl create mode 100644 NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl create mode 100644 NDTensors/src/SmallVectors/test/runtests.jl create mode 100644 NDTensors/test/SmallVectors.jl diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index ae138d11fc..80f4a7fd8b 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -19,6 +19,8 @@ using TupleTools include("SetParameters/src/SetParameters.jl") using .SetParameters +include("SmallVectors/src/SmallVectors.jl") +using .SmallVectors using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo diff --git a/NDTensors/src/SmallVectors/README.md b/NDTensors/src/SmallVectors/README.md new file mode 100644 index 0000000000..b771e20e7b --- /dev/null +++ b/NDTensors/src/SmallVectors/README.md @@ -0,0 +1,18 @@ +# SmallVectors + +A module that defines small (mutable and immutable) vectors with a maximum length. Externally the have a dynamic (or in the case of immuatable vectors, runtime) length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type. + +For example: +```julia +using NDTensors.SmallVectors +v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10 +v = push(v, 4) +v = setindex(v, 4, 4) +v = sort(v; rev=true) + +mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10 +push!(mv, 4) +mv[2] = 12 +sort!(mv; rev=true) +``` +This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length. diff --git a/NDTensors/src/SmallVectors/src/SmallVectors.jl b/NDTensors/src/SmallVectors/src/SmallVectors.jl new file mode 100644 index 0000000000..908c498724 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/SmallVectors.jl @@ -0,0 +1,23 @@ +module SmallVectors + using StaticArrays + + include("StaticArraysExt.jl") + using .StaticArraysExt + + export SmallVector, MSmallVector, SubSmallVector + + struct NotImplemented <: Exception + msg::String + end + NotImplemented() = NotImplemented("Not implemented.") + + struct BufferDimensionMismatch <: Exception + msg::String + end + + include("abstractsmallvector/abstractsmallvector.jl") + include("abstractsmallvector/deque.jl") + include("msmallvector/msmallvector.jl") + include("smallvector/smallvector.jl") + include("subsmallvector/subsmallvector.jl") +end diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl new file mode 100644 index 0000000000..b60f8a7cf0 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl @@ -0,0 +1,24 @@ +""" +A vector with a fixed maximum length, backed by a fixed size buffer. +""" +abstract type AbstractSmallVector{T} <: AbstractVector{T} end + +# Required buffer interface +buffer(vec::AbstractSmallVector) = throw(NotImplemented()) + +similar_type(vec::AbstractSmallVector) = typeof(vec) + +# Required buffer interface +maxlength(vec::AbstractSmallVector) = length(buffer(vec)) + +# Required AbstractArray interface +Base.size(vec::AbstractSmallVector) = throw(NotImplemented()) + +# Derived AbstractArray interface +function Base.getindex(vec::AbstractSmallVector, index::Integer) + return throw(NotImplemented()) +end +function Base.setindex!(vec::AbstractSmallVector, item, index::Integer) + return throw(NotImplemented()) +end +Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear() diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl new file mode 100644 index 0000000000..55e3d8352b --- /dev/null +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl @@ -0,0 +1,280 @@ +# TODO: Set operations +# union, ∪, union! +# intersect, ∩, intersect! +# setdiff, setdiff! +# symdiff, symdiff! +# unique, unique! + +Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented()) + +@inline function resize(vec::AbstractSmallVector, len) + mvec = Base.copymutable(vec) + resize!(mvec, len) + return convert(similar_type(vec), mvec) +end + +@inline function Base.empty!(vec::AbstractSmallVector) + resize!(vec, 0) + return vec +end + +@inline function empty(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + empty!(mvec) + return convert(similar_type(vec), mvec) +end + +@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer) + @boundscheck checkbounds(vec, index) + mvec = Base.copymutable(vec) + @inbounds mvec[index] = item + return convert(similar_type(vec), mvec) +end + +@inline function Base.push!(vec::AbstractSmallVector, item) + resize!(vec, length(vec) + 1) + @inbounds vec[length(vec)] = item + return vec +end + +@inline function StaticArrays.push(vec::AbstractSmallVector, item) + mvec = Base.copymutable(vec) + push!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.pop!(vec::AbstractSmallVector) + resize!(vec, length(vec) - 1) + return vec +end + +@inline function StaticArrays.pop(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + pop!(mvec) + return convert(similar_type(vec), mvec) +end + +@inline function Base.pushfirst!(vec::AbstractSmallVector, item) + insert!(vec, firstindex(vec), item) + return vec +end + +# Don't `@inline`, makes it slower. +function StaticArrays.pushfirst(vec::AbstractSmallVector, item) + mvec = Base.copymutable(vec) + pushfirst!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.popfirst!(vec::AbstractSmallVector) + circshift!(vec, -1) + resize!(vec, length(vec) - 1) + return vec +end + +# Don't `@inline`, makes it slower. +function StaticArrays.popfirst(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + popfirst!(mvec) + return convert(similar_type(vec), mvec) +end + +@inline function Base.reverse!(vec::AbstractSmallVector) + start, stop = firstindex(vec), lastindex(vec) + r = stop + @inbounds for i in start:Base.midpoint(start, stop-1) + vec[i], vec[r] = vec[r], vec[i] + r -= 1 + end + return vec +end + +@inline function Base.reverse!(vec::AbstractSmallVector, start::Integer, stop::Integer=lastindex(v)) + reverse!(smallview(vec, start, stop)) + return vec +end + +@inline function Base.circshift!(vec::AbstractSmallVector, shift::Integer) + start, stop = firstindex(vec), lastindex(vec) + n = length(vec) + n == 0 && return vec + shift = mod(shift, n) + shift == 0 && return vec + reverse!(smallview(vec, start, stop - shift)) + reverse!(smallview(vec, stop - shift + 1, stop)) + reverse!(smallview(vec, start, stop)) + return vec +end + +@inline function Base.insert!(vec::AbstractSmallVector, index::Integer, item) + resize!(vec, length(vec) + 1) + circshift!(smallview(vec, index, lastindex(vec)), 1) + @inbounds vec[index] = item + return vec +end + +# Don't @inline, makes it slower. +function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item) + mvec = Base.copymutable(vec) + insert!(mvec, index, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.deleteat!(vec::AbstractSmallVector, index::Integer) + circshift!(smallview(vec, index, lastindex(vec)), -1) + resize!(vec, length(vec) - 1) + return vec +end + +# Don't @inline, makes it slower. +function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer) + mvec = Base.copymutable(vec) + deleteat!(mvec, index) + return convert(similar_type(vec), mvec) +end + +# InsertionSortAlg +# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736 +# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort. +# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting. +@inline function Base.sort!(vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false) + lo, hi = firstindex(vec), lastindex(vec) + lo_plus_1 = (lo + 1) + @inbounds for i in lo_plus_1:hi + j = i + x = vec[i] + jmax = j + for _ in jmax:-1:lo_plus_1 + y = vec[j - 1] + if !(lt(by(x), by(y)) != rev) + break + end + vec[j] = y + j -= 1 + end + vec[j] = x + end + return vec +end + +# Don't @inline, makes it slower. +function Base.sort(vec::AbstractSmallVector; kwargs...) + mvec = Base.copymutable(vec) + sort!(mvec; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...) + insert!(vec, searchsortedfirst(vec, item; kwargs...), item) + return vec +end + +function insertsorted(vec::AbstractSmallVector, item; kwargs...) + mvec = Base.copymutable(vec) + insertsorted!(mvec, item; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function insertsortedunique!(vec::AbstractSmallVector, item; kwargs...) + r = searchsorted(vec, item; kwargs...) + if length(r) == 0 + insert!(vec, first(r), item) + end + return vec +end + +# Code repeated since inlining doesn't work. +function insertsortedunique(vec::AbstractSmallVector, item; kwargs...) + r = searchsorted(vec, item; kwargs...) + if length(r) == 0 + vec = insert(vec, first(r), item) + end + return vec +end + +@inline function mergesorted!(vec::AbstractSmallVector, item::AbstractVector; kwargs...) + for x in item + insertsorted!(vec, x; kwargs...) + end + return vec +end + +function mergesorted(vec::AbstractSmallVector, item; kwargs...) + mvec = Base.copymutable(vec) + mergesorted!(mvec, item; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function mergesortedunique!(vec::AbstractSmallVector, item::AbstractVector; kwargs...) + for x in item + insertsortedunique!(vec, x; kwargs...) + end + return vec +end + +# Code repeated since inlining doesn't work. +function mergesortedunique(vec::AbstractSmallVector, item; kwargs...) + for x in item + vec = insertsortedunique(vec, x; kwargs...) + end + return vec +end + +Base.@propagate_inbounds function Base.copyto!(vec::AbstractSmallVector, item::AbstractVector) + for i in eachindex(item) + vec[i] = item[i] + end + return vec +end + +# Don't @inline, makes it slower. +function Base.circshift(vec::AbstractSmallVector, shift::Integer) + mvec = Base.copymutable(vec) + circshift!(mvec, shift) + return convert(similar_type(vec), mvec) +end + +@inline function Base.append!(vec::AbstractSmallVector, item::AbstractVector) + l = length(vec) + r = length(item) + resize!(vec, l + r) + @inbounds copyto!(smallview(vec, l + 1, l + r + 1), item) + return vec +end + +# Missing from `StaticArrays.jl`. +# Don't @inline, makes it slower. +function append(vec::AbstractSmallVector, item::AbstractVector) + mvec = Base.copymutable(vec) + append!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.prepend!(vec::AbstractSmallVector, item::AbstractVector) + l = length(vec) + r = length(item) + resize!(vec, l + r) + circshift!(vec, length(item)) + @inbounds copyto!(vec, item) + return vec +end + +# Missing from `StaticArrays.jl`. +# Don't @inline, makes it slower. +function prepend(vec::AbstractSmallVector, item::AbstractVector) + mvec = Base.copymutable(vec) + prepend!(mvec, item) + return convert(similar_type(vec), mvec) +end + +# Don't @inline, makes it slower. +function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector) + mvec1 = Base.copymutable(vec1) + append!(mvec1, vec2) + return convert(similar_type(vec1), mvec1) +end + +# TODO: inline when defined. +function Base.splice!(a::AbstractSmallVector, args...) + return throw(NotImplemented()) +end diff --git a/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl new file mode 100644 index 0000000000..e31046435d --- /dev/null +++ b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl @@ -0,0 +1,70 @@ +""" +MSmallVector +""" +mutable struct MSmallVector{S,T} <: AbstractSmallVector{T} + const buffer::MVector{S,T} + length::Int +end + +# Constructors +MSmallVector{S}(buffer::AbstractVector, len::Int) where {S} = MSmallVector{S,eltype(buffer)}(buffer, len) +MSmallVector(buffer::AbstractVector, len::Int) = MSmallVector{length(buffer),eltype(buffer)}(buffer, len) + +""" +`MSmallVector` constructor, uses `MVector` as a buffer. +```julia +MSmallVector{10}([1, 2, 3]) +MSmallVector{10}(SA[1, 2, 3]) +``` +""" +function MSmallVector{S,T}(vec::AbstractVector) where {S,T} + buffer = zeros(MVector{S,T}) + copyto!(buffer, vec) + return MSmallVector(buffer, length(vec)) +end + +# Derive the buffer length. +MSmallVector(vec::AbstractSmallVector) = MSmallVector{length(buffer(vec))}(vec) + +Base.convert(::Type{T}, a::AbstractArray) where {T<:MSmallVector} = a isa T ? a : T(a)::T + +function MSmallVector{S}(vec::AbstractVector) where {S} + return MSmallVector{S,eltype(vec)}(vec) +end + +function MSmallVector{S,T}(::UndefInitializer, dims::Tuple{Integer}) where {S,T} + return MSmallVector{S,T}(undef, prod(dims)) +end +function MSmallVector{S,T}(::UndefInitializer, length::Integer) where {S,T} + return MSmallVector{S,T}(MVector{S,T}(undef), length) +end + +# Buffer interface +buffer(vec::MSmallVector) = vec.buffer + +# Accessors +Base.size(vec::MSmallVector) = (vec.length,) + +# Required Base overloads +@inline function Base.getindex(vec::MSmallVector, index::Integer) + @boundscheck checkbounds(vec, index) + return @inbounds buffer(vec)[index] +end + +@inline function Base.setindex!(vec::MSmallVector, item, index::Integer) + @boundscheck checkbounds(vec, index) + @inbounds buffer(vec)[index] = item + return vec +end + +@inline function Base.resize!(vec::MSmallVector, len::Integer) + len < 0 && throw(ArgumentError("New length must be ≥ 0.")) + len > maxlength(vec) && throw(ArgumentError("New length $len must be ≤ the maximum length $(maxlength(vec)).")) + vec.length = len + return vec +end + +# `similar` creates a `MSmallVector` by default. +function Base.similar(vec::AbstractSmallVector, elt::Type, dims::Dims) + return MSmallVector{length(buffer(vec)),elt}(undef, dims) +end diff --git a/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl new file mode 100644 index 0000000000..872b80ac30 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl @@ -0,0 +1,64 @@ +""" +SmallVector +""" +struct SmallVector{S,T} <: AbstractSmallVector{T} + buffer::SVector{S,T} + length::Int +end + +# Accessors +# TODO: Use `Accessors.jl`. +@inline setbuffer(vec::SmallVector, buffer) = SmallVector(buffer, vec.length) +@inline setlength(vec::SmallVector, length) = SmallVector(vec.buffer, length) + +# Constructors +SmallVector{S}(buffer::AbstractVector, len::Int) where {S} = SmallVector{S,eltype(buffer)}(buffer, len) +SmallVector(buffer::AbstractVector, len::Int) = SmallVector{length(buffer),eltype(buffer)}(buffer, len) + +""" +`SmallVector` constructor, uses `SVector` as buffer storage. +```julia +SmallVector{10}([1, 2, 3]) +SmallVector{10}(SA[1, 2, 3]) +``` +""" +function SmallVector{S,T}(vec::AbstractVector) where {S,T} + mvec = MSmallVector{S,T}(vec) + return SmallVector{S,T}(buffer(mvec), length(mvec)) +end +# Special optimization codepath for `MSmallVector` +# to avoid a copy. +function SmallVector{S,T}(vec::MSmallVector) where {S,T} + return SmallVector{S,T}(buffer(vec), length(vec)) +end + +function SmallVector{S}(vec::AbstractVector) where {S} + return SmallVector{S,eltype(vec)}(vec) +end + +# Specialized constructor +function MSmallVector{S,T}(vec::SmallVector) where {S,T} + return MSmallVector{S,T}(buffer(vec), length(vec)) +end + +# Derive the buffer length. +SmallVector(vec::AbstractSmallVector) = SmallVector{length(buffer(vec))}(vec) + +Base.convert(::Type{T}, a::AbstractArray) where {T<:SmallVector} = a isa T ? a : T(a)::T + +# Buffer interface +buffer(vec::SmallVector) = vec.buffer + +# AbstractArray interface +Base.size(vec::SmallVector) = (vec.length,) + +# Base overloads +@inline function Base.getindex(vec::SmallVector, index::Integer) + @boundscheck checkbounds(vec, index) + return @inbounds buffer(vec)[index] +end + +Base.copy(vec::SmallVector) = vec + +# Optimization, default uses `similar`. +Base.copymutable(vec::SmallVector) = MSmallVector(vec) diff --git a/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl b/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl new file mode 100644 index 0000000000..8156e4bb74 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl @@ -0,0 +1,69 @@ +abstract type AbstractSubSmallVector{T} <: AbstractSmallVector{T} end + +""" +SubSmallVector +""" +struct SubSmallVector{T,P} <: AbstractSubSmallVector{T} + parent::P + start::Int + stop::Int +end + +mutable struct SubMSmallVector{T,P<:AbstractVector{T}} <: AbstractSubSmallVector{T} + const parent::P + start::Int + stop::Int +end + +# TODO: Use Accessors.jl +Base.parent(vec::SubSmallVector) = vec.parent +Base.parent(vec::SubMSmallVector) = vec.parent + +# buffer interface +buffer(vec::AbstractSubSmallVector) = buffer(parent(vec)) + +smallview(vec::SmallVector, start::Integer, stop::Integer) = SubSmallVector(vec, start, stop) +smallview(vec::MSmallVector, start::Integer, stop::Integer) = SubMSmallVector(vec, start, stop) + +function smallview(vec::SubSmallVector, start::Integer, stop::Integer) + return SubSmallVector(parent(vec), vec.start + start - 1, vec.start + stop - 1) +end +function smallview(vec::SubMSmallVector, start::Integer, stop::Integer) + return SubMSmallVector(parent(vec), vec.start + start - 1, vec.start + stop - 1) +end + +# Constructors +SubSmallVector(vec::AbstractVector, start::Integer, stop::Integer) = SubSmallVector{eltype(vec),typeof(vec)}(vec, start, stop) +SubMSmallVector(vec::AbstractVector, start::Integer, stop::Integer) = SubMSmallVector{eltype(vec),typeof(vec)}(vec, start, stop) + +# Accessors +Base.size(vec::AbstractSubSmallVector) = (vec.stop - vec.start + 1,) + +Base.@propagate_inbounds function Base.getindex(vec::AbstractSubSmallVector, index::Integer) + return parent(vec)[index + vec.start - 1] +end + +Base.@propagate_inbounds function Base.setindex!(vec::AbstractSubSmallVector, item, index::Integer) + buffer(vec)[index + vec.start - 1] = item + return vec +end + +function SubSmallVector{T,P}(vec::SubMSmallVector) where {T,P} + return SubSmallVector{T,P}(P(parent(vec)), vec.start, vec.stop) +end + +function Base.convert(smalltype::Type{<:SubSmallVector}, vec::SubMSmallVector) + return smalltype(vec) +end + +@inline function Base.resize!(vec::SubMSmallVector, len::Integer) + len < 0 && throw(ArgumentError("New length must be ≥ 0.")) + len > maxlength(vec) - vec.start + 1 && throw(ArgumentError("New length $len must be ≤ the maximum length $(maxlength(vec)).")) + vec.stop = vec.start + len - 1 + return vec +end + +# Optimization, default uses `similar`. +function Base.copymutable(vec::SubSmallVector) + return SubMSmallVector(Base.copymutable(parent(vec)), vec.start, vec.stop) +end diff --git a/NDTensors/src/SmallVectors/test/runtests.jl b/NDTensors/src/SmallVectors/test/runtests.jl new file mode 100644 index 0000000000..48a08bf167 --- /dev/null +++ b/NDTensors/src/SmallVectors/test/runtests.jl @@ -0,0 +1,78 @@ +module TestSmallVectors + include("SmallVectors/src/SmallVectors.jl") + using .SmallVectors + using StaticArrays + using Test + + function test() + @testset "SmallVectors" begin + x = SmallVector{10}([1, 3, 5]) + mx = MSmallVector(x) + + @test x isa SmallVector{10,Int} + @test mx isa MSmallVector{10,Int} + @test eltype(x) === Int + @test eltype(mx) === Int + + # TODO: Test construction has zero allocations. + # TODO: Extend construction to arbitrary collections, like tuple. + + # conversion + @test @inferred(SmallVector(x)) == x + @test @allocated(SmallVector(x)) == 0 + @test @inferred(SmallVector(mx)) == x + @test @allocated(SmallVector(mx)) == 0 + + # length + @test @inferred(length(x)) == 3 + @test @allocated(length(x)) == 0 + @test @inferred(length(SmallVectors.buffer(x))) == 10 + @test @allocated(length(SmallVectors.buffer(x))) == 0 + + item = 115 + no_broken = (false, false, false, false) + for (f!, f, ans, args, f!_impl_broken, f!_noalloc_broken, f_impl_broken, f_noalloc_broken) in [ + (:push!, :push, [1, 3, 5, item], (item,), no_broken...), + (:append!, :(SmallVectors.append), [1, 3, 5, item], ([item],), no_broken...), + (:prepend!, :(SmallVectors.prepend), [item, 1, 3, 5], ([item],), no_broken...), + # (:splice!, :(SmallVectors.splice), [1, item, 3], (2, [item],), true, true, true, true), # Not implemented + (:pushfirst!, :pushfirst, [item, 1, 3, 5], (item,), no_broken...), + (:setindex!, :setindex, [1, item, 5], (item, 2), no_broken...), + (:pop!, :pop, [1, 3], (), no_broken...), + (:popfirst!, :popfirst, [3, 5], (), no_broken...), + (:insert!, :insert, [1, item, 3, 5], (2, item), no_broken...), + (:deleteat!, :deleteat, [1, 5], (2,), no_broken...), + (:circshift!, :circshift, [5, 1, 3], (1,), no_broken...), + (:sort!, :sort, [1, 3, 5], (), no_broken...), + (:(SmallVectors.insertsorted!), :(SmallVectors.insertsorted), [1, 2, 3, 5], (2,), no_broken...), + (:(SmallVectors.insertsorted!), :(SmallVectors.insertsorted), [1, 3, 3, 5], (3,), no_broken...), + (:(SmallVectors.insertsortedunique!), :(SmallVectors.insertsortedunique), [1, 2, 3, 5], (2,), no_broken...), + (:(SmallVectors.insertsortedunique!), :(SmallVectors.insertsortedunique), [1, 3, 5], (3,), no_broken...), + (:(SmallVectors.mergesorted!), :(SmallVectors.mergesorted), [1, 2, 3, 3, 5], ([2, 3],), no_broken...), + (:(SmallVectors.mergesortedunique!), :(SmallVectors.mergesortedunique), [1, 2, 3, 5], ([2, 3],), no_broken...), + ] + mx_tmp = copy(mx) + @eval begin + @test @inferred($f!(copy($mx), $args...)) == $ans broken=$f!_impl_broken + @test @allocated($f!($mx_tmp, $args...)) == 0 broken=$f!_noalloc_broken + @test @inferred($f($x, $args...)) == $ans broken=$f_impl_broken + @test @allocated($f($x, $args...)) == 0 broken=$f_noalloc_broken + end + end + + # Separated out since for some reason it breaks the `@inferred` + # check when `kwargs` are interpolated into `@eval`. + ans, kwargs = [5, 3, 1], (; rev=true) + mx_tmp = copy(mx) + @test @inferred(sort!(copy(mx); kwargs...)) == ans + @test @allocated(sort!(mx_tmp; kwargs...)) == 0 + @test @inferred(sort(x; kwargs...)) == ans + @test @allocated(sort(x; kwargs...)) == 0 + + ans, args = [1, 3, 5, item], ([item],) + @test @inferred(vcat(x, args...)) == ans + @test @allocated(vcat(x, args...)) == 0 + end + end +end +TestSmallVectors.test() diff --git a/NDTensors/test/SmallVectors.jl b/NDTensors/test/SmallVectors.jl new file mode 100644 index 0000000000..62b552dc72 --- /dev/null +++ b/NDTensors/test/SmallVectors.jl @@ -0,0 +1,4 @@ +using Test +using NDTensors + +include(joinpath(pkgdir(NDTensors), "src", "SmallVectors", "test", "runtests.jl")) diff --git a/NDTensors/test/runtests.jl b/NDTensors/test/runtests.jl index 90aeeea118..274e2303ed 100644 --- a/NDTensors/test/runtests.jl +++ b/NDTensors/test/runtests.jl @@ -20,6 +20,7 @@ end @safetestset "NDTensors" begin @testset "$filename" for filename in [ "SetParameters.jl", + "SmallVectors.jl", "linearalgebra.jl", "dense.jl", "blocksparse.jl",