From 4cfd892f3b4f787110bc3204a7e94edd6cc04f0b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 17:53:50 -0400 Subject: [PATCH 1/9] [BlockSparseArrays] Improve design of block views --- .../BlockArraysExtensions.jl | 4 +++ .../src/abstractblocksparsearray/view.jl | 4 --- .../wrappedabstractblocksparsearray.jl | 21 ++++++++++++++- .../blocksparsearrayinterface.jl | 27 +++++++++++++++++-- 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 04f37f0f18..0bb67b326e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -131,6 +131,10 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice) return blockrange(axis, r.block) end +function blockrange(axis::AbstractUnitRange, r::Block{1}) + return r:r +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl index e2e5c8acb9..bbe586d771 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl @@ -1,9 +1,5 @@ using BlockArrays: BlockIndexRange, BlockRange, BlockSlice, block -function blocksparse_view(a::AbstractArray, index::Block) - return blocks(a)[Int.(Tuple(index))...] -end - # TODO: Define `AnyBlockSparseVector`. function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N} return blocksparse_view(a, index) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 286847992e..618a1794bb 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -55,7 +55,20 @@ function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitR return ArrayLayouts.layout_getindex(a, I...) end -function Base.isassigned(a::BlockSparseArrayLike, index::Vararg{Block}) +function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, block::Block{N}) where {N} + return blocksparse_getindex(a, block) +end +function Base.getindex( + a::BlockSparseArrayLike{<:Any,N}, block::Vararg{Block{1},N} +) where {N} + return blocksparse_getindex(a, block...) +end + +# TODO: Define `issasigned(a, ::Block{N})`. +function Base.isassigned( + a::BlockSparseArrayLike{<:Any,N}, index::Vararg{Block{1},N} +) where {N} + # TODO: Define `blocksparse_isassigned`. return isassigned(blocks(a), Int.(index)...) end @@ -64,6 +77,12 @@ function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N return a end +function Base.setindex!( + a::BlockSparseArrayLike{<:Any,N}, value, I::Vararg{Block{1},N} +) where {N} + a[Block(Int.(I))] = value + return a +end function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::Block{N}) where {N} blocksparse_setindex!(a, value, I) return a diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index b93fb5b9b2..ad5e17dbca 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -19,6 +19,14 @@ function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where return a[findblockindex.(axes(a), I)...] end +function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} + return blocksparse_getindex(a, Tuple(I)...) +end +function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N} + # TODO: Avoid copy if the block isn't stored. + return copy(blocks(a)[Int.(I)...]) +end + # TODO: Implement as `copy(@view a[I...])`, which is then implemented # through `ArrayLayouts.sub_materialize`. using ..SparseArrayInterface: set_getindex_zero_function @@ -67,13 +75,28 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N end function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) where {N} - # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. - i = I.n + blocksparse_setindex!(a, value, Tuple(I)...) + return a +end +function blocksparse_setindex!( + a::AbstractArray{<:Any,N}, value, I::Vararg{Block{1},N} +) where {N} + i = Int.(I) @boundscheck blockcheckbounds(a, i...) + # TODO: Use `blocksizes(a)[i...]` when we upgrade to + # BlockArrays.jl v1. + @assert size(value) == size(view(a, I)) blocks(a)[i...] = value return a end +function blocksparse_view(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} + return blocksparse_view(a, Tuple(I)...) +end +function blocksparse_view(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N} + return SubArray(a, to_indices(a, I)) +end + function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n From 5e987b87e22e086f4b4d9bea0a957a49e9803d72 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 1 Jun 2024 12:42:09 -0400 Subject: [PATCH 2/9] Fix some more block operations --- .../blocksparsearrayinterface.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index ad5e17dbca..8e539946cc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -67,10 +67,11 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N end function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::BlockIndex{N}) where {N} - a_b = view(a, block(I)) + i = Int.(Tuple(block(I))) + a_b = blocks(a)[i...] a_b[I.α...] = value - # Set the block, required if it is structurally zero - a[block(I)] = a_b + # Set the block, required if it is structurally zero. + blocks(a)[i...] = a_b return a end @@ -85,7 +86,7 @@ function blocksparse_setindex!( @boundscheck blockcheckbounds(a, i...) # TODO: Use `blocksizes(a)[i...]` when we upgrade to # BlockArrays.jl v1. - @assert size(value) == size(view(a, I)) + @assert size(value) == size(view(a, I...)) blocks(a)[i...] = value return a end From aca628f74d70b02f34caa36b2c7df1387eefda24 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 11:36:58 -0400 Subject: [PATCH 3/9] Fix some issues introduces in map --- .../src/abstractblocksparsearray/map.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 22f9350605..6a7cb96ac1 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -32,8 +32,14 @@ function SparseArrayInterface.sparse_map!( for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) BI_dest = blockindexrange(a_dest, I) BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) - block_dest = @view a_dest[_block(BI_dest)] - block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) + # TODO: Investigate why this doesn't work: + # block_dest = @view a_dest[_block(BI_dest)] + block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] + # TODO: Investigate why this doesn't work: + # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) + block_srcs = ntuple(length(a_srcs)) do i + return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] + end subblock_dest = @view block_dest[BI_dest.indices...] subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) # TODO: Use `map!!` to handle immutable blocks. From 1ebd30ff1c8407c6b814094f535619c3e682afc6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 11:46:54 -0400 Subject: [PATCH 4/9] Update tests --- .../lib/BlockSparseArrays/test/test_basics.jl | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 565d349c33..193a5fe497 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -225,36 +225,41 @@ include("TestBlockSparseArraysUtils.jl") @test block_nstored(c) == 2 @test Array(c) == 2 * transpose(Array(a)) - ## Broken, need to fix. - a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) - @test_broken a[Block(1), Block(1):Block(2)] + b = a[Block(1), Block(1):Block(2)] + @test size(b) == (2, 7) + @test blocksize(b) == (1, 2) + @test b[Block(1, 1)] == a[Block(1, 1)] + @test b[Block(1, 2)] == a[Block(1, 2)] - # This is outputting only zero blocks. a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) - b = a[Block(2):Block(2), Block(1):Block(2)] - @test_broken block_nstored(b) == 1 - @test_broken b == Array(a)[3:5, 1:end] + b = copy(a) + x = randn(elt, size(@view(a[Block(2, 2)]))) + b[Block(2), Block(2)] = x + @test b[Block(2, 2)] == x a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) b = copy(a) - x = randn(size(@view(a[Block(2, 2)]))) - b[Block(2), Block(2)] = x - @test_broken b[Block(2, 2)] == x + b[Block(1, 1)] .= 1 + # TODO: Use `blocksizes(b)[1, 1]` once we upgrade to + # BlockArrays.jl v1. + @test b[Block(1, 1)] == trues(size(@view(b[Block(1, 1)]))) - # Doesnt' set the block + ## Broken, need to fix. + + # This is outputting only zero blocks. a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) - b = copy(a) - b[Block(1, 1)] .= 1 - @test_broken b[1, 1] == trues(size(@view(b[1, 1]))) + b = a[Block(2):Block(2), Block(1):Block(2)] + @test_broken block_nstored(b) == 1 + @test_broken b == Array(a)[3:5, 1:end] end @testset "LinearAlgebra" begin a1 = BlockSparseArray{elt}([2, 3], [2, 3]) From abc13b15f0e84eecba34477e2ebd9cbd34a0125d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 12:17:51 -0400 Subject: [PATCH 5/9] Fix tests --- .../BlockSparseArraysGradedAxesExt/test/runtests.jl | 4 ++-- .../BlockArraysExtensions/BlockArraysExtensions.jl | 11 +++++++++++ NDTensors/src/lib/GradedAxes/src/unitrangedual.jl | 6 ++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index e66c45aaef..d496ee57c6 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -80,8 +80,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "dual axes" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(dual(r), r) - a[Block(1, 1)] = randn(size(a[Block(1, 1)])) - a[Block(2, 2)] = randn(size(a[Block(2, 2)])) + a[Block(1, 1)] = randn(elt, size(a[Block(1, 1)])) + a[Block(2, 2)] = randn(elt, size(a[Block(2, 2)])) a_dense = Array(a) @test eachindex(a) == CartesianIndices(size(a)) for I in eachindex(a) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 0bb67b326e..56147069a5 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -6,6 +6,7 @@ using BlockArrays: BlockRange, BlockedUnitRange, BlockVector, + BlockSlice, block, blockaxes, blockedrange, @@ -29,6 +30,16 @@ function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange) return only(axes(blockedunitrange_getindices(a, indices))) end +function sub_axis(a::AbstractUnitRange, indices::BlockSlice) + return sub_axis(a, block(indices)) +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::Block) + return sub_axis(a, [indices]) +end + # TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block}) diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 52d8042736..3ce65a4d15 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -54,6 +54,12 @@ function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange) return BlockSlice(b, unlabel(a)) end +using BlockArrays: BlockArrays, BlockSlice +using NDTensors.GradedAxes: UnitRangeDual, dual +function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual) + return BlockSlice(b, dual(r)) +end + using NDTensors.LabelledNumbers: LabelledNumbers, label LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a))) From f1526997d36b79f2ef358ab0c866cf754e18ea57 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 12:37:30 -0400 Subject: [PATCH 6/9] Better error message for setting block with incorrect size, add test --- .../BlockArraysExtensions/BlockArraysExtensions.jl | 12 ++++++++++-- .../blocksparsearrayinterface.jl | 6 +++++- .../src/lib/BlockSparseArrays/test/test_basics.jl | 3 ++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 56147069a5..63db489b84 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -30,8 +30,16 @@ function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange) return only(axes(blockedunitrange_getindices(a, indices))) end -function sub_axis(a::AbstractUnitRange, indices::BlockSlice) - return sub_axis(a, block(indices)) +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockRange{1}}) + return sub_axis(a, indices.block) +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}}) + return sub_axis(a, Block(indices)) end # TODO: Use `GradedAxes.blockedunitrange_getindices`. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 8e539946cc..4b36a5e887 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -86,7 +86,11 @@ function blocksparse_setindex!( @boundscheck blockcheckbounds(a, i...) # TODO: Use `blocksizes(a)[i...]` when we upgrade to # BlockArrays.jl v1. - @assert size(value) == size(view(a, I...)) + if size(value) ≠ size(view(a, I...)) + return throw( + DimensionMismatch("Trying to set a block with an array of the wrong size.") + ) + end blocks(a)[i...] = value return a end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 193a5fe497..52b7a192e8 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -4,7 +4,7 @@ using LinearAlgebra: mul! using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored, block_reshape using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: contract -using Test: @test, @testset, @test_broken +using Test: @test, @test_broken, @test_throws, @testset include("TestBlockSparseArraysUtils.jl") @testset "BlockSparseArrays (eltype=$elt)" for elt in (Float32, Float64, ComplexF32, ComplexF64) @@ -20,6 +20,7 @@ include("TestBlockSparseArraysUtils.jl") @test block_nstored(a) == 0 @test iszero(a) @test all(I -> iszero(a[I]), eachindex(a)) + @test_throws DimensionMismatch a[Block(1, 1)] = randn(elt, 2, 3) a = BlockSparseArray{elt}([2, 3], [2, 3]) a[3, 3] = 33 From 8c677d6dc20158014c261a64f25e682a87984f02 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 12:43:43 -0400 Subject: [PATCH 7/9] [NDTensors] Bump to v0.3.19 --- NDTensors/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 57392bc89e..fc6b36d9f1 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.18" +version = "0.3.19" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" From 64e5597099798817d74a238a0e60731047f428cf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 13:44:32 -0400 Subject: [PATCH 8/9] Add and fix some more tests of slicing --- .../BlockArraysExtensions.jl | 18 +++++++++++++++++- .../abstractblocksparsearray/arraylayouts.jl | 11 +++++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 18 ++++++++++++++++++ .../src/lib/GradedAxes/src/gradedunitrange.jl | 5 +++++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 63db489b84..d0710d6209 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -42,10 +42,22 @@ function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}}) return sub_axis(a, Block(indices)) end +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockIndexRange{1}}) + return sub_axis(a, indices.block) +end + # TODO: Use `GradedAxes.blockedunitrange_getindices`. # Outputs a `BlockUnitRange`. function sub_axis(a::AbstractUnitRange, indices::Block) - return sub_axis(a, [indices]) + return only(axes(blockedunitrange_getindices(a, indices))) +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::BlockIndexRange) + return only(axes(blockedunitrange_getindices(a, indices))) end # TODO: Use `GradedAxes.blockedunitrange_getindices`. @@ -154,6 +166,10 @@ function blockrange(axis::AbstractUnitRange, r::Block{1}) return r:r end +function blockrange(axis::AbstractUnitRange, r::BlockIndexRange) + return Block(r):Block(r) +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl index c9cbf33cdc..dfe4052aee 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl @@ -16,7 +16,18 @@ end # Materialize a SubArray view. function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes) + # TODO: Make more generic for GPU. a_dest = BlockSparseArray{eltype(a)}(axes) a_dest .= a return a_dest end + +# Materialize a SubArray view. +function ArrayLayouts.sub_materialize( + layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}} +) + # TODO: Make more generic for GPU. + a_dest = Array{eltype(a)}(undef, length.(axes)) + a_dest .= a + return a_dest +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 52b7a192e8..d8d7191fbd 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -252,6 +252,24 @@ include("TestBlockSparseArraysUtils.jl") # BlockArrays.jl v1. @test b[Block(1, 1)] == trues(size(@view(b[Block(1, 1)]))) + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) + x = randn(elt, 1, 2) + @view(a[Block(2, 2)])[1:1, 1:2] = x + @test @view(a[Block(2, 2)])[1:1, 1:2] == x + @test a[Block(2, 2)][1:1, 1:2] == x + + # TODO: This is broken, fix! + @test_broken a[3:3, 4:5] == x + + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) + x = randn(elt, 1, 2) + @views a[Block(2, 2)][1:1, 1:2] = x + @test @view(a[Block(2, 2)])[1:1, 1:2] == x + @test a[Block(2, 2)][1:1, 1:2] == x + + # TODO: This is broken, fix! + @test_broken a[3:3, 4:5] == x + ## Broken, need to fix. # This is outputting only zero blocks. diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 8a19b6e04f..6e414cbc8a 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -221,6 +221,11 @@ function blockedunitrange_getindices( return mortar(map(index -> a[index], indices)) end +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices(a::BlockedUnitRange, indices::Block{1}) + return a[indices] +end + # TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices) return error("Not implemented.") From e948603946ba13cfeb5b43a1a7b9377f2c525ce9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 3 Jun 2024 14:50:00 -0400 Subject: [PATCH 9/9] Fix tests --- .../src/lib/GradedAxes/src/unitrangedual.jl | 5 ++++ .../LabelledNumbers/src/labelledinteger.jl | 29 +++++++++++++++++++ .../src/lib/LabelledNumbers/test/runtests.jl | 24 +++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 3ce65a4d15..8f1a86f1b8 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -38,6 +38,11 @@ function unitrangedual_getindices_blocks(a, indices) return mortar([dual(b) for b in blocks(a_indices)]) end +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1}) + return a[indices] +end + function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl index f5e2d58f3d..323d252b0c 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl @@ -86,3 +86,32 @@ Base.:-(x::LabelledInteger) = labelled_minus(x) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. Base.hash(x::LabelledInteger, h::UInt64) = labelled_hash(x, h) + +using Random: AbstractRNG, default_rng +default_eltype() = Float64 +for f in [:rand, :randn] + @eval begin + function Base.$f( + rng::AbstractRNG, + elt::Type{<:Number}, + dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}, + ) + return a = $f(rng, elt, unlabel.(dims)) + end + function Base.$f( + rng::AbstractRNG, + elt::Type{<:Number}, + dim1::LabelledInteger, + dims::Vararg{LabelledInteger}, + ) + return $f(rng, elt, (dim1, dims...)) + end + Base.$f(elt::Type{<:Number}, dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) = + $f(default_rng(), elt, dims) + Base.$f(elt::Type{<:Number}, dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = + $f(elt, (dim1, dims...)) + Base.$f(dims::Tuple{LabelledInteger,Vararg{LabelledInteger}}) = + $f(default_eltype(), dims) + Base.$f(dim1::LabelledInteger, dims::Vararg{LabelledInteger}) = $f((dim1, dims...)) + end +end diff --git a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl index cf3f87e86d..6fc1ac4231 100644 --- a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl +++ b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl @@ -1,4 +1,5 @@ @eval module $(gensym()) +using LinearAlgebra: norm using NDTensors.LabelledNumbers: LabelledInteger, islabelled, label, labelled, unlabel using Test: @test, @testset @testset "LabelledNumbers" begin @@ -48,6 +49,29 @@ using Test: @test, @testset @test one(typeof(x)) == true @test !islabelled(one(typeof(x))) end + @testset "randn" begin + d = labelled(2, "x") + + a = randn(Float32, d, d) + @test eltype(a) === Float32 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = rand(Float32, d, d) + @test eltype(a) === Float32 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = randn(d, d) + @test eltype(a) === Float64 + @test size(a) == (2, 2) + @test norm(a) > 0 + + a = rand(d, d) + @test eltype(a) === Float64 + @test size(a) == (2, 2) + @test norm(a) > 0 + end @testset "Labelled array ($a)" for a in (collect(2:5), 2:5) x = labelled(a, "x") @test eltype(x) == LabelledInteger{Int,String}