Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 13, 2024
1 parent c8979ed commit 634fcc6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
27 changes: 27 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test a1' * a2 Array(a1)' * Array(a2)
@test dot(a1, a2) a1' * a2
end
@testset "cat" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))

a_dest = cat(a1, a2; dims=1)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=2)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=(1, 2))
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
end
@testset "TensorAlgebra" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ using Test: @test, @testset
@test a_dest isa SparseArray{elt}
@test SparseArrayInterface.nstored(a_dest) == 2

# cat
a1 = SparseArray{elt}(2, 3)
a1[1, 2] = 12
a1[2, 1] = 21
a2 = SparseArray{elt}(2, 3)
a2[1, 1] = 11
a2[2, 2] = 22

a_dest = cat(a1, a2; dims=1)
@test size(a_dest) == (4, 3)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 1] == a2[1, 1]
@test a_dest[4, 2] == a2[2, 2]

a_dest = cat(a1, a2; dims=2)
@test size(a_dest) == (2, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[1, 4] == a2[1, 1]
@test a_dest[2, 5] == a2[2, 2]

a_dest = cat(a1, a2; dims=(1, 2))
@test size(a_dest) == (4, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 4] == a2[1, 1]
@test a_dest[4, 5] == a2[2, 2]

## # Sparse matrix of matrix multiplication
## TODO: Make this work, seems to require
## a custom zero constructor.
Expand Down

0 comments on commit 634fcc6

Please sign in to comment.