Skip to content

Commit

Permalink
Update for Derive.jl v0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 8, 2024
1 parent 9a432bf commit 0514aa0
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 188 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ version = "0.1.0"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Aqua = "0.8.9"
ArrayLayouts = "1.11.0"
BroadcastMapConversion = "0.1.0"
Derive = "0.3.0"
Dictionaries = "0.4.3"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Suppressor = "0.2"
Expand Down
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Updates for latest Derive.
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12 changes: 10 additions & 2 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,16 @@ a[1, 2] = 12

# SparseArraysBase interface:

using Dictionaries: IndexError
@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
@test getstoredindex(a, 1, 2) == 12
@test_throws KeyError getstoredindex(a, 1, 1)
@test_throws IndexError getstoredindex(a, 1, 1)
@test getunstoredindex(a, 1, 1) == 0
@test getunstoredindex(a, 1, 2) == 0
@test !isstored(a, 1, 1)
@test isstored(a, 1, 2)
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
@test_throws IndexError setstoredindex!(copy(a), 21, 2, 1)
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
@test storedlength(a) == 1
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
Expand All @@ -80,8 +81,15 @@ a[1, 2] = 12
# AbstractArray functionality:

b = a .+ 2 .* a'
@test b isa SparseArrayDOK{Float64}
@test b == [0 12; 24 0]
@test storedlength(b) == 2

b = permutedims(a, (2, 1))
@test b isa SparseArrayDOK{Float64}
@test b[1, 1] == a[1, 1]
@test b[2, 1] == a[1, 2]
@test b[1, 2] == a[2, 1]
@test b[2, 2] == a[2, 2]

a * a'
1 change: 1 addition & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module SparseArraysBase

include("abstractsparsearrayinterface.jl")
include("sparsearrayinterface.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
Expand Down
20 changes: 11 additions & 9 deletions src/abstractsparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ function Derive.interface(::Type{<:AbstractSparseArray})
end

using Derive: @derive
# Derive `Base.getindex`, `Base.setindex!`, etc.
@derive AnyAbstractSparseArray AbstractArrayOps

# TODO: These need to be loaded since `AbstractArrayOps`
# includes overloads of functions from these modules.
# Ideally that wouldn't be needed and can be circumvented
# with `GlobalRef`.
using ArrayLayouts: ArrayLayouts
using LinearAlgebra: LinearAlgebra
@derive (T=AnyAbstractSparseVecOrMat,) begin
LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number)
end

using ArrayLayouts: ArrayLayouts
@derive (T=AnyAbstractSparseArray,) begin
ArrayLayouts.MemoryLayout(::Type{<:T})
end
# Derive `Base.getindex`, `Base.setindex!`, etc.
# TODO: Define `AbstractMatrixOps` and overload for
# `AnyAbstractSparseMatrix` and `AnyAbstractSparseVector`,
# which is where matrix multiplication and factorizations
# shoudl go.
@derive AnyAbstractSparseArray AbstractArrayOps
141 changes: 141 additions & 0 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Minimal interface for `SparseArrayInterface`.
# TODO: Define default definitions for these based
# on the dense case.
storedvalues(a) = error()
isstored(a, I::Int...) = error()
eachstoredindex(a) = error()
getstoredindex(a, I::Int...) = error()
setstoredindex!(a, value, I::Int...) = error()
setunstoredindex!(a, value, I::Int...) = error()

# Interface defaults.
# TODO: Have a fallback that handles element types
# that don't define `zero(::Type)`.
getunstoredindex(a, I::Int...) = zero(eltype(a))

# Derived interface.
storedlength(a) = length(storedvalues(a))
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))

function eachstoredindex(a1, a2, a_rest...)
# TODO: Make this more customizable, say with a function
# `combine/promote_storedindices(a1, a2)`.
return union(eachstoredindex.((a1, a2, a_rest...))...)
end

using Derive: Derive, @interface, AbstractArrayInterface

# TODO: Add `ndims` type parameter.
# TODO: This isn't used to define interface functions right now.
# Currently, `@interface` expects an instance, probably it should take a
# type instead so fallback functions can use abstract types.
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end

# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
# to handle slicing (implemented by copying SubArray).
@interface AbstractSparseArrayInterface function Base.getindex(a, I::Int...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end

@interface AbstractSparseArrayInterface function Base.setindex!(a, value, I::Int...)
iszero(value) && return a
if !isstored(a, I...)
setunstoredindex!(a, value, I...)
return a
end
setstoredindex!(a, value, I...)
return a
end

# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
@interface AbstractSparseArrayInterface function Base.similar(
a, T::Type, size::Tuple{Vararg{Int}}
)
# TODO: Define `default_similartype` or something like that?
return SparseArrayDOK{T}(size...)
end

@interface AbstractSparseArrayInterface function Base.map!(f, dest, as...)
# Check `f` preserves zeros.
# Define as `map_stored!`.
# Define `eachstoredindex` promotion.
for I in eachstoredindex(as...)
dest[I] = f(map(a -> a[I], as)...)
end
return dest
end

# TODO: Make this a subtype of `Derive.AbstractArrayStyle{N}` instead.
using Derive: Derive
abstract type AbstractSparseArrayStyle{N} <: Derive.AbstractArrayStyle{N} end

struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end

SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

@interface AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
return SparseArrayStyle{ndims(type)}()
end

using ArrayLayouts: ArrayLayouts, MatMulMatAdd

abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end

struct SparseLayout <: AbstractSparseLayout end

@interface AbstractSparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
return SparseLayout()
end

function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
if I1[2] I2[1]
return nothing
end
return CartesianIndex(I1[1], I2[2])
end

function default_mul!!(
a_dest::AbstractMatrix,
a1::AbstractMatrix,
a2::AbstractMatrix,
α::Number=true,
β::Number=false,
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

function default_mul!!(
a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false
)
return a1 * a2 * α + a_dest * β
end

# a1 * a2 * α + a_dest * β
function sparse_mul!(
a_dest::AbstractArray,
a1::AbstractArray,
a2::AbstractArray,
α::Number=true,
β::Number=false;
(mul!!)=(default_mul!!),
)
for I1 in eachstoredindex(a1)
for I2 in eachstoredindex(a2)
I_dest = mul_indices(I1, I2)
if !isnothing(I_dest)
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
end
end
end
return a_dest
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
)
sparse_mul!(m.C, m.A, m.B, m.α, m.β)
return m.C
end
24 changes: 16 additions & 8 deletions src/sparsearraydok.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
# TODO: Rewrite to use `Dictionary`.
struct SparseArrayDOK{T,N} <: AbstractSparseArray{T,N}
storage::Dict{CartesianIndex{N},T}
using Dictionaries: Dictionary, IndexError, set!

function default_getunstoredindex(a::AbstractArray, I::Int...)
return zero(eltype(a))
end

struct SparseArrayDOK{T,N,F} <: AbstractSparseArray{T,N}
storage::Dictionary{CartesianIndex{N},T}
size::NTuple{N,Int}
getunstoredindex::F
end

function SparseArrayDOK{T,N}(size::Vararg{Int,N}) where {T,N}
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size)
getunstoredindex = default_getunstoredindex
F = typeof(getunstoredindex)
return SparseArrayDOK{T,N,F}(Dictionary{CartesianIndex{N},T}(), size, getunstoredindex)
end

function SparseArrayDOK{T}(size::Int...) where {T}
Expand All @@ -30,17 +38,17 @@ function getstoredindex(a::SparseArrayDOK, I::Int...)
return storage(a)[CartesianIndex(I)]
end
function getunstoredindex(a::SparseArrayDOK, I::Int...)
return zero(eltype(a))
return a.getunstoredindex(a, I...)
end
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
isstored(a, I...) || throw(KeyError(CartesianIndex(I)))
isstored(a, I...) || throw(IndexError("key $(CartesianIndex(I)) not found"))
storage(a)[CartesianIndex(I)] = value
return a
end
function setunstoredindex!(a::SparseArrayDOK, value, I::Int...)
storage(a)[CartesianIndex(I)] = value
set!(storage(a), CartesianIndex(I), value)
return a
end

# Optional, but faster than the default.
storedpairs(a::SparseArrayDOK) = storage(a)
storedpairs(a::SparseArrayDOK) = pairs(storage(a))
Loading

0 comments on commit 0514aa0

Please sign in to comment.