diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 0ca1c48ee..854115fe0 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -328,7 +328,7 @@ async def shuffle_task( ignore_index: bool, num_rounds: int, batchsize: int, -) -> List[DataFrame]: +) -> Dict[int, DataFrame]: """Explicit-comms shuffle task This function is running on each worker participating in the shuffle. @@ -360,8 +360,8 @@ async def shuffle_task( Returns ------- - partitions: list of DataFrames - List of dataframe-partitions + partitions: dict + dict that maps each Partition ID to a dataframe-partition """ proxify = get_proxify(s["worker"]) @@ -387,14 +387,13 @@ async def shuffle_task( ) # Finally, we concatenate the output dataframes into the final output partitions - ret = [] + ret = {} while out_part_id_to_dataframe_list: - ret.append( - proxify( - dd_concat( - out_part_id_to_dataframe_list.popitem()[1], - ignore_index=ignore_index, - ) + part_id, dataframe_list = out_part_id_to_dataframe_list.popitem() + ret[part_id] = proxify( + dd_concat( + dataframe_list, + ignore_index=ignore_index, ) ) # For robustness, we yield this task to give Dask a chance to do bookkeeping @@ -529,9 +528,12 @@ def shuffle( dsk = {} for rank in ranks: - for i, part_id in enumerate(rank_to_out_part_ids[rank]): + for part_id in rank_to_out_part_ids[rank]: dsk[(name, part_id)] = c.client.submit( - getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]] + getitem, + shuffle_result[rank], + part_id, + workers=[c.worker_addresses[rank]], ) # Create a distributed Dataframe from all the pieces diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 1a15370b5..ae4e3332c 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -93,7 +93,7 @@ def check_partitions(df, npartitions): return True -def _test_dataframe_shuffle(backend, protocol, n_workers): +def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): if backend == "cudf": cudf = pytest.importorskip("cudf") @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): if backend == "cudf": df = cudf.DataFrame.from_pandas(df) + if _partitions: + df["_partitions"] = 0 + for input_nparts in range(1, 5): for output_nparts in range(1, 5): ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist( @@ -123,33 +126,45 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): with dask.config.set(explicit_comms_batchsize=batchsize): ddf = explicit_comms_shuffle( ddf, - ["key"], + ["_partitions"] if _partitions else ["key"], npartitions=output_nparts, batchsize=batchsize, ).persist() assert ddf.npartitions == output_nparts - # Check that each partition hashes to the same value - result = ddf.map_partitions( - check_partitions, output_nparts - ).compute() - assert all(result.to_list()) - - # Check the values (ignoring the row order) - expected = df.sort_values("key") - got = ddf.compute().sort_values("key") - assert_eq(got, expected) + if _partitions: + # If "_partitions" is the hash key, we expect all but + # the first partition to be empty + assert_eq(ddf.partitions[0].compute(), df) + assert all( + len(ddf.partitions[i].compute()) == 0 + for i in range(1, ddf.npartitions) + ) + else: + # Check that each partition hashes to the same value + result = ddf.map_partitions( + check_partitions, output_nparts + ).compute() + assert all(result.to_list()) + + # Check the values (ignoring the row order) + expected = df.sort_values("key") + got = ddf.compute().sort_values("key") + assert_eq(got, expected) @pytest.mark.parametrize("nworkers", [1, 2, 3]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -def test_dataframe_shuffle(backend, protocol, nworkers): +@pytest.mark.parametrize("_partitions", [True, False]) +def test_dataframe_shuffle(backend, protocol, nworkers, _partitions): if backend == "cudf": pytest.importorskip("cudf") - p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers)) + p = mp.Process( + target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions) + ) p.start() p.join() assert not p.exitcode