From 634fcc6380aea018eaf31bf647eaf7f156021ceb Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Nov 2024 18:02:54 -0500 Subject: [PATCH] Add tests --- .../lib/BlockSparseArrays/test/test_basics.jl | 27 ++++++++++++++++ .../test/test_abstractsparsearray.jl | 32 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 4374be541c..32990471a0 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype @test a1' * a2 ≈ Array(a1)' * Array(a2) @test dot(a1, a2) ≈ a1' * a2 end + @testset "cat" begin + a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)])))) + a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)])))) + + a_dest = cat(a1, a2; dims=1) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(3, 2)] == a2[Block(1, 2)] + + a_dest = cat(a1, a2; dims=2) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(1, 4)] == a2[Block(1, 2)] + + a_dest = cat(a1, a2; dims=(1, 2)) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(3, 4)] == a2[Block(1, 2)] + end @testset "TensorAlgebra" begin a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)])))) diff --git a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl index 743f457d43..47cf6668c6 100644 --- a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl +++ b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl @@ -342,6 +342,38 @@ using Test: @test, @testset @test a_dest isa SparseArray{elt} @test SparseArrayInterface.nstored(a_dest) == 2 + # cat + a1 = SparseArray{elt}(2, 3) + a1[1, 2] = 12 + a1[2, 1] = 21 + a2 = SparseArray{elt}(2, 3) + a2[1, 1] = 11 + a2[2, 2] = 22 + + a_dest = cat(a1, a2; dims=1) + @test size(a_dest) == (4, 3) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[3, 1] == a2[1, 1] + @test a_dest[4, 2] == a2[2, 2] + + a_dest = cat(a1, a2; dims=2) + @test size(a_dest) == (2, 6) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[1, 4] == a2[1, 1] + @test a_dest[2, 5] == a2[2, 2] + + a_dest = cat(a1, a2; dims=(1, 2)) + @test size(a_dest) == (4, 6) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[3, 4] == a2[1, 1] + @test a_dest[4, 5] == a2[2, 2] + ## # Sparse matrix of matrix multiplication ## TODO: Make this work, seems to require ## a custom zero constructor.