From 9c229614afddb7c67c5a1fee69644626e6d45947 Mon Sep 17 00:00:00 2001
From: Matt Fishman <mtfishman@users.noreply.github.com>
Date: Wed, 26 Jun 2024 14:18:40 -0400
Subject: [PATCH] [BlockSparseArrays] Towards block merging (#1512)

* [BlockSparseArrays] Towards block merging

* [NDTensors] Bump to v0.3.37
---
 NDTensors/Project.toml                        |  2 +-
 .../BlockArraysExtensions.jl                  | 20 ++++++++++++++++++-
 .../wrappedabstractblocksparsearray.jl        |  4 +++-
 .../blocksparsearrayinterface.jl              | 10 ++++++++++
 .../lib/BlockSparseArrays/test/test_basics.jl | 19 ++++++++++++++++++
 .../lib/GradedAxes/src/blockedunitrange.jl    | 16 +++++++++++++++
 6 files changed, 68 insertions(+), 3 deletions(-)

diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml
index bd56eda80c..82ac74c82b 100644
--- a/NDTensors/Project.toml
+++ b/NDTensors/Project.toml
@@ -1,7 +1,7 @@
 name = "NDTensors"
 uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
 authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
-version = "0.3.36"
+version = "0.3.37"
 
 [deps]
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl
index fb015e7ee4..76ff94eb1b 100644
--- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl
+++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl
@@ -230,8 +230,26 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
   return r
 end
 
+# This handles changing the blocking, for example:
+# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
+# I = blockedrange([4, 4])
+# a[I, I]
+# TODO: Generalize to `AbstractBlockedUnitRange`.
+function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
+  # TODO: Probably this is incorrect and should be something like:
+  # return findblock(axis, first(r)):findblock(axis, last(r))
+  return only(blockaxes(r))
+end
+
+# This handles changing the blocking, for example:
+# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
+# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
+# a[I, I]
+# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
 function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
-  return error("Slicing not implemented for range of type `$(typeof(r))`.")
+  # TODO: Probably this is incorrect and should be something like:
+  # return findblock(axis, first(r)):findblock(axis, last(r))
+  return only(blockaxes(r))
 end
 
 using BlockArrays: BlockSlice
diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
index aec6783308..d25affba75 100644
--- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
+++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
@@ -1,6 +1,7 @@
 using Adapt: Adapt, WrappedArray
 using BlockArrays:
   BlockArrays,
+  AbstractBlockVector,
   AbstractBlockedUnitRange,
   BlockIndexRange,
   BlockRange,
@@ -40,8 +41,9 @@ function Base.to_indices(
 end
 
 # a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
+# a[BlockedVector([Block(2), Block(1)], [2]), BlockedVector([Block(2), Block(1)], [2])]
 function Base.to_indices(
-  a::BlockSparseArrayLike, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}}
+  a::BlockSparseArrayLike, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}}
 )
   return blocksparse_to_indices(a, inds, I)
 end
diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl
index 39e684a55f..ee6790914b 100644
--- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl
+++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl
@@ -4,6 +4,7 @@ using BlockArrays:
   BlockIndex,
   BlockVector,
   BlockedUnitRange,
+  BlockedVector,
   block,
   blockcheckbounds,
   blocklengths,
@@ -46,6 +47,12 @@ function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg
   return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
 end
 
+# TODO: Should this be combined with the version above?
+function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}})
+  I1 = blockedunitrange_getindices(inds[1], I[1])
+  return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
+end
+
 # TODO: Need to implement this!
 function block_merge end
 
@@ -223,6 +230,9 @@ function Base.size(a::SparseSubArrayBlocks)
   return length.(axes(a))
 end
 function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
+  # TODO: Should this be defined as `@view a.array[Block(I)]` instead?
+  ## return @view a.array[Block(I)]
+
   parent_blocks = @view blocks(parent(a.array))[blockrange(a)...]
   parent_block = parent_blocks[I...]
   # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
index 2f3b8d6b92..d11fcf7f9c 100644
--- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
+++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
@@ -7,6 +7,7 @@ using BlockArrays:
   BlockVector,
   BlockedOneTo,
   BlockedUnitRange,
+  BlockedVector,
   blockedrange,
   blocklength,
   blocklengths,
@@ -23,6 +24,24 @@ using Test: @test, @test_broken, @test_throws, @testset
 include("TestBlockSparseArraysUtils.jl")
 @testset "BlockSparseArrays (eltype=$elt)" for elt in
                                                (Float32, Float64, ComplexF32, ComplexF64)
+  @testset "Broken" begin
+    a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2])
+    @views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)]
+      a[I] = randn(elt, size(a[I]))
+    end
+
+    I = blockedrange([4, 4])
+    b = @view a[I, I]
+    @test_broken copy(b)
+
+    I = BlockedVector(Block.(1:4), [2, 2])
+    b = @view a[I, I]
+    @test_broken copy(b)
+
+    I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
+    b = @view a[I, I]
+    @test_broken copy(b)
+  end
   @testset "Basics" begin
     a = BlockSparseArray{elt}([2, 3], [2, 3])
     @test a == BlockSparseArray{elt}(blockedrange([2, 3]), blockedrange([2, 3]))
diff --git a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
index f420fae614..417a15adf9 100644
--- a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
+++ b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
@@ -6,6 +6,7 @@ using BlockArrays:
   BlockRange,
   BlockSlice,
   BlockedUnitRange,
+  BlockedVector,
   block,
   blockindex,
   findblock,
@@ -70,6 +71,21 @@ function blockedunitrange_getindices(
   return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
 end
 
+# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
+function blockedunitrange_getindices(
+  a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1},<:BlockRange{1}}
+)
+  blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
+  return blockedrange(blocklengths)
+end
+
+# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
+function blockedunitrange_getindices(
+  a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1}}
+)
+  return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)))
+end
+
 # TODO: Move this to a `BlockArraysExtensions` library.
 function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndexRange)
   return a[block(indices)][only(indices.indices)]