Skip to content

Commit

Permalink
[BlockSparseArrys] Fix nested slicing in Julia 1.11
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 10, 2024
1 parent 49c1202 commit 880e3e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
end
Base.size(a::NonBlockedArray) = size(a.array)
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
# Views of `NonBlockedArray`/`NonBlockedVector` are eager.
# This fixes an issue in Julia 1.11 where reindexing defaults to using views.
# TODO: Maybe reconsider this design, and allows views to work in slicing.
Base.view(a::NonBlockedArray, I...) = a[I...]
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)
Expand Down Expand Up @@ -81,6 +85,9 @@ function Base.getindex(
)
return i
end
# Views of `BlockIndices` are eager.
# This fixes an issue in Julia 1.11 where reindexing defaults to using views.
Base.view(S::BlockIndices, i) = S[i]

# Used in indexing such as:
# ```julia
Expand Down
40 changes: 13 additions & 27 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,31 +691,21 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
c = @view b[4:8, 4:8]
@test c isa SubArray{<:Any,<:Any,<:BlockSparseArray}
@test size(c) == (5, 5)
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test block_nstored(c) == 2 broken = VERSION > v"1.11-"
@test block_nstored(c) == 2
@test blocksize(c) == (2, 2)
@test blocklengths.(axes(c)) == ([2, 3], [2, 3])
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test size(c[Block(1, 1)]) == (2, 2) broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]] broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test size(c[Block(2, 2)]) == (3, 3) broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]] broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test size(c[Block(2, 1)]) == (3, 2) broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test iszero(c[Block(2, 1)]) broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test size(c[Block(1, 2)]) == (2, 3) broken = VERSION v"1.11-"
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test iszero(c[Block(1, 2)]) broken = VERSION v"1.11-"
@test size(c[Block(1, 1)]) == (2, 2)
@test c[Block(1, 1)] == a[Block(2, 2)[2:3, 2:3]]
@test size(c[Block(2, 2)]) == (3, 3)
@test c[Block(2, 2)] == a[Block(1, 1)[1:3, 1:3]]
@test size(c[Block(2, 1)]) == (3, 2)
@test iszero(c[Block(2, 1)])
@test size(c[Block(1, 2)]) == (2, 3)
@test iszero(c[Block(1, 2)])

x = randn(elt, 3, 3)
c[Block(2, 2)] = x
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test c[Block(2, 2)] == x broken = VERSION v"1.11-"
@test c[Block(2, 2)] == x
@test a[Block(1, 1)[1:3, 1:3]] == x

a = BlockSparseArray{elt}([2, 3], [3, 4])
Expand Down Expand Up @@ -776,17 +766,13 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test copy(b) == a[J, J]
@test blocksize(b) == (2, 2)
@test blocklengths.(axes(b)) == ([4, 4], [4, 4])
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test b[Block(1, 1)] == Array(a)[[7, 8, 5, 6], [7, 8, 5, 6]] broken =
VERSION v"1.11-"
@test b[Block(1, 1)] == Array(a)[[7, 8, 5, 6], [7, 8, 5, 6]]
c = @views b[Block(1, 1)][2:3, 2:3]
@test c == Array(a)[[8, 5], [8, 5]]
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test copy(c) == Array(a)[[8, 5], [8, 5]] broken = VERSION v"1.11-"
@test copy(c) == Array(a)[[8, 5], [8, 5]]
c = @view b[Block(1, 1)[2:3, 2:3]]
@test c == Array(a)[[8, 5], [8, 5]]
# TODO: Fix in Julia 1.11 (https://github.com/ITensor/ITensors.jl/pull/1539).
@test copy(c) == Array(a)[[8, 5], [8, 5]] broken = VERSION v"1.11-"
@test copy(c) == Array(a)[[8, 5], [8, 5]]
end

# TODO: Add more tests of this, it may
Expand Down

0 comments on commit 880e3e3

Please sign in to comment.