Skip to content

Commit

Permalink
Add tests for mixed-format sparse-sparse and sparse-dense add.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Nov 15, 2024
1 parent d5d0158 commit 0b910b9
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@
],
)

parametrize_scipy_fmt = pytest.mark.parametrize(
"format",
["csr", "csc", "coo"],
)

def parametrize_scipy_fmt_with_arg(name: str) -> pytest.MarkDecorator:
return pytest.mark.parametrize(
name,
["csr", "csc", "coo"],
)


parametrize_scipy_fmt = parametrize_scipy_fmt_with_arg("format")


def assert_sps_equal(
Expand Down Expand Up @@ -187,26 +192,28 @@ def test_roundtrip_dense(rng, dtype, shape):


@parametrize_dtypes
@parametrize_scipy_fmt
def test_add(rng, dtype, format):
if format == "coo":
@parametrize_scipy_fmt_with_arg("format1")
@parametrize_scipy_fmt_with_arg("format2")
def test_add(rng, dtype, format1, format2):
if format1 == "coo" or format2 == "coo":
pytest.xfail(reason="https://github.com/llvm/llvm-project/issues/116012")

SHAPE = (100, 50)
DENSITY = 0.5
sampler = generate_sampler(dtype, rng)
sps_arr1 = sps.random_array(
SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
SHAPE, density=DENSITY, format=format1, dtype=dtype, random_state=rng, data_sampler=sampler
)
sps_arr2 = sps.random_array(
SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
SHAPE, density=DENSITY, format=format2, dtype=dtype, random_state=rng, data_sampler=sampler
)

sp_arr1 = sparse.asarray(sps_arr1)
sp_arr2 = sparse.asarray(sps_arr2)

expected = sps_arr1 + sps_arr2
actual = sparse.add(sp_arr1, sp_arr2)
actual_sps = sparse.to_scipy(actual.asformat(sp_arr1.format))
actual_sps = sparse.to_scipy(actual.asformat(sparse.asarray(expected).format))

assert_sps_equal(expected, actual_sps, check_canonical=True)

Expand All @@ -228,6 +235,31 @@ def test_add_dense(rng, dtype, shape):
np.testing.assert_array_equal(expected, actual_np)


@parametrize_dtypes
@parametrize_scipy_fmt
def test_add_dense_sparse(rng, dtype, format):
if format == "coo":
pytest.xfail(reason="https://github.com/llvm/llvm-project/issues/116012")
sampler = generate_sampler(dtype, rng)

SHAPE = (100, 50)
DENSITY = 0.5

np_arr1 = sampler(SHAPE)
sps_arr2 = sps.random_array(
SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
)

sp_arr1 = sparse.asarray(np_arr1)
sp_arr2 = sparse.asarray(sps_arr2)

expected = np_arr1 + sps_arr2
actual = sparse.add(sp_arr1, sp_arr2)
actual_np = sparse.to_numpy(actual.asformat(sp_arr1.format))

np.testing.assert_array_equal(expected, actual_np)


@parametrize_dtypes
def test_csf_format(dtype):
format = sparse.levels.get_storage_format(
Expand Down

0 comments on commit 0b910b9

Please sign in to comment.