Skip to content

Commit

Permalink
[NDTensors] Add SmallVectors submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Sep 21, 2023
1 parent b04a973 commit 6e692d9
Show file tree
Hide file tree
Showing 11 changed files with 633 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("SmallVectors/src/SmallVectors.jl")
using .SmallVectors

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
18 changes: 18 additions & 0 deletions NDTensors/src/SmallVectors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SmallVectors

A module that defines small (mutable and immutable) vectors with a maximum length. Externally the have a dynamic (or in the case of immuatable vectors, runtime) length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type.

For example:
```julia
using NDTensors.SmallVectors
v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10
v = push(v, 4)
v = setindex(v, 4, 4)
v = sort(v; rev=true)

mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10
push!(mv, 4)
mv[2] = 12
sort!(mv; rev=true)
```
This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length.
23 changes: 23 additions & 0 deletions NDTensors/src/SmallVectors/src/SmallVectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module SmallVectors
using StaticArrays

include("StaticArraysExt.jl")
using .StaticArraysExt

export SmallVector, MSmallVector, SubSmallVector

struct NotImplemented <: Exception
msg::String
end
NotImplemented() = NotImplemented("Not implemented.")

struct BufferDimensionMismatch <: Exception
msg::String
end

include("abstractsmallvector/abstractsmallvector.jl")
include("abstractsmallvector/deque.jl")
include("msmallvector/msmallvector.jl")
include("smallvector/smallvector.jl")
include("subsmallvector/subsmallvector.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
A vector with a fixed maximum length, backed by a fixed size buffer.
"""
abstract type AbstractSmallVector{T} <: AbstractVector{T} end

# Required buffer interface
buffer(vec::AbstractSmallVector) = throw(NotImplemented())

similar_type(vec::AbstractSmallVector) = typeof(vec)

# Required buffer interface
maxlength(vec::AbstractSmallVector) = length(buffer(vec))

# Required AbstractArray interface
Base.size(vec::AbstractSmallVector) = throw(NotImplemented())

# Derived AbstractArray interface
function Base.getindex(vec::AbstractSmallVector, index::Integer)
return throw(NotImplemented())
end
function Base.setindex!(vec::AbstractSmallVector, item, index::Integer)
return throw(NotImplemented())
end
Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear()
280 changes: 280 additions & 0 deletions NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# TODO: Set operations
# union, ∪, union!
# intersect, ∩, intersect!
# setdiff, setdiff!
# symdiff, symdiff!
# unique, unique!

Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented())

@inline function resize(vec::AbstractSmallVector, len)
mvec = Base.copymutable(vec)
resize!(mvec, len)
return convert(similar_type(vec), mvec)
end

@inline function Base.empty!(vec::AbstractSmallVector)
resize!(vec, 0)
return vec
end

@inline function empty(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
empty!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer)
@boundscheck checkbounds(vec, index)
mvec = Base.copymutable(vec)
@inbounds mvec[index] = item
return convert(similar_type(vec), mvec)
end

@inline function Base.push!(vec::AbstractSmallVector, item)
resize!(vec, length(vec) + 1)
@inbounds vec[length(vec)] = item
return vec
end

@inline function StaticArrays.push(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
push!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.pop!(vec::AbstractSmallVector)
resize!(vec, length(vec) - 1)
return vec
end

@inline function StaticArrays.pop(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
pop!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function Base.pushfirst!(vec::AbstractSmallVector, item)
insert!(vec, firstindex(vec), item)
return vec
end

# Don't `@inline`, makes it slower.
function StaticArrays.pushfirst(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
pushfirst!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.popfirst!(vec::AbstractSmallVector)
circshift!(vec, -1)
resize!(vec, length(vec) - 1)
return vec
end

# Don't `@inline`, makes it slower.
function StaticArrays.popfirst(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
popfirst!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function Base.reverse!(vec::AbstractSmallVector)
start, stop = firstindex(vec), lastindex(vec)
r = stop
@inbounds for i in start:Base.midpoint(start, stop-1)
vec[i], vec[r] = vec[r], vec[i]
r -= 1
end
return vec
end

@inline function Base.reverse!(vec::AbstractSmallVector, start::Integer, stop::Integer=lastindex(v))
reverse!(smallview(vec, start, stop))
return vec
end

@inline function Base.circshift!(vec::AbstractSmallVector, shift::Integer)
start, stop = firstindex(vec), lastindex(vec)
n = length(vec)
n == 0 && return vec
shift = mod(shift, n)
shift == 0 && return vec
reverse!(smallview(vec, start, stop - shift))
reverse!(smallview(vec, stop - shift + 1, stop))
reverse!(smallview(vec, start, stop))
return vec
end

@inline function Base.insert!(vec::AbstractSmallVector, index::Integer, item)
resize!(vec, length(vec) + 1)
circshift!(smallview(vec, index, lastindex(vec)), 1)
@inbounds vec[index] = item
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item)
mvec = Base.copymutable(vec)
insert!(mvec, index, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.deleteat!(vec::AbstractSmallVector, index::Integer)
circshift!(smallview(vec, index, lastindex(vec)), -1)
resize!(vec, length(vec) - 1)
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer)
mvec = Base.copymutable(vec)
deleteat!(mvec, index)
return convert(similar_type(vec), mvec)
end

# InsertionSortAlg
# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736
# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort.
# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting.
@inline function Base.sort!(vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false)
lo, hi = firstindex(vec), lastindex(vec)
lo_plus_1 = (lo + 1)
@inbounds for i in lo_plus_1:hi
j = i
x = vec[i]
jmax = j
for _ in jmax:-1:lo_plus_1
y = vec[j - 1]
if !(lt(by(x), by(y)) != rev)
break
end
vec[j] = y
j -= 1
end
vec[j] = x
end
return vec
end

# Don't @inline, makes it slower.
function Base.sort(vec::AbstractSmallVector; kwargs...)
mvec = Base.copymutable(vec)
sort!(mvec; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...)
insert!(vec, searchsortedfirst(vec, item; kwargs...), item)
return vec
end

function insertsorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
insertsorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function insertsortedunique!(vec::AbstractSmallVector, item; kwargs...)
r = searchsorted(vec, item; kwargs...)
if length(r) == 0
insert!(vec, first(r), item)
end
return vec
end

# Code repeated since inlining doesn't work.
function insertsortedunique(vec::AbstractSmallVector, item; kwargs...)
r = searchsorted(vec, item; kwargs...)
if length(r) == 0
vec = insert(vec, first(r), item)
end
return vec
end

@inline function mergesorted!(vec::AbstractSmallVector, item::AbstractVector; kwargs...)
for x in item
insertsorted!(vec, x; kwargs...)
end
return vec
end

function mergesorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
mergesorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function mergesortedunique!(vec::AbstractSmallVector, item::AbstractVector; kwargs...)
for x in item
insertsortedunique!(vec, x; kwargs...)
end
return vec
end

# Code repeated since inlining doesn't work.
function mergesortedunique(vec::AbstractSmallVector, item; kwargs...)
for x in item
vec = insertsortedunique(vec, x; kwargs...)
end
return vec
end

Base.@propagate_inbounds function Base.copyto!(vec::AbstractSmallVector, item::AbstractVector)
for i in eachindex(item)
vec[i] = item[i]
end
return vec
end

# Don't @inline, makes it slower.
function Base.circshift(vec::AbstractSmallVector, shift::Integer)
mvec = Base.copymutable(vec)
circshift!(mvec, shift)
return convert(similar_type(vec), mvec)
end

@inline function Base.append!(vec::AbstractSmallVector, item::AbstractVector)
l = length(vec)
r = length(item)
resize!(vec, l + r)
@inbounds copyto!(smallview(vec, l + 1, l + r + 1), item)
return vec
end

# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function append(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
append!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.prepend!(vec::AbstractSmallVector, item::AbstractVector)
l = length(vec)
r = length(item)
resize!(vec, l + r)
circshift!(vec, length(item))
@inbounds copyto!(vec, item)
return vec
end

# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function prepend(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
prepend!(mvec, item)
return convert(similar_type(vec), mvec)
end

# Don't @inline, makes it slower.
function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector)
mvec1 = Base.copymutable(vec1)
append!(mvec1, vec2)
return convert(similar_type(vec1), mvec1)
end

# TODO: inline when defined.
function Base.splice!(a::AbstractSmallVector, args...)
return throw(NotImplemented())
end
Loading

0 comments on commit 6e692d9

Please sign in to comment.