Skip to content

Commit

Permalink
Add source code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 6, 2024
1 parent ad98b39 commit b1d3723
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 7 deletions.
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@ uuid = "b8770bf0-c4ae-4888-b9b0-956061873092"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Aqua = "0.8.9"
ArrayLayouts = "1.11.0"
BroadcastMapConversion = "0.1.0"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Suppressor = "0.2"
Test = "1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "Suppressor", "SafeTestsets"]
43 changes: 41 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,49 @@ julia> Pkg.add("SparseArraysBase")
## Examples

````julia
using SparseArraysBase: SparseArraysBase
using SparseArraysBase:
SparseArrayDOK,
eachstoredindex,
getstoredindex,
getunstoredindex,
isstored,
setstoredindex!,
setunstoredindex!,
storedlength,
storedpairs,
storedvalues
using Test: @test, @test_throws

a = SparseArrayDOK{Float64}(2, 2)
````

AbstractArray interface:

````julia
a[1, 2] = 12
@test a[1, 1] == 0
@test a[2, 1] == 0
@test a[1, 2] == 12
@test a[2, 2] == 0
````

Examples go here.
SparseArraysBase interface:

````julia
@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
@test getstoredindex(a, 1, 2) == 12
@test_throws KeyError getstoredindex(a, 1, 1)
@test getunstoredindex(a, 1, 1) == 0
@test getunstoredindex(a, 1, 2) == 0
@test !isstored(a, 1, 1)
@test isstored(a, 1, 2)
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test storedlength(a) == 1
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
@test issetequal(storedvalues(a), [12])
````

---

Expand Down
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[deps]
SparseArraysBase = "b8770bf0-c4ae-4888-b9b0-956061873092"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
40 changes: 38 additions & 2 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,41 @@ julia> Pkg.add("SparseArraysBase")

# ## Examples

using SparseArraysBase: SparseArraysBase
# Examples go here.
using SparseArraysBase:
SparseArrayDOK,
eachstoredindex,
getstoredindex,
getunstoredindex,
isstored,
setstoredindex!,
setunstoredindex!,
storedlength,
storedpairs,
storedvalues
using Test: @test, @test_throws

a = SparseArrayDOK{Float64}(2, 2)

# AbstractArray interface:

a[1, 2] = 12
@test a[1, 1] == 0
@test a[2, 1] == 0
@test a[1, 2] == 12
@test a[2, 2] == 0

# SparseArraysBase interface:

@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
@test getstoredindex(a, 1, 2) == 12
@test_throws KeyError getstoredindex(a, 1, 1)
@test getunstoredindex(a, 1, 1) == 0
@test getunstoredindex(a, 1, 2) == 0
@test !isstored(a, 1, 1)
@test isstored(a, 1, 2)
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test storedlength(a) == 1
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
@test issetequal(storedvalues(a), [12])
4 changes: 3 additions & 1 deletion src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module SparseArraysBase

# Write your package code here.
include("sparsearrayinterface.jl")
include("wrappers.jl")
include("sparsearraydok.jl")

end
51 changes: 51 additions & 0 deletions src/sparsearraydok.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# TODO: Define `AbstractSparseArray`, make this a subtype.
struct SparseArrayDOK{T,N} <: AbstractArray{T,N}
storage::Dict{CartesianIndex{N},T}
size::NTuple{N,Int}
end

function SparseArrayDOK{T}(size::Int...) where {T}
N = length(size)
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size)
end

using Derive: @wrappedtype
# Define `WrappedSparseArrayDOK` and `AnySparseArrayDOK`.
@wrappedtype SparseArrayDOK

using Derive: Derive
function Derive.interface(::Type{<:SparseArrayDOK})
return SparseArrayInterface()
end

using Derive: @derive
@derive AnySparseArrayDOK AbstractArrayOps

storage(a::SparseArrayDOK) = a.storage
Base.size(a::SparseArrayDOK) = a.size

storedvalues(a::SparseArrayDOK) = values(storage(a))
function isstored(a::SparseArrayDOK, I::Int...)
return CartesianIndex(I) in keys(storage(a))
end
function eachstoredindex(a::SparseArrayDOK)
return keys(storage(a))
end
function getstoredindex(a::SparseArrayDOK, I::Int...)
return storage(a)[CartesianIndex(I)]
end
function getunstoredindex(a::SparseArrayDOK, I::Int...)
return zero(eltype(a))
end
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
isstored(a, I...) || throw(KeyError(CartesianIndex(I)))
storage(a)[CartesianIndex(I)] = value
return a
end
function setunstoredindex!(a::SparseArrayDOK, value, I::Int...)
storage(a)[CartesianIndex(I)] = value
return a
end

# Optional, but faster than the default.
storedpairs(a::SparseArrayDOK) = storage(a)
161 changes: 161 additions & 0 deletions src/sparsearrayinterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Minimal interface for `SparseArrayInterface`.
# TODO: Define default definitions for these based
# on the dense case.
storedvalues(a) = error()
isstored(a, I::Int...) = error()
eachstoredindex(a) = error()
getstoredindex(a, I::Int...) = error()
getunstoredindex(a, I::Int...) = error()
setstoredindex!(a, value, I::Int...) = error()
setunstoredindex!(a, value, I::Int...) = error()

# Derived interface.
storedlength(a) = length(storedvalues(a))
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))

function eachstoredindex(a1, a2, a_rest...)
# TODO: Make this more customizable, say with a function
# `combine/promote_storedindices(a1, a2)`.
return union(eachstoredindex.((a1, a2, a_rest...))...)
end

