Skip to content

Commit

Permalink
[BlockSparseArrays] Towards block merging (ITensor#1512)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Towards block merging

* [NDTensors] Bump to v0.3.37
  • Loading branch information
mtfishman authored Jun 26, 2024
1 parent ab8a59e commit 9c22961
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 3 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.36"
version = "0.3.37"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,26 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
return r
end

# This handles changing the blocking, for example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = blockedrange([4, 4])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
return only(blockaxes(r))
end

# This handles changing the blocking, for example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
return error("Slicing not implemented for range of type `$(typeof(r))`.")
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
return only(blockaxes(r))
end

using BlockArrays: BlockSlice
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Adapt: Adapt, WrappedArray
using BlockArrays:
BlockArrays,
AbstractBlockVector,
AbstractBlockedUnitRange,
BlockIndexRange,
BlockRange,
Expand Down Expand Up @@ -40,8 +41,9 @@ function Base.to_indices(
end

# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
# a[BlockedVector([Block(2), Block(1)], [2]), BlockedVector([Block(2), Block(1)], [2])]
function Base.to_indices(
a::BlockSparseArrayLike, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}}
a::BlockSparseArrayLike, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}}
)
return blocksparse_to_indices(a, inds, I)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using BlockArrays:
BlockIndex,
BlockVector,
BlockedUnitRange,
BlockedVector,
block,
blockcheckbounds,
blocklengths,
Expand Down Expand Up @@ -46,6 +47,12 @@ function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

# TODO: Should this be combined with the version above?
function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}})
I1 = blockedunitrange_getindices(inds[1], I[1])
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

# TODO: Need to implement this!
function block_merge end

Expand Down Expand Up @@ -223,6 +230,9 @@ function Base.size(a::SparseSubArrayBlocks)
return length.(axes(a))
end
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
# TODO: Should this be defined as `@view a.array[Block(I)]` instead?
## return @view a.array[Block(I)]

parent_blocks = @view blocks(parent(a.array))[blockrange(a)...]
parent_block = parent_blocks[I...]
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
Expand Down
19 changes: 19 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using BlockArrays:
BlockVector,
BlockedOneTo,
BlockedUnitRange,
BlockedVector,
blockedrange,
blocklength,
blocklengths,
Expand All @@ -23,6 +24,24 @@ using Test: @test, @test_broken, @test_throws, @testset
include("TestBlockSparseArraysUtils.jl")
@testset "BlockSparseArrays (eltype=$elt)" for elt in
(Float32, Float64, ComplexF32, ComplexF64)
@testset "Broken" begin
a = BlockSparseArray{elt}([2, 2, 2, 2], [2, 2, 2, 2])
@views for I in [Block(1, 1), Block(2, 2), Block(3, 3), Block(4, 4)]
a[I] = randn(elt, size(a[I]))
end

I = blockedrange([4, 4])
b = @view a[I, I]
@test_broken copy(b)

I = BlockedVector(Block.(1:4), [2, 2])
b = @view a[I, I]
@test_broken copy(b)

I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
b = @view a[I, I]
@test_broken copy(b)
end
@testset "Basics" begin
a = BlockSparseArray{elt}([2, 3], [2, 3])
@test a == BlockSparseArray{elt}(blockedrange([2, 3]), blockedrange([2, 3]))
Expand Down
16 changes: 16 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using BlockArrays:
BlockRange,
BlockSlice,
BlockedUnitRange,
BlockedVector,
block,
blockindex,
findblock,
Expand Down Expand Up @@ -70,6 +71,21 @@ function blockedunitrange_getindices(
return blockedunitrange(indices .+ (first(a) - 1), blocklengths)
end

# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1},<:BlockRange{1}}
)
blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
return blockedrange(blocklengths)
end

# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly.
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1}}
)
return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)))
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndexRange)
return a[block(indices)][only(indices.indices)]
Expand Down

0 comments on commit 9c22961

Please sign in to comment.