diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index c471ea2..aa64d9f 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -1,25 +1,30 @@ # Minimal interface for `SparseArrayInterface`. -# TODO: Define default definitions for these based -# on the dense case. -# TODO: Define as `MethodError`. -## isstored(a::AbstractArray, I::Int...) = true -isstored(a::AbstractArray, I::Int...) = error("Not implemented.") -## eachstoredindex(a::AbstractArray) = eachindex(a) -eachstoredindex(a::AbstractArray) = error("Not implemented.") -## getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...) -getstoredindex(a::AbstractArray, I::Int...) = error("Not implemented.") -## setstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...) -setstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.") -## setunstoredindex!(a::AbstractArray, value, I::Int...) = setindex!(a, value, I...) -setunstoredindex!(a::AbstractArray, value, I::Int...) = error("Not implemented.") +isstored(a::AbstractArray, I::Int...) = true +eachstoredindex(a::AbstractArray) = eachindex(a) +getstoredindex(a::AbstractArray, I::Int...) = getindex(a, I...) +function setstoredindex!(a::AbstractArray, value, I::Int...) + setindex!(a, value, I...) + return a +end +# TODO: Should this error by default if the value at the index +# is stored? It could be disabled with something analogous +# to `checkbounds`, like `checkstored`/`checkunstored`. +function setunstoredindex!(a::AbstractArray, value, I::Int...) + setindex!(a, value, I...) + return a +end # TODO: Use `Base.to_indices`? isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...) +# TODO: Use `Base.to_indices`? getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...) +# TODO: Use `Base.to_indices`? getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...) +# TODO: Use `Base.to_indices`? function setstoredindex!(a::AbstractArray, value, I::CartesianIndex) return setstoredindex!(a, value, Tuple(I)...) end +# TODO: Use `Base.to_indices`? function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) return setunstoredindex!(a, value, Tuple(I)...) end @@ -33,6 +38,9 @@ getunstoredindex(a::AbstractArray, I::Int...) = zero(eltype(a)) storedlength(a::AbstractArray) = length(storedvalues(a)) storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a)) +to_vec(x) = vec(collect(x)) +to_vec(x::AbstractArray) = vec(x) + # A view of the stored values of an array. # Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue # with that is it returns a `SubArray` wrapping a sparse array, which @@ -47,7 +55,7 @@ struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} array::A storedindices::I end -StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a))) +StoredValues(a::AbstractArray) = StoredValues(a, to_vec(eachstoredindex(a))) Base.size(a::StoredValues) = size(a.storedindices) Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) function Base.setindex!(a::StoredValues, value, I::Int) diff --git a/test/Project.toml b/test/Project.toml index 266d4f0..9a08f31 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,9 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index 0970a6b..c8764c6 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -1,6 +1,63 @@ -using SparseArraysBase: SparseArraysBase +using Adapt: adapt +using JLArrays: JLArray, @allowscalar +using SparseArraysBase: + SparseArraysBase, + eachstoredindex, + getstoredindex, + getunstoredindex, + isstored, + setstoredindex!, + setunstoredindex!, + storedlength, + storedpairs, + storedvalues using Test: @test, @testset -@testset "SparseArraysBase" begin - # Tests go here. +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) +arrayts = (Array, JLArray) +@testset "SparseArraysBase (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts, + elt in elts + + dev(x) = adapt(arrayt, x) + + n = 2 + a = dev(randn(elt, n, n)) + @test storedlength(a) == length(a) + for indexstyle in (IndexLinear(), IndexCartesian()) + for I in eachindex(indexstyle, a) + @test isstored(a, I) + end + end + @test eachstoredindex(a) == eachindex(a) + # TODO: We should be specializing these for dense/strided arrays, + # probably we can have a trait for that. It could be based + # on the `ArrayLayouts.MemoryLayout`. + @allowscalar @test storedvalues(a) == vec(a) + @allowscalar @test storedpairs(a) == collect(pairs(vec(a))) + @allowscalar for I in eachindex(a) + @test getstoredindex(a, I) == a[I] + @test iszero(getunstoredindex(a, I)) + end + @allowscalar for I in eachindex(IndexCartesian(), a) + @test getstoredindex(a, I) == a[I] + @test iszero(getunstoredindex(a, I)) + end + + a = dev(randn(elt, n, n)) + for I in ((1, 2), (CartesianIndex(1, 2),)) + b = copy(a) + value = randn(elt) + @allowscalar setstoredindex!(b, value, I...) + @allowscalar b[I...] == value + end + + # TODO: Should `setunstoredindex!` error by default + # if the value at that index is already stored? + a = dev(randn(elt, n, n)) + for I in ((1, 2), (CartesianIndex(1, 2),)) + b = copy(a) + value = randn(elt) + @allowscalar setunstoredindex!(b, value, I...) + @allowscalar b[I...] == value + end end