Skip to content

Commit

Permalink
Define fallback for sparse interface for dense arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 12, 2024
1 parent c99a04a commit 1b9dfe8
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 17 deletions.
36 changes: 22 additions & 14 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
# Minimal interface for `SparseArrayInterface`.
# TODO: Define default definitions for these based
# on the dense case.
# TODO: Define as `MethodError`.
## isstored(a::AbstractArray, I::Int...) = true
isstored(a::AbstractArray, I::Int...) = error("Not implemented.")
## eachstoredindex(a::AbstractArray) = eachindex(a)
eachstoredindex(a::AbstractArray) = error("Not implemented.")
## getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...)
getstoredindex(a::AbstractArray, I::Int...) = error("Not implemented.")
## setstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...)
setstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.")
## setunstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...)
setunstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.")
isstored(a::AbstractArray, I::Int...) = true
eachstoredindex(a::AbstractArray) = eachindex(a)
getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...)
function setstoredindex!(a::AbstractArray, value, I::Int...)
setindex!(a, value, I...)
return a

Check warning on line 7 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L2-L7

Added lines #L2 - L7 were not covered by tests
end
# TODO: Should this error by default if the value at the index
# is stored? It could be disabled with something analogous
# to `checkbounds`, like `checkstored`/`checkunstored`.
function setunstoredindex!(a::AbstractArray, value, I::Int...)
setindex!(a, value, I...)
return a

Check warning on line 14 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L12-L14

Added lines #L12 - L14 were not covered by tests
end

# TODO: Use `Base.to_indices`?
isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...)

Check warning on line 18 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L18

Added line #L18 was not covered by tests
# TODO: Use `Base.to_indices`?
getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...)

Check warning on line 20 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L20

Added line #L20 was not covered by tests
# TODO: Use `Base.to_indices`?
getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...)
# TODO: Use `Base.to_indices`?
function setstoredindex!(a::AbstractArray, value, I::CartesianIndex)
return setstoredindex!(a, value, Tuple(I)...)

Check warning on line 25 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end
# TODO: Use `Base.to_indices`?
function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
return setunstoredindex!(a, value, Tuple(I)...)

Check warning on line 29 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L28-L29

Added lines #L28 - L29 were not covered by tests
end
Expand All @@ -33,6 +38,9 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a))
storedlength(a::AbstractArray) = length(storedvalues(a))
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))

Check warning on line 39 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L39

Added line #L39 was not covered by tests

to_vec(x) = vec(collect(x))
to_vec(x::AbstractArray) = vec(x)

Check warning on line 42 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L41-L42

Added lines #L41 - L42 were not covered by tests

# A view of the stored values of an array.
# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue
# with that is it returns a `SubArray` wrapping a sparse array, which
Expand All @@ -47,7 +55,7 @@ struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T}
array::A
storedindices::I
end
StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a)))
StoredValues(a::AbstractArray) = StoredValues(a, to_vec(eachstoredindex(a)))
Base.size(a::StoredValues) = size(a.storedindices)
Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I])
function Base.setindex!(a::StoredValues, value, I::Int)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Expand Down
63 changes: 60 additions & 3 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
using SparseArraysBase: SparseArraysBase
using Adapt: adapt
using JLArrays: JLArray, @allowscalar
using SparseArraysBase:
SparseArraysBase,
eachstoredindex,
getstoredindex,
getunstoredindex,
isstored,
setstoredindex!,
setunstoredindex!,
storedlength,
storedpairs,
storedvalues
using Test: @test, @testset

@testset "SparseArraysBase" begin
# Tests go here.
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
arrayts = (Array, JLArray)
@testset "SparseArraysBase (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
elt in elts

dev(x) = adapt(arrayt, x)

n = 2
a = dev(randn(elt, n, n))
@test storedlength(a) == length(a)
for indexstyle in (IndexLinear(), IndexCartesian())
for I in eachindex(indexstyle, a)
@test isstored(a, I)
end
end
@test eachstoredindex(a) == eachindex(a)
# TODO: We should be specializing these for dense/strided arrays,
# probably we can have a trait for that. It could be based
# on the `ArrayLayouts.MemoryLayout`.
@allowscalar @test storedvalues(a) == vec(a)
@allowscalar @test storedpairs(a) == collect(pairs(vec(a)))
@allowscalar for I in eachindex(a)
@test getstoredindex(a, I) == a[I]
@test iszero(getunstoredindex(a, I))
end
@allowscalar for I in eachindex(IndexCartesian(), a)
@test getstoredindex(a, I) == a[I]
@test iszero(getunstoredindex(a, I))
end

a = dev(randn(elt, n, n))
for I in ((1, 2), (CartesianIndex(1, 2),))
b = copy(a)
value = randn(elt)
@allowscalar setstoredindex!(b, value, I...)
@allowscalar b[I...] == value
end

# TODO: Should `setunstoredindex!` error by default
# if the value at that index is already stored?
a = dev(randn(elt, n, n))
for I in ((1, 2), (CartesianIndex(1, 2),))
b = copy(a)
value = randn(elt)
@allowscalar setunstoredindex!(b, value, I...)
@allowscalar b[I...] == value
end
end

0 comments on commit 1b9dfe8

Please sign in to comment.