Skip to content

Commit

Permalink
Update to Derive and SparseArraysBase v0.2 (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Dec 20, 2024
1 parent d44f51c commit 43fc9d6
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 65 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[compat]
ArrayLayouts = "1.10.4"
BroadcastMapConversion = "0.1"
NestedPermutedDimsArrays = "0.1"
SparseArraysBase = "0.1"
TypeParameterAccessors = "0.1"
Derive = "0.3.6"
SparseArraysBase = "0.2.1"
TypeParameterAccessors = "0.2"
julia = "1.10"
4 changes: 1 addition & 3 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
[deps]
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
100 changes: 86 additions & 14 deletions src/abstractdiagonalarray/diagonalarraydiaginterface.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,95 @@
using SparseArraysBase: SparseArraysBase, StorageIndex, StorageIndices
# TODO: Turn these into `@interface ::AbstractDiagonalArrayInterface` functions.

SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a)}))

function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
return a[StorageIndex(i)]
end
using Derive: Derive, @interface
using SparseArraysBase:
SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle ## , StorageIndex, StorageIndices

function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
a[StorageIndex(i)] = value
return a
end
abstract type AbstractDiagonalArrayInterface <: AbstractSparseArrayInterface end

struct DiagonalArrayInterface <: AbstractDiagonalArrayInterface end

Derive.arraytype(::AbstractDiagonalArrayInterface, elt::Type) = DiagonalArray{elt}
Derive.interface(::Type{<:AbstractDiagonalArray}) = DiagonalArrayInterface()

abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end

SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))
Derive.interface(::Type{<:AbstractDiagonalArrayStyle}) = DiagonalArrayInterface()

function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
return a[StorageIndices(i)]
struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end

DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}()

@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
return DiagonalArrayStyle{ndims(type)}()
end

function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
a[StorageIndices(i)] = value
function SparseArraysBase.isstored(
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
return allequal(I)
end
function SparseArraysBase.getstoredindex(
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
# TODO: Make this check optional, define `checkstored` like `checkbounds`
# in SparseArraysBase.jl.
# allequal(I) || error("Not a diagonal index.")
return getdiagindex(a, first(I))
end
function SparseArraysBase.setstoredindex!(
a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
# TODO: Make this check optional, define `checkstored` like `checkbounds`
# in SparseArraysBase.jl.
# allequal(I) || error("Not a diagonal index.")
setdiagindex!(a, value, first(I))
return a
end
function SparseArraysBase.eachstoredindex(a::AbstractDiagonalArray)
return diagindices(a)
end

# Fix ambiguity error with SparseArraysBase.
function Base.getindex(a::AbstractDiagonalArray, I::DiagIndices)
# TODO: Use `@interface` rather than `invoke`.
return invoke(getindex, Tuple{AbstractArray,DiagIndices}, a, I)
end
# Fix ambiguity error with SparseArraysBase.
function Base.getindex(a::AbstractDiagonalArray, I::DiagIndex)
# TODO: Use `@interface` rather than `invoke`.
return invoke(getindex, Tuple{AbstractArray,DiagIndex}, a, I)
end
# Fix ambiguity error with SparseArraysBase.
function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndices)
# TODO: Use `@interface` rather than `invoke`.
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndices}, a, value, I)
end
# Fix ambiguity error with SparseArraysBase.
function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
# TODO: Use `@interface` rather than `invoke`.
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I)
end

## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))

## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
## return a[StorageIndex(i)]
## end

## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
## a[StorageIndex(i)] = value
## return a
## end

## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))

## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
## return a[StorageIndices(i)]
## end

## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
## a[StorageIndices(i)] = value
## return a
## end
24 changes: 11 additions & 13 deletions src/abstractdiagonalarray/sparsearrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
using SparseArraysBase: SparseArraysBase

