Skip to content

Commit

Permalink
[BlockSparseArrays] Permute and merge blocks (ITensor#1514)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Permute and merge blocks

* [NDTensors] Bump to v0.3.39
  • Loading branch information
mtfishman authored Jul 1, 2024
1 parent 2985e9b commit d734e64
Show file tree
Hide file tree
Showing 14 changed files with 673 additions and 96 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.38"
version = "0.3.39"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ function TensorAlgebra.splitdims(
return length(axis) length(axes(a, i))
end
blockperms = invblockperm.(blocksortperm.(axes_prod))
a_blockpermed = a[blockperms...]
# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
a_unblocked = a[axes_prod...]
a_blockpermed = a_unblocked[blockperms...]
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
m = fusedims(a, (1, 2), (3, 4))
# TODO: Once block merging is implemented, this should
# be the real test.
for ax in axes(m)
@test ax isa GradedOneTo
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
end
for I in CartesianIndices(m)
if I CartesianIndex.([(1, 1), (4, 4)])
Expand All @@ -105,10 +100,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test blocksize(m) == (4, 4)
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
@testset "dual axes" begin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ using BlockArrays:
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockIndex,
BlockIndexRange,
BlockRange,
BlockSlice,
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
BlockedVector,
block,
blockaxes,
blockedrange,
Expand All @@ -17,8 +20,30 @@ using BlockArrays:
findblockindex
using Compat: allequal
using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices
using ..GradedAxes: blockedunitrange_getindices, to_blockindices
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices

# A return type for `blocks(array)` when `array` isn't blocked.
# Represents a vector with just that single block.
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
blocks_maybe_single(a) = blocks(a)
blocks_maybe_single(a::Array) = SingleBlockView(a)
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
@assert all(isone, index)
return a.array
end

# A wrapper around a potentially blocked array that is not blocked.
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
Expand All @@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end

# This is used in slicing like:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
function Base.getindex(
S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}}
)
# TODO: Check for conistency of indices.
# Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices
# as a single block, since the result shouldn't be blocked.
return NonBlockedVector(BlockIndices(S.blocks[Block(i)], S.indices[Block(i)]))
end
function Base.getindex(
S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}}
)
return i
end

# Used in indexing such as:
# ```julia
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# b = @view a[I, I]
# @view b[Block(1, 1)[1:2, 2:2]]
# ```
# This is similar to the definition:
# blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
function Base.getindex(
a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer}
)
ax = only(axes(a.array.indices))
brs = to_blockindices(ax, I)
inds = blockedunitrange_getindices(ax, I)
return NonBlockedVector(a.array[BlockSlice(brs, inds)])
end

function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
Expand All @@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
return BlockIndices(subblocks, subindices)
end

# Used when performing slices like:
# @views a[[Block(2), Block(1)]][2:4, 2:4]
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}})
subblocks = mortar(
map(blocks(i.block)) do br
return S.blocks[Int(Block(br))][only(br.indices)]
end,
)
subindices = mortar(
map(blocks(i.block)) do br
S.indices[br]
end,
)
return BlockIndices(subblocks, subindices)
end

# Similar to the definition of `BlockArrays.BlockSlices`:
# ```julia
# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
# ```
# but includes `BlockIndices`, where the blocks aren't contiguous.
const BlockSliceCollection = Union{
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
}
const SubBlockSliceCollection = BlockIndices{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
Expand Down Expand Up @@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
return findblock(axis, first(r)):findblock(axis, last(r))
end

# Occurs when slicing with `a[2:4, 2:4]`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedUnitRange{<:Integer})
# TODO: Check the blocks are commensurate.
return findblock(axis, first(r)):findblock(axis, last(r))
end

function blockrange(axis::AbstractUnitRange, r::Int)
## return findblock(axis, r)
return error("Slicing with integer values isn't supported.")
Expand All @@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
return only(blockaxes(r))
end

# This handles changing the blocking, for example:
# This handles block merging:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector(Block.(1:4), [2, 2])
# I = BlockVector(Block.(1:4), [2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
for b in r
@assert b blockaxes(axis, 1)
end
return only(blockaxes(r))
end

Expand Down Expand Up @@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
return only(blockaxes(axis))
end

function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
return Block(1):Block(1)
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down Expand Up @@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
return a
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
function SparseArrayInterface.nstored(a::BlockView)
# TODO: Store whether or not the block is stored already as
# a Bool in `BlockView`.
I = CartesianIndex(Int.(a.block))
# TODO: Use `block_stored_indices`.
if I stored_indices(blocks(a.array))
return nstored(blocks(a.array)[I])
end
return 0
end

function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module BlockSparseArrays
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -15,7 +17,6 @@ include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysExtensions/BlockArraysExtensions.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,57 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

reblock(a) = a

# If the blocking of the slice doesn't match the blocking of the
# parent array, reblock according to the blocking of the parent array.
function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
)
# TODO: This relies on the behavior that slicing a block sparse
# array with a UnitRange inherits the blocking of the underlying
# block sparse array, we might change that default behavior
# so this might become something like `@blocked parent(a)[...]`.
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
end

function reblock(
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
)
return @view parent(a)[map(I -> I.array, parentindices(a))...]
end

function reblock(
a::SubArray{
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
},
)
# Remove the blocking.
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
end

# TODO: Rewrite this so that it takes the blocking structure
# made by combining the blocking of the axes (i.e. the blocking that
# is used to determine `union_stored_blocked_cartesianindices(...)`).
# `reblock` is a partial solution to that, but a bit ad-hoc.
# TODO: Move to `blocksparsearrayinterface/map.jl`.
function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
Expand Down
Loading

0 comments on commit d734e64

Please sign in to comment.