diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 94626cdddb..80f68c17f0 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.67" +version = "0.3.68" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index 8c5d60b65f..24de544dc1 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -16,8 +16,15 @@ struct BlockSparseArray{ axes::Axes end -const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes} -const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes} +# TODO: Can this definition be shortened? +const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange,AbstractUnitRange}} = BlockSparseArray{ + T,2,A,Blocks,Axes +} + +# TODO: Can this definition be shortened? +const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange}} = BlockSparseArray{ + T,1,A,Blocks,Axes +} function BlockSparseArray( block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}}, @@ -68,10 +75,38 @@ function BlockSparseArray{T,N,A}( return BlockSparseArray{T,N,A}(blocks, axes) end +function BlockSparseArray{T,N,A}( + axes::Vararg{AbstractUnitRange,N} +) where {T,N,A<:AbstractArray{T,N}} + return BlockSparseArray{T,N,A}(axes) +end + +function BlockSparseArray{T,N,A}( + dims::Tuple{Vararg{Vector{Int},N}} +) where {T,N,A<:AbstractArray{T,N}} + return BlockSparseArray{T,N,A}(blockedrange.(dims)) +end + +# Fix ambiguity error. +function BlockSparseArray{T,0,A}(axes::Tuple{}) where {T,A<:AbstractArray{T,0}} + blocks = default_blocks(A, axes) + return BlockSparseArray{T,0,A}(blocks, axes) +end + +function BlockSparseArray{T,N,A}( + dims::Vararg{Vector{Int},N} +) where {T,N,A<:AbstractArray{T,N}} + return BlockSparseArray{T,N,A}(dims) +end + function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes) end +function BlockSparseArray{T,N}(axes::Vararg{AbstractUnitRange,N}) where {T,N} + return BlockSparseArray{T,N}(axes) +end + function BlockSparseArray{T,0}(axes::Tuple{}) where {T} return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes) end @@ -80,6 +115,10 @@ function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N} return BlockSparseArray{T,N}(blockedrange.(dims)) end +function BlockSparseArray{T,N}(dims::Vararg{Vector{Int},N}) where {T,N} + return BlockSparseArray{T,N}(dims) +end + function BlockSparseArray{T}(dims::Tuple{Vararg{Vector{Int}}}) where {T} return BlockSparseArray{T,length(dims)}(dims) end @@ -104,37 +143,25 @@ function BlockSparseArray{T}() where {T} return BlockSparseArray{T}(()) end -function BlockSparseArray{T,N,A}( - ::UndefInitializer, dims::Tuple -) where {T,N,A<:AbstractArray{T,N}} - return BlockSparseArray{T,N,A}(dims) -end - # undef -function BlockSparseArray{T,N}( - ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}} -) where {T,N} - return BlockSparseArray{T,N}(axes) -end - -function BlockSparseArray{T,N}( - ::UndefInitializer, dims::Tuple{Vararg{Vector{Int},N}} -) where {T,N} - return BlockSparseArray{T,N}(dims) +function BlockSparseArray{T,N,A,Blocks}( + ::UndefInitializer, args... +) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}} + return BlockSparseArray{T,N,A,Blocks}(args...) end -function BlockSparseArray{T}( - ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange}} -) where {T} - return BlockSparseArray{T}(axes) +function BlockSparseArray{T,N,A}( + ::UndefInitializer, args... +) where {T,N,A<:AbstractArray{T,N}} + return BlockSparseArray{T,N,A}(args...) end -function BlockSparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Vector{Int}}}) where {T} - return BlockSparseArray{T}(dims) +function BlockSparseArray{T,N}(::UndefInitializer, args...) where {T,N} + return BlockSparseArray{T,N}(args...) end -function BlockSparseArray{T}(::UndefInitializer, dims::Vararg{Vector{Int}}) where {T} - return BlockSparseArray{T}(dims...) +function BlockSparseArray{T}(::UndefInitializer, args...) where {T} + return BlockSparseArray{T}(args...) end # Base `AbstractArray` interface diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index dd80ec8048..98694efe94 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -17,6 +17,8 @@ using ..SparseArrayInterface: perm, iperm, stored_length, sparse_zero! blocksparse_blocks(a::AbstractArray) = error("Not implemented") +blockstype(a::AbstractArray) = blockstype(typeof(a)) + function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) return a[findblockindex.(axes(a), I)...] diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 41df3e3ea0..e41f736783 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -20,13 +20,18 @@ using LinearAlgebra: Adjoint, dot, mul!, norm using NDTensors.BlockSparseArrays: @view!, BlockSparseArray, + BlockSparseMatrix, + BlockSparseVector, BlockView, block_stored_length, block_reshape, block_stored_indices, + blockstype, + blocktype, view! using NDTensors.GPUArraysCoreExtensions: cpu using NDTensors.SparseArrayInterface: stored_length +using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK using NDTensors.TensorAlgebra: contract using Test: @test, @test_broken, @test_throws, @testset include("TestBlockSparseArraysUtils.jl") @@ -72,6 +77,71 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype ah = adjoint(a) @test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector end + @testset "Constructors" begin + # BlockSparseMatrix + bs = ([2, 3], [3, 4]) + for T in ( + BlockSparseArray{elt}, + BlockSparseArray{elt,2}, + BlockSparseMatrix{elt}, + BlockSparseArray{elt,2,Matrix{elt}}, + BlockSparseMatrix{elt,Matrix{elt}}, + ## BlockSparseArray{elt,2,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO + ## BlockSparseMatrix{elt,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO + ) + for args in ( + bs, + (bs,), + blockedrange.(bs), + (blockedrange.(bs),), + (undef, bs), + (undef, bs...), + (undef, blockedrange.(bs)), + (undef, blockedrange.(bs)...), + ) + a = T(args...) + @test eltype(a) == elt + @test blocktype(a) == Matrix{elt} + @test blockstype(a) <: SparseMatrixDOK{Matrix{elt}} + @test blocklengths.(axes(a)) == ([2, 3], [3, 4]) + @test iszero(a) + @test iszero(block_stored_length(a)) + @test iszero(stored_length(a)) + end + end + + # BlockSparseVector + bs = ([2, 3],) + for T in ( + BlockSparseArray{elt}, + BlockSparseArray{elt,1}, + BlockSparseVector{elt}, + BlockSparseArray{elt,1,Vector{elt}}, + BlockSparseVector{elt,Vector{elt}}, + ## BlockSparseArray{elt,1,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO + ## BlockSparseVector{elt,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO + ) + for args in ( + bs, + (bs,), + blockedrange.(bs), + (blockedrange.(bs),), + (undef, bs), + (undef, bs...), + (undef, blockedrange.(bs)), + (undef, blockedrange.(bs)...), + ) + a = T(args...) + @test eltype(a) == elt + @test blocktype(a) == Vector{elt} + @test blockstype(a) <: SparseVectorDOK{Vector{elt}} + @test blocklengths.(axes(a)) == ([2, 3],) + @test iszero(a) + @test iszero(block_stored_length(a)) + @test iszero(stored_length(a)) + end + end + end @testset "Basics" begin a = dev(BlockSparseArray{elt}([2, 3], [2, 3])) @allowscalar @test a == dev(