# `SparseArraysBase` interface
function SparseArraysBase.index_to_storage_index(
a::AbstractDiagonalArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
!allequal(Tuple(I)) && return nothing
return first(Tuple(I))
end

function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I)
return CartesianIndex(ntuple(Returns(I), ndims(a)))
end
## # `SparseArraysBase` interface
## function SparseArraysBase.index_to_storage_index(
## a::AbstractDiagonalArray{<:Any,N}, I::CartesianIndex{N}
## ) where {N}
## !allequal(Tuple(I)) && return nothing
## return first(Tuple(I))
## end
##
## function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I)
## return CartesianIndex(ntuple(Returns(I), ndims(a)))
## end

## # 1-dimensional case can be `AbstractDiagonalArray`.
## function SparseArraysBase.sparse_similar(
Expand Down
35 changes: 35 additions & 0 deletions src/diaginterface/diaginterface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Turn these into `@interface ::AbstractDiagonalArrayInterface` functions.

diaglength(a::AbstractArray{<:Any,0}) = 1

function diaglength(a::AbstractArray)
Expand All @@ -19,10 +21,43 @@ function diagstride(a::AbstractArray)
return s
end

# Iterator over the diagonal cartesian indices.
# For an AbstractArray `a`, `DiagCartesianIndices(a)` is equivalent
# to `@view CartesianIndices(a)[diagindices(a)]` but should be
# faster because it avoids conversions from linear to cartesian indices.
struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}}
diaglength::Int
end
function DiagCartesianIndices(axes::Tuple{Vararg{AbstractUnitRange}})
# Check the ranges are one-based.
@assert all(isone, first.(axes))
return DiagCartesianIndices{length(axes)}(minimum(length.(axes)))
end
function DiagCartesianIndices(dims::Tuple{Vararg{Int}})
return DiagCartesianIndices(Base.OneTo.(dims))
end
function DiagCartesianIndices(a::AbstractArray)
return DiagCartesianIndices(axes(a))
end
Base.size(I::DiagCartesianIndices) = (I.diaglength,)
function Base.getindex(I::DiagCartesianIndices{N}, i::Int) where {N}
return CartesianIndex(ntuple(Returns(i), N))
end

function diagindices(a::AbstractArray)
return diagindices(IndexStyle(a), a)
end
function diagindices(::IndexLinear, a::AbstractArray)
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
return 1:diagstride(a):maxdiag
end
function diagindices(::IndexCartesian, a::AbstractArray)
return DiagCartesianIndices(a)
# TODO: Define a special iterator for this, i.e. `DiagCartesianIndices`?
return Iterators.map(
i -> CartesianIndex(ntuple(Returns(i), ndims(a))), Base.OneTo(diaglength(a))
)
end

function diagindices(a::AbstractArray{<:Any,0})
return Base.OneTo(1)
Expand Down
59 changes: 33 additions & 26 deletions src/diagonalarray/diagonalarray.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
using SparseArraysBase: SparseArraysBase, SparseArrayDOK, Zero, getindex_zero_function
using SparseArraysBase: SparseArraysBase, SparseArrayDOK, default_getunstoredindex ## , Zero, getindex_zero_function

struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractDiagonalArray{T,N}
struct DiagonalArray{T,N,Diag<:AbstractVector{T},F} <: AbstractDiagonalArray{T,N}
diag::Diag
dims::NTuple{N,Int}
zero::Zero
getunstoredindex::F
end

function DiagonalArray{T,N}(
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=Zero()
diag::AbstractVector{T},
d::Tuple{Vararg{Int,N}},
getunstoredindex=default_getunstoredindex,
) where {T,N}
return DiagonalArray{T,N,typeof(diag),typeof(zero)}(diag, d, zero)
return DiagonalArray{T,N,typeof(diag),typeof(getunstoredindex)}(diag, d, getunstoredindex)
end

function DiagonalArray{T,N}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=Zero()
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
) where {T,N}
return DiagonalArray{T,N}(T.(diag), d, zero)
return DiagonalArray{T,N}(T.(diag), d, getunstoredindex)
end

function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray{T}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=Zero()
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
) where {T,N}
return DiagonalArray{T,N}(diag, d, zero)
return DiagonalArray{T,N}(diag, d, getunstoredindex)
end

function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
Expand All @@ -51,27 +53,29 @@ end

