Skip to content

Commit

Permalink
Add support for linear indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 20, 2024
1 parent c67b83a commit ac4e58c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand All @@ -14,7 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Aqua = "0.8.9"
ArrayLayouts = "1.11.0"
BroadcastMapConversion = "0.1.0"
Derive = "0.3.0"
Derive = "0.3.6"
Dictionaries = "0.4.3"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Expand Down
83 changes: 75 additions & 8 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ end
# type instead so fallback functions can use abstract types.
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end

function Derive.combine_interface_rule(
interface1::AbstractSparseArrayInterface, interface2::AbstractSparseArrayInterface
)
return error("Rule not defined.")
end
function Derive.combine_interface_rule(
interface1::Interface, interface2::Interface
) where {Interface<:AbstractSparseArrayInterface}
return interface1
end
function Derive.combine_interface_rule(
interface1::AbstractSparseArrayInterface, interface2::AbstractArrayInterface
)
return interface1
end
function Derive.combine_interface_rule(
interface1::AbstractArrayInterface, interface2::AbstractSparseArrayInterface
)
return interface2
end

to_vec(x) = vec(collect(x))
to_vec(x::AbstractArray) = vec(x)

Expand Down Expand Up @@ -178,7 +199,46 @@ end
return SparseArrayDOK{T}(size...)
end

@interface ::AbstractSparseArrayInterface function Base.map!(
# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachstoredindex(as...)
a_dest[I] = f(map(a -> a[I], as)...)
end
return a_dest
end

# Only map all values, not just the stored ones.
function map_all! end

@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachindex(as...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
return a_dest
end

using ArrayLayouts: ArrayLayouts, zero!

# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
# and is useful for sparse array logic, since it can be used to empty
# the sparse array storage.
# We use a single function definition to minimize method ambiguities.
@interface interface::AbstractSparseArrayInterface function ArrayLayouts.zero!(
a::AbstractArray
)
# More generally, this codepath could be taking if `zero(eltype(a))`
# is defined and the elements are immutable.
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
return @interface interface map_stored!(f, a, a)
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
Expand All @@ -194,15 +254,22 @@ end
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
if !preserves_unstored
# Doesn't preserve unstored values, loop over all elements.
for I in eachindex(as...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
@interface interface map_all!(f, a_dest, as...)
return a_dest
end
# Define `eachstoredindex` promotion.
for I in eachstoredindex(as...)
a_dest[I] = f(map(a -> a[I], as)...)
end
# First zero out the destination.
# TODO: Make this more nuanced, skip when possible, for
# example if the sparsity of the destination is a subset of
# the sparsity of the sources, i.e.:
# ```julia
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
# zero!(a_dest)
# end
# ```
# This is the safest thing to do in general, for example
# if the destination is dense but the sources are sparse.
@interface interface zero!(a_dest)
@interface interface map_stored!(f, a_dest, as...)
return a_dest
end

Expand Down
11 changes: 11 additions & 0 deletions src/sparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ using Derive: Derive

struct SparseArrayInterface <: AbstractSparseArrayInterface end

function Derive.combine_interface_rule(
interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface
)
return interface1
end
function Derive.combine_interface_rule(
interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface
)
return interface2
end

# Convenient shorthand to refer to the sparse interface.
# Can turn a function into a sparse function with the syntax `sparse(f)`,
# i.e. `sparse(map)(x -> 2x, randn(2, 2))` while use the sparse
Expand Down
42 changes: 28 additions & 14 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@ parentvalue_to_value(a::AbstractArray, value) = value
value_to_parentvalue(a::AbstractArray, value) = value
eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a))
storedparentvalues(a::AbstractArray) = storedvalues(parent(a))
parentindex_to_index(a::AbstractArray, I::CartesianIndex) = error()
function parentindex_to_index(a::AbstractArray, I::Int...)

function parentindex_to_index(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return throw(MethodError(parentindex_to_index, Tuple{typeof(a),typeof(I)}))
end
function parentindex_to_index(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return Tuple(parentindex_to_index(a, CartesianIndex(I)))
end
index_to_parentindex(a::AbstractArray, I::CartesianIndex) = error()
function index_to_parentindex(a::AbstractArray, I::Int...)
# Handle linear indexing.
function parentindex_to_index(a::AbstractArray, I::Int)
return parentindex_to_index(a, CartesianIndices(parent(a))[I])
end

function index_to_parentindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return throw(MethodError(index_to_parentindex, Tuple{typeof(a),typeof(I)}))
end
function index_to_parentindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return Tuple(index_to_parentindex(a, CartesianIndex(I)))
end
# Handle linear indexing.
function index_to_parentindex(a::AbstractArray, I::Int)
return index_to_parentindex(a, CartesianIndices(a)[I])
end

function cartesianindex_reverse(I::CartesianIndex)
return CartesianIndex(reverse(Tuple(I)))
Expand All @@ -21,10 +35,10 @@ tuple_oneto(n) = ntuple(identity, n)
genperm(v, perm) = map(j -> v[j], perm)

using LinearAlgebra: Adjoint
function parentindex_to_index(a::Adjoint, I::CartesianIndex)
function parentindex_to_index(a::Adjoint, I::CartesianIndex{2})
return cartesianindex_reverse(I)
end
function index_to_parentindex(a::Adjoint, I::CartesianIndex)
function index_to_parentindex(a::Adjoint, I::CartesianIndex{2})
return cartesianindex_reverse(I)
end
function parentvalue_to_value(a::Adjoint, value)
Expand All @@ -36,18 +50,18 @@ end

perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
function index_to_parentindex(a::PermutedDimsArray, I::CartesianIndex)
function index_to_parentindex(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return CartesianIndex(genperm(I, iperm(a)))
end
function parentindex_to_index(a::PermutedDimsArray, I::CartesianIndex)
function parentindex_to_index(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return CartesianIndex(genperm(I, perm(a)))
end

using Base: ReshapedArray
function parentindex_to_index(a::ReshapedArray, I::CartesianIndex)
function parentindex_to_index(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return CartesianIndices(size(a))[LinearIndices(parent(a))[I]]
end
function index_to_parentindex(a::ReshapedArray, I::CartesianIndex)
function index_to_parentindex(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return CartesianIndices(parent(a))[LinearIndices(size(a))[I]]
end

Expand All @@ -56,10 +70,10 @@ function eachstoredparentindex(a::SubArray)
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
end
end
function index_to_parentindex(a::SubArray, I::CartesianIndex)
function index_to_parentindex(a::SubArray{<:Any,N}, I::CartesianIndex{N}) where {N}
return CartesianIndex(Base.reindex(parentindices(a), Tuple(I)))
end
function parentindex_to_index(a::SubArray, I::CartesianIndex)
function parentindex_to_index(a::SubArray{<:Any,N}, I::CartesianIndex{N}) where {N}
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
return !(parentindices(a)[d] isa Real)
end
Expand All @@ -81,10 +95,10 @@ function storedparentvalues(a::SubArray)
end

using LinearAlgebra: Transpose
function parentindex_to_index(a::Transpose, I::CartesianIndex)
function parentindex_to_index(a::Transpose, I::CartesianIndex{2})
return cartesianindex_reverse(I)
end
function index_to_parentindex(a::Transpose, I::CartesianIndex)
function index_to_parentindex(a::Transpose, I::CartesianIndex{2})
return cartesianindex_reverse(I)
end
function parentvalue_to_value(a::Transpose, value)
Expand Down

0 comments on commit ac4e58c

Please sign in to comment.