-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
358 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,18 +3,27 @@ uuid = "b8770bf0-c4ae-4888-b9b0-956061873092" | |
authors = ["ITensor developers <[email protected]> and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" | ||
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" | ||
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
||
[compat] | ||
Aqua = "0.8.9" | ||
ArrayLayouts = "1.11.0" | ||
BroadcastMapConversion = "0.1.0" | ||
LinearAlgebra = "1.10" | ||
SafeTestsets = "0.1" | ||
Suppressor = "0.2" | ||
Test = "1.10" | ||
julia = "1.10" | ||
|
||
[extras] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" | ||
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Aqua", "Test", "Suppressor", "SafeTestsets"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
[deps] | ||
SparseArraysBase = "b8770bf0-c4ae-4888-b9b0-956061873092" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
module SparseArraysBase | ||
|
||
# Write your package code here. | ||
include("sparsearrayinterface.jl") | ||
include("wrappers.jl") | ||
include("sparsearraydok.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# TODO: Define `AbstractSparseArray`, make this a subtype. | ||
struct SparseArrayDOK{T,N} <: AbstractArray{T,N} | ||
storage::Dict{CartesianIndex{N},T} | ||
size::NTuple{N,Int} | ||
end | ||
|
||
function SparseArrayDOK{T}(size::Int...) where {T} | ||
N = length(size) | ||
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) | ||
end | ||
|
||
using Derive: @wrappedtype | ||
# Define `WrappedSparseArrayDOK` and `AnySparseArrayDOK`. | ||
@wrappedtype SparseArrayDOK | ||
|
||
using Derive: Derive | ||
function Derive.interface(::Type{<:SparseArrayDOK}) | ||
return SparseArrayInterface() | ||
end | ||
|
||
using Derive: @derive | ||
@derive AnySparseArrayDOK AbstractArrayOps | ||
|
||
storage(a::SparseArrayDOK) = a.storage | ||
Base.size(a::SparseArrayDOK) = a.size | ||
|
||
storedvalues(a::SparseArrayDOK) = values(storage(a)) | ||
function isstored(a::SparseArrayDOK, I::Int...) | ||
return CartesianIndex(I) in keys(storage(a)) | ||
end | ||
function eachstoredindex(a::SparseArrayDOK) | ||
return keys(storage(a)) | ||
end | ||
function getstoredindex(a::SparseArrayDOK, I::Int...) | ||
return storage(a)[CartesianIndex(I)] | ||
end | ||
function getunstoredindex(a::SparseArrayDOK, I::Int...) | ||
return zero(eltype(a)) | ||
end | ||
function setstoredindex!(a::SparseArrayDOK, value, I::Int...) | ||
isstored(a, I...) || throw(KeyError(CartesianIndex(I))) | ||
storage(a)[CartesianIndex(I)] = value | ||
return a | ||
end | ||
function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) | ||
storage(a)[CartesianIndex(I)] = value | ||
return a | ||
end | ||
|
||
# Optional, but faster than the default. | ||
storedpairs(a::SparseArrayDOK) = storage(a) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Minimal interface for `SparseArrayInterface`. | ||
# TODO: Define default definitions for these based | ||
# on the dense case. | ||
storedvalues(a) = error() | ||
isstored(a, I::Int...) = error() | ||
eachstoredindex(a) = error() | ||
getstoredindex(a, I::Int...) = error() | ||
getunstoredindex(a, I::Int...) = error() | ||
setstoredindex!(a, value, I::Int...) = error() | ||
setunstoredindex!(a, value, I::Int...) = error() | ||
|
||
# Derived interface. | ||
storedlength(a) = length(storedvalues(a)) | ||
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a)) | ||
|
||
function eachstoredindex(a1, a2, a_rest...) | ||
# TODO: Make this more customizable, say with a function | ||
# `combine/promote_storedindices(a1, a2)`. | ||
return union(eachstoredindex.((a1, a2, a_rest...))...) | ||
end | ||
|
||
# TODO: Add `ndims` type parameter. | ||
# TODO: Define `AbstractSparseArrayInterface`, make this a subtype. | ||
using Derive: Derive, @interface, AbstractArrayInterface | ||
struct SparseArrayInterface <: AbstractArrayInterface end | ||
|
||
# Convenient shorthand to refer to the sparse interface. | ||
const sparse = SparseArrayInterface() | ||
|
||
# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize` | ||
# to handle slicing (implemented by copying SubArray). | ||
@interface sparse function Base.getindex(a, I::Int...) | ||
!isstored(a, I...) && return getunstoredindex(a, I...) | ||
return getstoredindex(a, I...) | ||
end | ||
|
||
@interface sparse function Base.setindex!(a, value, I::Int...) | ||
iszero(value) && return a | ||
if !isstored(a, I...) | ||
setunstoredindex!(a, value, I...) | ||
return a | ||
end | ||
setstoredindex!(a, value, I...) | ||
return a | ||
end | ||
|
||
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK` | ||
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`. | ||
@interface sparse function Base.similar(a, T::Type, size::Tuple{Vararg{Int}}) | ||
return SparseArrayDOK{T}(size...) | ||
end | ||
|
||
## TODO: Make this more general, handle mixtures of integers and ranges. | ||
## TODO: Make this logic generic to all `similar(::AbstractInterface, ...)`. | ||
## @interface sparse function Base.similar(a, T::Type, dims::Tuple{Vararg{Base.OneTo}}) | ||
## return sparse(similar)(interface, a, T, Base.to_shape(dims)) | ||
## end | ||
|
||
@interface sparse function Base.map(f, as...) | ||
# This is defined in this way so we can rely on the Broadcast logic | ||
# for determining the destination of the operation (element type, shape, etc.). | ||
return f.(as...) | ||
end | ||
|
||
@interface sparse function Base.map!(f, dest, as...) | ||
# Check `f` preserves zeros. | ||
# Define as `map_stored!`. | ||
# Define `eachstoredindex` promotion. | ||
for I in eachstoredindex(as...) | ||
dest[I] = f(map(a -> a[I], as)...) | ||
end | ||
return dest | ||
end | ||
|
||
# TODO: Define `AbstractSparseArrayStyle`, make this a subtype. | ||
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end | ||
|
||
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() | ||
|
||
@interface sparse function Broadcast.BroadcastStyle(type::Type) | ||
return SparseArrayStyle{ndims(type)}() | ||
end | ||
|
||
function Base.similar(bc::Broadcast.Broadcasted{<:SparseArrayStyle}, T::Type, axes::Tuple) | ||
# TODO: Allow `similar` to accept `axes` directly. | ||
return sparse(similar)(bc, T, Int.(length.(axes))) | ||
end | ||
|
||
using BroadcastMapConversion: map_function, map_args | ||
# TODO: Look into `SparseArrays.capturescalars`: | ||
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 | ||
function Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{<:SparseArrayStyle}) | ||
sparse(map!)(map_function(bc), dest, map_args(bc)...) | ||
return dest | ||
end | ||
|
||
using ArrayLayouts: ArrayLayouts, MatMulMatAdd | ||
|
||
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end | ||
|
||
struct SparseLayout <: AbstractSparseLayout end | ||
|
||
@interface sparse function ArrayLayouts.MemoryLayout(type::Type) | ||
return SparseLayout() | ||
end | ||
|
||
using LinearAlgebra: LinearAlgebra | ||
@interface sparse function LinearAlgebra.mul!(a_dest, a1, a2, α, β) | ||
return ArrayLayouts.mul!(a_dest, a1, a2, α, β) | ||
end | ||
|
||
function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2}) | ||
if I1[2] ≠ I2[1] | ||
return nothing | ||
end | ||
return CartesianIndex(I1[1], I2[2]) | ||
end | ||
|
||
function default_mul!!( | ||
a_dest::AbstractMatrix, | ||
a1::AbstractMatrix, | ||
a2::AbstractMatrix, | ||
α::Number=true, | ||
β::Number=false, | ||
) | ||
mul!(a_dest, a1, a2, α, β) | ||
return a_dest | ||
end | ||
|
||
function default_mul!!( | ||
a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false | ||
) | ||
return a1 * a2 * α + a_dest * β | ||
end | ||
|
||
# a1 * a2 * α + a_dest * β | ||
function sparse_mul!( | ||
a_dest::AbstractArray, | ||
a1::AbstractArray, | ||
a2::AbstractArray, | ||
α::Number=true, | ||
β::Number=false; | ||
(mul!!)=(default_mul!!), | ||
) | ||
for I1 in eachstoredindex(a1) | ||
for I2 in eachstoredindex(a2) | ||
I_dest = mul_indices(I1, I2) | ||
if !isnothing(I_dest) | ||
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β) | ||
end | ||
end | ||
end | ||
return a_dest | ||
end | ||
|
||
function ArrayLayouts.materialize!( | ||
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout} | ||
) | ||
sparse_mul!(m.C, m.A, m.B, m.α, m.β) | ||
return m.C | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
using LinearAlgebra: Adjoint | ||
storedvalues(a::Adjoint) = storedvalues(parent(a)) | ||
function isstored(a::Adjoint, i::Int, j::Int) | ||
return isstored(parent(a), j, i) | ||
end | ||
function eachstoredindex(a::Adjoint) | ||
# TODO: Make lazy with `Iterators.map`. | ||
return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) | ||
end | ||
function getstoredindex(a::Adjoint, i::Int, j::Int) | ||
return getstoredindex(parent(a), j, i)' | ||
end | ||
function getunstoredindex(a::Adjoint, i::Int, j::Int) | ||
return getunstoredindex(parent(a), j, i)' | ||
end | ||
function setstoredindex!(a::Adjoint, value, i::Int, j::Int) | ||
setstoredindex!(parent(a), value', j, i) | ||
return a | ||
end | ||
function setunstoredindex!(a::Adjoint, value, i::Int, j::Int) | ||
setunstoredindex!(parent(a), value', j, i) | ||
return a | ||
end | ||
|
||
using LinearAlgebra: Transpose | ||
storedvalues(a::Transpose) = storedvalues(parent(a)) | ||
function isstored(a::Transpose, i::Int, j::Int) | ||
return isstored(parent(a), j, i) | ||
end | ||
function eachstoredindex(a::Transpose) | ||
# TODO: Make lazy with `Iterators.map`. | ||
return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) | ||
end | ||
function getstoredindex(a::Transpose, i::Int, j::Int) | ||
return transpose(getstoredindex(parent(a), j, i)) | ||
end | ||
function getunstoredindex(a::Transpose, i::Int, j::Int) | ||
return transpose(getunstoredindex(parent(a), j, i)) | ||
end | ||
function setstoredindex!(a::Transpose, value, i::Int, j::Int) | ||
setstoredindex!(parent(a), transpose(value), j, i) | ||
return a | ||
end | ||
function setunstoredindex!(a::Transpose, value, i::Int, j::Int) | ||
setunstoredindex!(parent(a), transpose(value), j, i) | ||
return a | ||
end |
Oops, something went wrong.