Skip to content

Commit

Permalink
[SparseArrayInterface] NestedPermutedDimsArray support (#1590)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 15, 2024
1 parent 3594216 commit d40ca1a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

# TODO: Make this into a generic definition of all `AbstractArray`?
function SparseArrayInterface.stored_indices(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
)
return Iterators.map(
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
Expand All @@ -41,7 +41,7 @@ end

# TODO: Make this into a generic definition of all `AbstractArray`?
function SparseArrayInterface.sparse_storage(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
)
return sparse_storage(parent(a))
end
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
using ..NestedPermutedDimsArrays: NestedPermutedDimsArray

## PermutedDimsArray

perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,IP}) where {IP} = IP
const AnyPermutedDimsArray{T,N,perm,iperm,P} = Union{
PermutedDimsArray{T,N,perm,iperm,P},NestedPermutedDimsArray{T,N,perm,iperm,P}
}

# TODO: Use `TypeParameterAccessors`.
perm(::AnyPermutedDimsArray{<:Any,<:Any,Perm}) where {Perm} = Perm
iperm(::AnyPermutedDimsArray{<:Any,<:Any,<:Any,IPerm}) where {IPerm} = IPerm

# TODO: Use `Base.PermutedDimsArrays.genperm` or
# https://github.com/jipolanco/StaticPermutations.jl?
genperm(v, perm) = map(j -> v[j], perm)
genperm(v::CartesianIndex, perm) = CartesianIndex(map(j -> Tuple(v)[j], perm))

function storage_index_to_index(a::PermutedDimsArray, I)
function storage_index_to_index(a::AnyPermutedDimsArray, I)
return genperm(storage_index_to_index(parent(a), I), perm(a))
end

function index_to_storage_index(
a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
a::AnyPermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return index_to_storage_index(parent(a), genperm(I, perm(a)))
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
module AbstractSparseArrays
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout, MulAdd
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray, Zero

struct SparseArray{T,N} <: AbstractSparseArray{T,N}
struct SparseArray{T,N,Zero} <: AbstractSparseArray{T,N}
data::Vector{T}
dims::Tuple{Vararg{Int,N}}
index_to_dataindex::Dict{CartesianIndex{N},Int}
dataindex_to_index::Vector{CartesianIndex{N}}
zero::Zero
end
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
return SparseArray{T,N}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
return SparseArray{T,N,typeof(zero)}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
)
end
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
return SparseArray{T}(dims)
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
return SparseArray{T,N}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T,length(dims)}(dims; kwargs...)
end
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)

# ArrayLayouts interface
struct SparseLayout <: MemoryLayout end
Expand All @@ -41,6 +46,7 @@ function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}})
end

# Minimal interface
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
function SparseArrayInterface.index_to_storage_index(
a::SparseArray{<:Any,N}, I::CartesianIndex{N}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
module SparseArrays
using LinearAlgebra: LinearAlgebra
using NDTensors.SparseArrayInterface: SparseArrayInterface
using NDTensors.SparseArrayInterface: SparseArrayInterface, Zero

struct SparseArray{T,N} <: AbstractArray{T,N}
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
data::Vector{T}
dims::Tuple{Vararg{Int,N}}
index_to_dataindex::Dict{CartesianIndex{N},Int}
dataindex_to_index::Vector{CartesianIndex{N}}
zero::Zero
end
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
return SparseArray{T,N}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
return SparseArray{T,N,typeof(zero)}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
)
end
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
return SparseArray{T}(dims)
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
return SparseArray{T,N}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T,length(dims)}(dims; kwargs...)
end
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)

# LinearAlgebra interface
function LinearAlgebra.mul!(
Expand Down Expand Up @@ -53,6 +58,7 @@ function Base.fill!(a::SparseArray, value)
end

# Minimal interface
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
function SparseArrayInterface.index_to_storage_index(
a::SparseArray{<:Any,N}, I::CartesianIndex{N}
Expand All @@ -79,6 +85,33 @@ function SparseArrayInterface.stored_indices(
)
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.SparseArrayInterface: sparse_storage
function SparseArrayInterface.sparse_storage(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return sparse_storage(parent(a))
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
function SparseArrayInterface.stored_indices(
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return Iterators.map(
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
)
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
using NDTensors.SparseArrayInterface: sparse_storage
function SparseArrayInterface.sparse_storage(
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return sparse_storage(parent(a))
end

# Empty the storage, helps with efficiency in `map!` to drop
# zeros.
function SparseArrayInterface.dropall!(a::SparseArray)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using LinearAlgebra: dot, mul!, norm
using NDTensors.SparseArrayInterface: SparseArrayInterface
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
include("SparseArrayInterfaceTestUtils/SparseArrayInterfaceTestUtils.jl")
using .SparseArrayInterfaceTestUtils.AbstractSparseArrays: AbstractSparseArrays
using .SparseArrayInterfaceTestUtils.SparseArrays: SparseArrays
Expand Down Expand Up @@ -224,6 +225,44 @@ using Test: @test, @testset
end
end

a = SparseArray{elt}(2, 3)
a[1, 2] = 12
b = PermutedDimsArray(a, (2, 1))
@test size(b) == (3, 2)
@test axes(b) == (1:3, 1:2)
@test SparseArrayInterface.sparse_storage(b) == elt[12]
@test SparseArrayInterface.stored_length(b) == 1
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
@test !iszero(b)
@test !iszero(norm(b))
for I in eachindex(b)
if I == CartesianIndex(2, 1)
@test b[I] == 12
else
@test iszero(b[I])
end
end

a = SparseArray{Matrix{elt}}(
2, 3; zero=(a, I) -> (z = similar(eltype(a), 2, 3); fill!(z, false); z)
)
a[1, 2] = randn(elt, 2, 3)
b = NestedPermutedDimsArray(a, (2, 1))
@test size(b) == (3, 2)
@test axes(b) == (1:3, 1:2)
@test SparseArrayInterface.sparse_storage(b) == [a[1, 2]]
@test SparseArrayInterface.stored_length(b) == 1
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
@test !iszero(b)
@test !iszero(norm(b))
for I in eachindex(b)
if I == CartesianIndex(2, 1)
@test b[I] == permutedims(a[1, 2], (2, 1))
else
@test iszero(b[I])
end
end

a = SparseArray{elt}(2, 3)
a[1, 2] = 12
b = randn(elt, 2, 3)
Expand Down

0 comments on commit d40ca1a

Please sign in to comment.