# undef
function DiagonalArray{T,N}(
::UndefInitializer, d::Tuple{Vararg{Int,N}}, zero=Zero()
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
) where {T,N}
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, zero)
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d, getunstoredindex)
end

function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

function DiagonalArray{T}(
::UndefInitializer, d::Tuple{Vararg{Int,N}}, zero=Zero()
::UndefInitializer, d::Tuple{Vararg{Int,N}}, getunstoredindex=default_getunstoredindex
) where {T,N}
return DiagonalArray{T,N}(undef, d, zero)
return DiagonalArray{T,N}(undef, d, getunstoredindex)
end

# Axes version
function DiagonalArray{T}(
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}, zero=Zero()
::UndefInitializer,
axes::Tuple{Vararg{AbstractUnitRange,N}},
getunstoredindex=default_getunstoredindex,
) where {T,N}
@assert all(isone, first.(axes))
return DiagonalArray{T,N}(undef, length.(axes), zero)
return DiagonalArray{T,N}(undef, length.(axes), getunstoredindex)
end

function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
Expand All @@ -83,23 +87,26 @@ Base.size(a::DiagonalArray) = a.dims

function Base.similar(a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}})
# TODO: Preserve zero element function.
return DiagonalArray{elt}(undef, dims, getindex_zero_function(a))
return DiagonalArray{elt}(undef, dims, a.getunstoredindex)
end

# DiagonalArrays interface.
diagview(a::DiagonalArray) = a.diag

# Minimal `SparseArraysBase` interface
SparseArraysBase.sparse_storage(a::DiagonalArray) = a.diag
## SparseArraysBase.sparse_storage(a::DiagonalArray) = a.diag

# `SparseArraysBase`
# Defines similar when the output can't be `DiagonalArray`,
# such as in `reshape`.
# TODO: Put into `DiagonalArraysSparseArraysBaseExt`?
# TODO: Special case 2D to output `SparseMatrixCSC`?
function SparseArraysBase.sparse_similar(
a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}
)
return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
end

function SparseArraysBase.getindex_zero_function(a::DiagonalArray)
return a.zero
end
## function SparseArraysBase.sparse_similar(
## a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}
## )
## return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
## end

## function SparseArraysBase.getindex_zero_function(a::DiagonalArray)
## return a.zero
## end
8 changes: 4 additions & 4 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test: @test, @testset, @test_broken
using DiagonalArrays: DiagonalArrays, DiagonalArray, DiagonalMatrix, diaglength
using SparseArraysBase: SparseArrayDOK, stored_length
using SparseArraysBase: SparseArrayDOK, storedlength
@testset "Test DiagonalArrays" begin
@testset "DiagonalArray (eltype=$elt)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
Expand All @@ -23,15 +23,15 @@ using SparseArraysBase: SparseArrayDOK, stored_length
# TODO: Use `densearray` to make generic to GPU.
@test Array(a_dest) Array(a1) * Array(a2)
# TODO: Make this work with `ArrayLayouts`.
@test stored_length(a_dest) == 2
@test storedlength(a_dest) == 2
@test a_dest isa DiagonalMatrix{elt}

# TODO: Make generic to GPU, use `allocate_randn`?
a2 = randn(elt, (3, 4))
a_dest = a1 * a2
# TODO: Use `densearray` to make generic to GPU.
@test Array(a_dest) Array(a1) * Array(a2)
@test stored_length(a_dest) == 8
@test storedlength(a_dest) == 8
@test a_dest isa Matrix{elt}

a2 = SparseArrayDOK{elt}(3, 4)
Expand All @@ -43,7 +43,7 @@ using SparseArraysBase: SparseArrayDOK, stored_length
@test Array(a_dest) Array(a1) * Array(a2)
# TODO: Define `SparseMatrixDOK`.
# TODO: Make this work with `ArrayLayouts`.
@test stored_length(a_dest) == 2
@test storedlength(a_dest) == 2
@test a_dest isa SparseArrayDOK{elt,2}
end
end
Expand Down

0 comments on commit 43fc9d6

Please sign in to comment.