Skip to content

Commit

Permalink
Reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 11, 2024
1 parent 9985386 commit 33b3d6c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ a = SparseArrayDOK{Float64}(2, 2)
AbstractArray interface:

````julia
@test iszero(a)
@test iszero(sum(a))
@test iszero(storedlength(a))

a[1, 2] = 12
@test a == [0 12; 0 0]
@test a[1, 1] == 0
Expand All @@ -78,6 +82,10 @@ using Dictionaries: IndexError
@test storedlength(a) == 1
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
@test issetequal(storedvalues(a), [12])
@test sum(a) == 12
@test isreal(a)
@test !iszero(a)
@test mapreduce(x -> 2x, +, a) == 24
````

AbstractArray functionality:
Expand All @@ -87,6 +95,10 @@ b = a .+ 2 .* a'
@test b isa SparseMatrixDOK{Float64}
@test b == [0 12; 24 0]
@test storedlength(b) == 2
@test sum(b) == 36
@test isreal(b)
@test !iszero(b)
@test mapreduce(x -> 2x, +, b) == 72

b = permutedims(a, (2, 1))
@test b isa SparseMatrixDOK{Float64}
Expand Down
12 changes: 12 additions & 0 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ a = SparseArrayDOK{Float64}(2, 2)

# AbstractArray interface:

@test iszero(a)
@test iszero(sum(a))
@test iszero(storedlength(a))

a[1, 2] = 12
@test a == [0 12; 0 0]
@test a[1, 1] == 0
Expand All @@ -79,13 +83,21 @@ using Dictionaries: IndexError
@test storedlength(a) == 1
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
@test issetequal(storedvalues(a), [12])
@test sum(a) == 12
@test isreal(a)
@test !iszero(a)
@test mapreduce(x -> 2x, +, a) == 24

# AbstractArray functionality:

b = a .+ 2 .* a'
@test b isa SparseMatrixDOK{Float64}
@test b == [0 12; 24 0]
@test storedlength(b) == 2
@test sum(b) == 36
@test isreal(b)
@test !iszero(b)
@test mapreduce(x -> 2x, +, b) == 72

b = permutedims(a, (2, 1))
@test b isa SparseMatrixDOK{Float64}
Expand Down
50 changes: 46 additions & 4 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,27 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
# Derived interface.
storedlength(a::AbstractArray) = length(storedvalues(a))
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))

Check warning on line 34 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L34

Added line #L34 was not covered by tests
function storedvalues(a::AbstractArray)
return @view a[collect(eachstoredindex(a))]
end

# A view of the stored values of an array.
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
# with that is it returns a `SubArray` wrapping a sparse array, which
# is then interpreted as a sparse array so it can lead to recursion.
# Also, that involves extra logic for determining if the indices are
# stored or not, but we know the indices are stored so we can use
# `getstoredindex` and `setstoredindex!`.
# Most sparse arrays should overload `storedvalues` directly
# and avoid this wrapper since it adds extra indirection to
# access stored values.
struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
array::A
storedindices::I
end
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
Base.size(a::StoredValues) = size(a.storedindices)
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
Base.setindex!(a::StoredValues, value, I::Int) = setstoredindex!(a.array, value, a.storedindices[I])

Check warning on line 53 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L50-L53

Added lines #L50 - L53 were not covered by tests

storedvalues(a::AbstractArray) = StoredValues(a)

Check warning on line 55 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L55

Added line #L55 was not covered by tests

function eachstoredindex(a1, a2, a_rest...)
# TODO: Make this more customizable, say with a function
Expand Down Expand Up @@ -64,8 +82,8 @@ end
@interface ::AbstractSparseArrayInterface function Base.setindex!(
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
iszero(value) && return a
if !isstored(a, I...)
iszero(value) && return a
setunstoredindex!(a, value, I...)
return a
end
Expand Down Expand Up @@ -94,6 +112,30 @@ end
return dest
end

# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
function reduce_init(f, op, as...)

Check warning on line 116 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L116

Added line #L116 was not covered by tests
# TODO: Generalize this.
@assert isone(length(as))
a = only(as)

Check warning on line 119 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
## TODO: Make this more efficient for block sparse
## arrays, in that case it allocates a block. Maybe
## it can use `FillArrays.Zeros`.
return f(getunstoredindex(a, first(eachindex(a))))

Check warning on line 123 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L123

Added line #L123 was not covered by tests
end

@interface ::AbstractSparseArrayInterface function Base.mapreduce(

Check warning on line 126 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L126

Added line #L126 was not covered by tests
f, op, as::AbstractArray...; init=reduce_init(f, op, as...), kwargs...
)
# TODO: Generalize this.
@assert isone(length(as))
a = only(as)
output = mapreduce(f, op, storedvalues(a); init, kwargs...)

Check warning on line 132 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L130-L132

Added lines #L130 - L132 were not covered by tests
## TODO: Bring this check back, or make the function more general.
## f_notstored = apply_notstored(f, a)
## @assert isequal(op(output, eltype(output)(f_notstored)), output)
return output

Check warning on line 136 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L136

Added line #L136 was not covered by tests
end

abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end

@derive AbstractSparseArrayStyle AbstractArrayStyleOps
Expand Down
3 changes: 3 additions & 0 deletions src/sparsearraydok.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ function getunstoredindex(a::SparseArrayDOK, I::Int...)
return a.getunstoredindex(a, I...)
end
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
# TODO: Have a way to disable this check, analogous to `checkbounds`,
# since this is already checked in `setindex!`.
isstored(a, I...) || throw(IndexError("key $(CartesianIndex(I)) not found"))
# TODO: If `iszero(value)`, unstore the index.
storage(a)[CartesianIndex(I)] = value
return a
end
Expand Down

0 comments on commit 33b3d6c

Please sign in to comment.