# TODO: Add `ndims` type parameter.
# TODO: Define `AbstractSparseArrayInterface`, make this a subtype.
using Derive: Derive, @interface, AbstractArrayInterface
struct SparseArrayInterface <: AbstractArrayInterface end

# Convenient shorthand to refer to the sparse interface.
const sparse = SparseArrayInterface()

# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
# to handle slicing (implemented by copying SubArray).
@interface sparse function Base.getindex(a, I::Int...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end

@interface sparse function Base.setindex!(a, value, I::Int...)
iszero(value) && return a
if !isstored(a, I...)
setunstoredindex!(a, value, I...)
return a
end
setstoredindex!(a, value, I...)
return a
end

# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
@interface sparse function Base.similar(a, T::Type, size::Tuple{Vararg{Int}})
return SparseArrayDOK{T}(size...)
end

## TODO: Make this more general, handle mixtures of integers and ranges.
## TODO: Make this logic generic to all `similar(::AbstractInterface, ...)`.
## @interface sparse function Base.similar(a, T::Type, dims::Tuple{Vararg{Base.OneTo}})
## return sparse(similar)(interface, a, T, Base.to_shape(dims))
## end

@interface sparse function Base.map(f, as...)
# This is defined in this way so we can rely on the Broadcast logic
# for determining the destination of the operation (element type, shape, etc.).
return f.(as...)
end

@interface sparse function Base.map!(f, dest, as...)
# Check `f` preserves zeros.
# Define as `map_stored!`.
# Define `eachstoredindex` promotion.
for I in eachstoredindex(as...)
dest[I] = f(map(a -> a[I], as)...)
end
return dest
end

# TODO: Define `AbstractSparseArrayStyle`, make this a subtype.
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end

SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

@interface sparse function Broadcast.BroadcastStyle(type::Type)
return SparseArrayStyle{ndims(type)}()
end

function Base.similar(bc::Broadcast.Broadcasted{<:SparseArrayStyle}, T::Type, axes::Tuple)
# TODO: Allow `similar` to accept `axes` directly.
return sparse(similar)(bc, T, Int.(length.(axes)))
end

using BroadcastMapConversion: map_function, map_args
# TODO: Look into `SparseArrays.capturescalars`:
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
function Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{<:SparseArrayStyle})
sparse(map!)(map_function(bc), dest, map_args(bc)...)
return dest
end

using ArrayLayouts: ArrayLayouts, MatMulMatAdd

abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end

struct SparseLayout <: AbstractSparseLayout end

@interface sparse function ArrayLayouts.MemoryLayout(type::Type)
return SparseLayout()
end

using LinearAlgebra: LinearAlgebra
@interface sparse function LinearAlgebra.mul!(a_dest, a1, a2, α, β)
return ArrayLayouts.mul!(a_dest, a1, a2, α, β)
end

function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
if I1[2] I2[1]
return nothing
end
return CartesianIndex(I1[1], I2[2])
end

function default_mul!!(
a_dest::AbstractMatrix,
a1::AbstractMatrix,
a2::AbstractMatrix,
α::Number=true,
β::Number=false,
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

function default_mul!!(
a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false
)
return a1 * a2 * α + a_dest * β
end

# a1 * a2 * α + a_dest * β
function sparse_mul!(
a_dest::AbstractArray,
a1::AbstractArray,
a2::AbstractArray,
α::Number=true,
β::Number=false;
(mul!!)=(default_mul!!),
)
for I1 in eachstoredindex(a1)
for I2 in eachstoredindex(a2)
I_dest = mul_indices(I1, I2)
if !isnothing(I_dest)
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
end
end
end
return a_dest
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
)
sparse_mul!(m.C, m.A, m.B, m.α, m.β)
return m.C
end
47 changes: 47 additions & 0 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using LinearAlgebra: Adjoint
storedvalues(a::Adjoint) = storedvalues(parent(a))
function isstored(a::Adjoint, i::Int, j::Int)
return isstored(parent(a), j, i)
end
function eachstoredindex(a::Adjoint)
# TODO: Make lazy with `Iterators.map`.
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
end
function getstoredindex(a::Adjoint, i::Int, j::Int)
return getstoredindex(parent(a), j, i)'
end
function getunstoredindex(a::Adjoint, i::Int, j::Int)
return getunstoredindex(parent(a), j, i)'
end
function setstoredindex!(a::Adjoint, value, i::Int, j::Int)
setstoredindex!(parent(a), value', j, i)
return a
end
function setunstoredindex!(a::Adjoint, value, i::Int, j::Int)
setunstoredindex!(parent(a), value', j, i)
return a
end

using LinearAlgebra: Transpose
storedvalues(a::Transpose) = storedvalues(parent(a))
function isstored(a::Transpose, i::Int, j::Int)
return isstored(parent(a), j, i)
end
function eachstoredindex(a::Transpose)
# TODO: Make lazy with `Iterators.map`.
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
end
function getstoredindex(a::Transpose, i::Int, j::Int)
return transpose(getstoredindex(parent(a), j, i))
end
function getunstoredindex(a::Transpose, i::Int, j::Int)
return transpose(getunstoredindex(parent(a), j, i))
end
function setstoredindex!(a::Transpose, value, i::Int, j::Int)
setstoredindex!(parent(a), transpose(value), j, i)
return a
end
function setunstoredindex!(a::Transpose, value, i::Int, j::Int)
setunstoredindex!(parent(a), transpose(value), j, i)
return a
end
Loading

0 comments on commit b1d3723

Please sign in to comment.