-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NDTensors] Add SmallVectors submodule
- Loading branch information
Showing
11 changed files
with
633 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SmallVectors | ||
|
||
A module that defines small (mutable and immutable) vectors with a maximum length. Externally the have a dynamic (or in the case of immuatable vectors, runtime) length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type. | ||
|
||
For example: | ||
```julia | ||
using NDTensors.SmallVectors | ||
v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10 | ||
v = push(v, 4) | ||
v = setindex(v, 4, 4) | ||
v = sort(v; rev=true) | ||
|
||
mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10 | ||
push!(mv, 4) | ||
mv[2] = 12 | ||
sort!(mv; rev=true) | ||
``` | ||
This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
module SmallVectors | ||
using StaticArrays | ||
|
||
include("StaticArraysExt.jl") | ||
using .StaticArraysExt | ||
|
||
export SmallVector, MSmallVector, SubSmallVector | ||
|
||
struct NotImplemented <: Exception | ||
msg::String | ||
end | ||
NotImplemented() = NotImplemented("Not implemented.") | ||
|
||
struct BufferDimensionMismatch <: Exception | ||
msg::String | ||
end | ||
|
||
include("abstractsmallvector/abstractsmallvector.jl") | ||
include("abstractsmallvector/deque.jl") | ||
include("msmallvector/msmallvector.jl") | ||
include("smallvector/smallvector.jl") | ||
include("subsmallvector/subsmallvector.jl") | ||
end |
24 changes: 24 additions & 0 deletions
24
NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
A vector with a fixed maximum length, backed by a fixed size buffer. | ||
""" | ||
abstract type AbstractSmallVector{T} <: AbstractVector{T} end | ||
|
||
# Required buffer interface | ||
buffer(vec::AbstractSmallVector) = throw(NotImplemented()) | ||
|
||
similar_type(vec::AbstractSmallVector) = typeof(vec) | ||
|
||
# Required buffer interface | ||
maxlength(vec::AbstractSmallVector) = length(buffer(vec)) | ||
|
||
# Required AbstractArray interface | ||
Base.size(vec::AbstractSmallVector) = throw(NotImplemented()) | ||
|
||
# Derived AbstractArray interface | ||
function Base.getindex(vec::AbstractSmallVector, index::Integer) | ||
return throw(NotImplemented()) | ||
end | ||
function Base.setindex!(vec::AbstractSmallVector, item, index::Integer) | ||
return throw(NotImplemented()) | ||
end | ||
Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear() |
280 changes: 280 additions & 0 deletions
280
NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
# TODO: Set operations | ||
# union, ∪, union! | ||
# intersect, ∩, intersect! | ||
# setdiff, setdiff! | ||
# symdiff, symdiff! | ||
# unique, unique! | ||
|
||
Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented()) | ||
|
||
@inline function resize(vec::AbstractSmallVector, len) | ||
mvec = Base.copymutable(vec) | ||
resize!(mvec, len) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.empty!(vec::AbstractSmallVector) | ||
resize!(vec, 0) | ||
return vec | ||
end | ||
|
||
@inline function empty(vec::AbstractSmallVector) | ||
mvec = Base.copymutable(vec) | ||
empty!(mvec) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer) | ||
@boundscheck checkbounds(vec, index) | ||
mvec = Base.copymutable(vec) | ||
@inbounds mvec[index] = item | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.push!(vec::AbstractSmallVector, item) | ||
resize!(vec, length(vec) + 1) | ||
@inbounds vec[length(vec)] = item | ||
return vec | ||
end | ||
|
||
@inline function StaticArrays.push(vec::AbstractSmallVector, item) | ||
mvec = Base.copymutable(vec) | ||
push!(mvec, item) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.pop!(vec::AbstractSmallVector) | ||
resize!(vec, length(vec) - 1) | ||
return vec | ||
end | ||
|
||
@inline function StaticArrays.pop(vec::AbstractSmallVector) | ||
mvec = Base.copymutable(vec) | ||
pop!(mvec) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.pushfirst!(vec::AbstractSmallVector, item) | ||
insert!(vec, firstindex(vec), item) | ||
return vec | ||
end | ||
|
||
# Don't `@inline`, makes it slower. | ||
function StaticArrays.pushfirst(vec::AbstractSmallVector, item) | ||
mvec = Base.copymutable(vec) | ||
pushfirst!(mvec, item) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.popfirst!(vec::AbstractSmallVector) | ||
circshift!(vec, -1) | ||
resize!(vec, length(vec) - 1) | ||
return vec | ||
end | ||
|
||
# Don't `@inline`, makes it slower. | ||
function StaticArrays.popfirst(vec::AbstractSmallVector) | ||
mvec = Base.copymutable(vec) | ||
popfirst!(mvec) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.reverse!(vec::AbstractSmallVector) | ||
start, stop = firstindex(vec), lastindex(vec) | ||
r = stop | ||
@inbounds for i in start:Base.midpoint(start, stop-1) | ||
vec[i], vec[r] = vec[r], vec[i] | ||
r -= 1 | ||
end | ||
return vec | ||
end | ||
|
||
@inline function Base.reverse!(vec::AbstractSmallVector, start::Integer, stop::Integer=lastindex(v)) | ||
reverse!(smallview(vec, start, stop)) | ||
return vec | ||
end | ||
|
||
@inline function Base.circshift!(vec::AbstractSmallVector, shift::Integer) | ||
start, stop = firstindex(vec), lastindex(vec) | ||
n = length(vec) | ||
n == 0 && return vec | ||
shift = mod(shift, n) | ||
shift == 0 && return vec | ||
reverse!(smallview(vec, start, stop - shift)) | ||
reverse!(smallview(vec, stop - shift + 1, stop)) | ||
reverse!(smallview(vec, start, stop)) | ||
return vec | ||
end | ||
|
||
@inline function Base.insert!(vec::AbstractSmallVector, index::Integer, item) | ||
resize!(vec, length(vec) + 1) | ||
circshift!(smallview(vec, index, lastindex(vec)), 1) | ||
@inbounds vec[index] = item | ||
return vec | ||
end | ||
|
||
# Don't @inline, makes it slower. | ||
function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item) | ||
mvec = Base.copymutable(vec) | ||
insert!(mvec, index, item) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.deleteat!(vec::AbstractSmallVector, index::Integer) | ||
circshift!(smallview(vec, index, lastindex(vec)), -1) | ||
resize!(vec, length(vec) - 1) | ||
return vec | ||
end | ||
|
||
# Don't @inline, makes it slower. | ||
function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer) | ||
mvec = Base.copymutable(vec) | ||
deleteat!(mvec, index) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
# InsertionSortAlg | ||
# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736 | ||
# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort. | ||
# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting. | ||
@inline function Base.sort!(vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false) | ||
lo, hi = firstindex(vec), lastindex(vec) | ||
lo_plus_1 = (lo + 1) | ||
@inbounds for i in lo_plus_1:hi | ||
j = i | ||
x = vec[i] | ||
jmax = j | ||
for _ in jmax:-1:lo_plus_1 | ||
y = vec[j - 1] | ||
if !(lt(by(x), by(y)) != rev) | ||
break | ||
end | ||
vec[j] = y | ||
j -= 1 | ||
end | ||
vec[j] = x | ||
end | ||
return vec | ||
end | ||
|
||
# Don't @inline, makes it slower. | ||
function Base.sort(vec::AbstractSmallVector; kwargs...) | ||
mvec = Base.copymutable(vec) | ||
sort!(mvec; kwargs...) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...) | ||
insert!(vec, searchsortedfirst(vec, item; kwargs...), item) | ||
return vec | ||
end | ||
|
||
function insertsorted(vec::AbstractSmallVector, item; kwargs...) | ||
mvec = Base.copymutable(vec) | ||
insertsorted!(mvec, item; kwargs...) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function insertsortedunique!(vec::AbstractSmallVector, item; kwargs...) | ||
r = searchsorted(vec, item; kwargs...) | ||
if length(r) == 0 | ||
insert!(vec, first(r), item) | ||
end | ||
return vec | ||
end | ||
|
||
# Code repeated since inlining doesn't work. | ||
function insertsortedunique(vec::AbstractSmallVector, item; kwargs...) | ||
r = searchsorted(vec, item; kwargs...) | ||
if length(r) == 0 | ||
vec = insert(vec, first(r), item) | ||
end | ||
return vec | ||
end | ||
|
||
@inline function mergesorted!(vec::AbstractSmallVector, item::AbstractVector; kwargs...) | ||
for x in item | ||
insertsorted!(vec, x; kwargs...) | ||
end | ||
return vec | ||
end | ||
|
||
function mergesorted(vec::AbstractSmallVector, item; kwargs...) | ||
mvec = Base.copymutable(vec) | ||
mergesorted!(mvec, item; kwargs...) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function mergesortedunique!(vec::AbstractSmallVector, item::AbstractVector; kwargs...) | ||
for x in item | ||
insertsortedunique!(vec, x; kwargs...) | ||
end | ||
return vec | ||
end | ||
|
||
# Code repeated since inlining doesn't work. | ||
function mergesortedunique(vec::AbstractSmallVector, item; kwargs...) | ||
for x in item | ||
vec = insertsortedunique(vec, x; kwargs...) | ||
end | ||
return vec | ||
end | ||
|
||
Base.@propagate_inbounds function Base.copyto!(vec::AbstractSmallVector, item::AbstractVector) | ||
for i in eachindex(item) | ||
vec[i] = item[i] | ||
end | ||
return vec | ||
end | ||
|
||
# Don't @inline, makes it slower. | ||
function Base.circshift(vec::AbstractSmallVector, shift::Integer) | ||
mvec = Base.copymutable(vec) | ||
circshift!(mvec, shift) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.append!(vec::AbstractSmallVector, item::AbstractVector) | ||
l = length(vec) | ||
r = length(item) | ||
resize!(vec, l + r) | ||
@inbounds copyto!(smallview(vec, l + 1, l + r + 1), item) | ||
return vec | ||
end | ||
|
||
# Missing from `StaticArrays.jl`. | ||
# Don't @inline, makes it slower. | ||
function append(vec::AbstractSmallVector, item::AbstractVector) | ||
mvec = Base.copymutable(vec) | ||
append!(mvec, item) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
@inline function Base.prepend!(vec::AbstractSmallVector, item::AbstractVector) | ||
l = length(vec) | ||
r = length(item) | ||
resize!(vec, l + r) | ||
circshift!(vec, length(item)) | ||
@inbounds copyto!(vec, item) | ||
return vec | ||
end | ||
|
||
# Missing from `StaticArrays.jl`. | ||
# Don't @inline, makes it slower. | ||
function prepend(vec::AbstractSmallVector, item::AbstractVector) | ||
mvec = Base.copymutable(vec) | ||
prepend!(mvec, item) | ||
return convert(similar_type(vec), mvec) | ||
end | ||
|
||
# Don't @inline, makes it slower. | ||
function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector) | ||
mvec1 = Base.copymutable(vec1) | ||
append!(mvec1, vec2) | ||
return convert(similar_type(vec1), mvec1) | ||
end | ||
|
||
# TODO: inline when defined. | ||
function Base.splice!(a::AbstractSmallVector, args...) | ||
return throw(NotImplemented()) | ||
end |
Oops, something went wrong.