Skip to content

Commit

Permalink
[BlockSparseArrays] Redesign block views again (ITensor#1513)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Redesign block views again

* [NDTensors] Bump to v0.3.38
  • Loading branch information
mtfishman authored Jun 27, 2024
1 parent 9c22961 commit 2985e9b
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 8 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.37"
version = "0.3.38"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,33 @@ function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks)
end
end

# Represents a view of a block of a blocked array.
struct BlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
array::Array
block::Tuple{Vararg{Block{1,Int},N}}
end
function Base.axes(a::BlockView)
# TODO: Try to avoid conversion to `Base.OneTo{Int}`, or just convert
# the element type to `Int` with `Int.(...)`.
# When the axes of `a.array` are `GradedOneTo`, the block is `LabelledUnitRange`,
# which has element type `LabelledInteger`. That causes conversion problems
# in some generic Base Julia code, for example when printing `BlockView`.
return ntuple(ndims(a)) do dim
return Base.OneTo{Int}(only(axes(axes(a.array, dim)[a.block[dim]])))
end
end
function Base.size(a::BlockView)
return length.(axes(a))
end
function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
return blocks(a.array)[Int.(a.block)...][index...]
end
function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N}
blocks(a.array)[Int.(a.block)...] = blocks(a.array)[Int.(a.block)...]
blocks(a.array)[Int.(a.block)...][index...] = value
return a
end

function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ function Base.setindex!(
blocksparse_setindex!(a, value, I...)
return a
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
blocksize = ntuple(dim -> length(axes(a, dim)[I[dim]]), N)
if size(value) blocksize
throw(
DimensionMismatch(
"Trying to set block $(Block(Int.(I)...)), which has a size $blocksize, with data of size $(size(value)).",
),
)
end
blocks(a)[Int.(I)...] = value
return a
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using BlockArrays: Block, BlockSlices
using BlockArrays: BlockArrays, Block, BlockSlices, viewblock

function blocksparse_view(a, I...)
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
Expand All @@ -22,3 +22,19 @@ function Base.view(
)
return blocksparse_view(a, I)
end

# Specialized code for getting the view of a block.
function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Block{N}
) where {N}
return viewblock(a, Tuple(block)...)
end
function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
) where {N}
I = CartesianIndex(Int.(block))
if I stored_indices(blocks(a))
return blocks(a)[I]
end
return BlockView(a, block)
end
12 changes: 6 additions & 6 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using BlockArrays:
using Compat: @compat
using LinearAlgebra: mul!
using NDTensors.BlockSparseArrays:
@view!, BlockSparseArray, block_nstored, block_reshape, view!
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
Expand Down Expand Up @@ -362,10 +362,10 @@ include("TestBlockSparseArraysUtils.jl")
b = @view a[Block(2, 2)]
@test size(b) == (3, 4)
for i in parentindices(b)
@test i isa BlockSlice{<:Block{1}}
@test i isa Base.OneTo{Int}
end
@test parentindices(b)[1] == BlockSlice(Block(2), 3:5)
@test parentindices(b)[2] == BlockSlice(Block(2), 4:7)
@test parentindices(b)[1] == 1:3
@test parentindices(b)[2] == 1:4

a = BlockSparseArray{elt}([2, 3], [3, 4])
b = @view a[Block(2, 2)[1:2, 2:2]]
Expand All @@ -392,9 +392,9 @@ include("TestBlockSparseArraysUtils.jl")

a = BlockSparseArray{elt}([2, 3], [3, 4])
b = @views a[Block(2, 2)][1:2, 2:3]
@test b isa SubArray{<:Any,<:Any,<:BlockSparseArray}
@test b isa SubArray{<:Any,<:Any,<:BlockView}
for i in parentindices(b)
@test i isa BlockSlice{<:BlockIndexRange{1}}
@test i isa UnitRange{Int}
end
x = randn(elt, 2, 2)
b .= x
Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/lib/GradedAxes/src/unitrangedual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ end
using NDTensors.LabelledNumbers: LabelledNumbers, label
LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a)))

using NDTensors.LabelledNumbers: LabelledUnitRange
# The Base version of `length(::AbstractUnitRange)` drops the label.
function Base.length(a::UnitRangeDual{<:Any,<:LabelledUnitRange})
return dual(length(nondual(a)))
end
function Base.iterate(a::UnitRangeDual, i)
i == last(a) && return nothing
return dual.(iterate(nondual(a), i))
end
# TODO: Is this a good definition?
Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a

using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))

using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock
BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a))
BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer}
return OrdinalRange{T,T}(unlabel(a))
end

# TODO: Is this a good definition?
Base.unitrange(a::LabelledUnitRange) = a

for f in [:first, :getindex, :last, :length, :step]
@eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a))
end
Expand Down

0 comments on commit 2985e9b

Please sign in to comment.