From c6950a900e9068abaae3528b227bdf512bd97f1d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 26 Sep 2023 12:51:25 +0200 Subject: [PATCH] test _partitions --- dask_cuda/tests/test_explicit_comms.py | 42 ++++++++++++++++++-------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 1a15370b5..49bd7058a 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -93,7 +93,7 @@ def check_partitions(df, npartitions): return True -def _test_dataframe_shuffle(backend, protocol, n_workers): +def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): if backend == "cudf": cudf = pytest.importorskip("cudf") @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): if backend == "cudf": df = cudf.DataFrame.from_pandas(df) + if _partitions: + df["_partitions"] = 0 + for input_nparts in range(1, 5): for output_nparts in range(1, 5): ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist( @@ -123,33 +126,46 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): with dask.config.set(explicit_comms_batchsize=batchsize): ddf = explicit_comms_shuffle( ddf, - ["key"], + ["_partitions"] if _partitions else ["key"], npartitions=output_nparts, batchsize=batchsize, ).persist() assert ddf.npartitions == output_nparts - # Check that each partition hashes to the same value - result = ddf.map_partitions( - check_partitions, output_nparts - ).compute() - assert all(result.to_list()) + if _partitions: + # If "_partitions" is the hash key, we except all but + # the first partition to be empty + assert_eq(ddf.partitions[0].compute(), df) + assert all( + len(ddf.partitions[i].compute()) == 0 + for i in range(1, ddf.npartitions) + ) + else: + + # Check that each partition hashes to the same value + result = ddf.map_partitions( + check_partitions, output_nparts + ).compute() + assert all(result.to_list()) - # Check the values (ignoring the row order) - expected = df.sort_values("key") - got = ddf.compute().sort_values("key") - assert_eq(got, expected) + # Check the values (ignoring the row order) + expected = df.sort_values("key") + got = ddf.compute().sort_values("key") + assert_eq(got, expected) @pytest.mark.parametrize("nworkers", [1, 2, 3]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -def test_dataframe_shuffle(backend, protocol, nworkers): +@pytest.mark.parametrize("_partitions", [True, False]) +def test_dataframe_shuffle(backend, protocol, nworkers, _partitions): if backend == "cudf": pytest.importorskip("cudf") - p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers)) + p = mp.Process( + target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions) + ) p.start() p.join() assert not p.exitcode