diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/sparsearrayinterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/sparsearrayinterface.jl index 98095225f7..e4929b5e81 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/sparsearrayinterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/sparsearrayinterface.jl @@ -32,7 +32,7 @@ end # TODO: Make this into a generic definition of all `AbstractArray`? function SparseArrayInterface.stored_indices( - a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray} + a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray} ) return Iterators.map( I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a)) @@ -41,7 +41,7 @@ end # TODO: Make this into a generic definition of all `AbstractArray`? function SparseArrayInterface.sparse_storage( - a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray} + a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray} ) return sparse_storage(parent(a)) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl index b4364b80ed..a4fce3bb0c 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl @@ -1,19 +1,26 @@ +using ..NestedPermutedDimsArrays: NestedPermutedDimsArray + ## PermutedDimsArray -perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P -iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,IP}) where {IP} = IP +const AnyPermutedDimsArray{T,N,perm,iperm,P} = Union{ + PermutedDimsArray{T,N,perm,iperm,P},NestedPermutedDimsArray{T,N,perm,iperm,P} +} + +# TODO: Use `TypeParameterAccessors`. +perm(::AnyPermutedDimsArray{<:Any,<:Any,Perm}) where {Perm} = Perm +iperm(::AnyPermutedDimsArray{<:Any,<:Any,<:Any,IPerm}) where {IPerm} = IPerm # TODO: Use `Base.PermutedDimsArrays.genperm` or # https://github.com/jipolanco/StaticPermutations.jl? genperm(v, perm) = map(j -> v[j], perm) genperm(v::CartesianIndex, perm) = CartesianIndex(map(j -> Tuple(v)[j], perm)) -function storage_index_to_index(a::PermutedDimsArray, I) +function storage_index_to_index(a::AnyPermutedDimsArray, I) return genperm(storage_index_to_index(parent(a), I), perm(a)) end function index_to_storage_index( - a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N} + a::AnyPermutedDimsArray{<:Any,N}, I::CartesianIndex{N} ) where {N} return index_to_storage_index(parent(a), genperm(I, perm(a))) end diff --git a/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/AbstractSparseArrays.jl b/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/AbstractSparseArrays.jl index 1455b65c15..1a81334397 100644 --- a/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/AbstractSparseArrays.jl +++ b/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/AbstractSparseArrays.jl @@ -1,24 +1,29 @@ module AbstractSparseArrays using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout, MulAdd -using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray +using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray, Zero -struct SparseArray{T,N} <: AbstractSparseArray{T,N} +struct SparseArray{T,N,Zero} <: AbstractSparseArray{T,N} data::Vector{T} dims::Tuple{Vararg{Int,N}} index_to_dataindex::Dict{CartesianIndex{N},Int} dataindex_to_index::Vector{CartesianIndex{N}} + zero::Zero end -function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N} - return SparseArray{T,N}( - T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}() +function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N} + return SparseArray{T,N,typeof(zero)}( + T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero ) end -SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims) -SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims) -function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T} - return SparseArray{T}(dims) +function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N} + return SparseArray{T,N}(dims; kwargs...) end -SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims) +function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T} + return SparseArray{T,length(dims)}(dims; kwargs...) +end +function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T} + return SparseArray{T}(dims; kwargs...) +end +SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...) # ArrayLayouts interface struct SparseLayout <: MemoryLayout end @@ -41,6 +46,7 @@ function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}}) end # Minimal interface +SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero SparseArrayInterface.sparse_storage(a::SparseArray) = a.data function SparseArrayInterface.index_to_storage_index( a::SparseArray{<:Any,N}, I::CartesianIndex{N} diff --git a/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl b/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl index e1ca4f0661..f74846e2c1 100644 --- a/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl +++ b/NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl @@ -1,24 +1,29 @@ module SparseArrays using LinearAlgebra: LinearAlgebra -using NDTensors.SparseArrayInterface: SparseArrayInterface +using NDTensors.SparseArrayInterface: SparseArrayInterface, Zero -struct SparseArray{T,N} <: AbstractArray{T,N} +struct SparseArray{T,N,Zero} <: AbstractArray{T,N} data::Vector{T} dims::Tuple{Vararg{Int,N}} index_to_dataindex::Dict{CartesianIndex{N},Int} dataindex_to_index::Vector{CartesianIndex{N}} + zero::Zero end -function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N} - return SparseArray{T,N}( - T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}() +function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N} + return SparseArray{T,N,typeof(zero)}( + T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero ) end -SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims) -SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims) -function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T} - return SparseArray{T}(dims) +function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N} + return SparseArray{T,N}(dims; kwargs...) end -SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims) +function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T} + return SparseArray{T,length(dims)}(dims; kwargs...) +end +function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T} + return SparseArray{T}(dims; kwargs...) +end +SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...) # LinearAlgebra interface function LinearAlgebra.mul!( @@ -53,6 +58,7 @@ function Base.fill!(a::SparseArray, value) end # Minimal interface +SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero SparseArrayInterface.sparse_storage(a::SparseArray) = a.data function SparseArrayInterface.index_to_storage_index( a::SparseArray{<:Any,N}, I::CartesianIndex{N} @@ -79,6 +85,33 @@ function SparseArrayInterface.stored_indices( ) end +# TODO: Make this into a generic definition of all `AbstractArray`? +using NDTensors.SparseArrayInterface: sparse_storage +function SparseArrayInterface.sparse_storage( + a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray} +) + return sparse_storage(parent(a)) +end + +# TODO: Make this into a generic definition of all `AbstractArray`? +using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray +function SparseArrayInterface.stored_indices( + a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray} +) + return Iterators.map( + I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a)) + ) +end + +# TODO: Make this into a generic definition of all `AbstractArray`? +using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray +using NDTensors.SparseArrayInterface: sparse_storage +function SparseArrayInterface.sparse_storage( + a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray} +) + return sparse_storage(parent(a)) +end + # Empty the storage, helps with efficiency in `map!` to drop # zeros. function SparseArrayInterface.dropall!(a::SparseArray) diff --git a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl index 8ec5174463..0c44c25f6d 100644 --- a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl +++ b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl @@ -1,6 +1,7 @@ @eval module $(gensym()) using LinearAlgebra: dot, mul!, norm using NDTensors.SparseArrayInterface: SparseArrayInterface +using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray include("SparseArrayInterfaceTestUtils/SparseArrayInterfaceTestUtils.jl") using .SparseArrayInterfaceTestUtils.AbstractSparseArrays: AbstractSparseArrays using .SparseArrayInterfaceTestUtils.SparseArrays: SparseArrays @@ -224,6 +225,44 @@ using Test: @test, @testset end end + a = SparseArray{elt}(2, 3) + a[1, 2] = 12 + b = PermutedDimsArray(a, (2, 1)) + @test size(b) == (3, 2) + @test axes(b) == (1:3, 1:2) + @test SparseArrayInterface.sparse_storage(b) == elt[12] + @test SparseArrayInterface.stored_length(b) == 1 + @test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)] + @test !iszero(b) + @test !iszero(norm(b)) + for I in eachindex(b) + if I == CartesianIndex(2, 1) + @test b[I] == 12 + else + @test iszero(b[I]) + end + end + + a = SparseArray{Matrix{elt}}( + 2, 3; zero=(a, I) -> (z = similar(eltype(a), 2, 3); fill!(z, false); z) + ) + a[1, 2] = randn(elt, 2, 3) + b = NestedPermutedDimsArray(a, (2, 1)) + @test size(b) == (3, 2) + @test axes(b) == (1:3, 1:2) + @test SparseArrayInterface.sparse_storage(b) == [a[1, 2]] + @test SparseArrayInterface.stored_length(b) == 1 + @test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)] + @test !iszero(b) + @test !iszero(norm(b)) + for I in eachindex(b) + if I == CartesianIndex(2, 1) + @test b[I] == permutedims(a[1, 2], (2, 1)) + else + @test iszero(b[I]) + end + end + a = SparseArray{elt}(2, 3) a[1, 2] = 12 b = randn(elt, 2, 3)