Skip to content

Commit

Permalink
Explicit-comms: preserve partition IDs (#1240)
Browse files Browse the repository at this point in the history
`shuffle_task()` now returns a dict mapping partition IDs to dataframes`

Fixes #1239

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #1240
  • Loading branch information
madsbk authored Sep 26, 2023
1 parent 8f1840f commit 6bd4ba4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
26 changes: 14 additions & 12 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
) -> List[DataFrame]:
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task
This function is running on each worker participating in the shuffle.
Expand Down Expand Up @@ -360,8 +360,8 @@ async def shuffle_task(
Returns
-------
partitions: list of DataFrames
List of dataframe-partitions
partitions: dict
dict that maps each Partition ID to a dataframe-partition
"""

proxify = get_proxify(s["worker"])
Expand All @@ -387,14 +387,13 @@ async def shuffle_task(
)

# Finally, we concatenate the output dataframes into the final output partitions
ret = []
ret = {}
while out_part_id_to_dataframe_list:
ret.append(
proxify(
dd_concat(
out_part_id_to_dataframe_list.popitem()[1],
ignore_index=ignore_index,
)
part_id, dataframe_list = out_part_id_to_dataframe_list.popitem()
ret[part_id] = proxify(
dd_concat(
dataframe_list,
ignore_index=ignore_index,
)
)
# For robustness, we yield this task to give Dask a chance to do bookkeeping
Expand Down Expand Up @@ -529,9 +528,12 @@ def shuffle(

dsk = {}
for rank in ranks:
for i, part_id in enumerate(rank_to_out_part_ids[rank]):
for part_id in rank_to_out_part_ids[rank]:
dsk[(name, part_id)] = c.client.submit(
getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]]
getitem,
shuffle_result[rank],
part_id,
workers=[c.worker_addresses[rank]],
)

# Create a distributed Dataframe from all the pieces
Expand Down
43 changes: 29 additions & 14 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_partitions(df, npartitions):
return True


def _test_dataframe_shuffle(backend, protocol, n_workers):
def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

Expand All @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
df = cudf.DataFrame.from_pandas(df)

if _partitions:
df["_partitions"] = 0

for input_nparts in range(1, 5):
for output_nparts in range(1, 5):
ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist(
Expand All @@ -123,33 +126,45 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
with dask.config.set(explicit_comms_batchsize=batchsize):
ddf = explicit_comms_shuffle(
ddf,
["key"],
["_partitions"] if _partitions else ["key"],
npartitions=output_nparts,
batchsize=batchsize,
).persist()

assert ddf.npartitions == output_nparts

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)
if _partitions:
# If "_partitions" is the hash key, we expect all but
# the first partition to be empty
assert_eq(ddf.partitions[0].compute(), df)
assert all(
len(ddf.partitions[i].compute()) == 0
for i in range(1, ddf.npartitions)
)
else:
# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_dataframe_shuffle(backend, protocol, nworkers):
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
pytest.importorskip("cudf")

p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers))
p = mp.Process(
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
)
p.start()
p.join()
assert not p.exitcode
Expand Down

0 comments on commit 6bd4ba4

Please sign in to comment.