From 4aa35aa5adcd9a2064b39bd214666508e5c66154 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 26 Sep 2023 10:32:57 +0200 Subject: [PATCH] shuffle_task() now returns a dict mapping partition IDs to dataframes --- dask_cuda/explicit_comms/dataframe/shuffle.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 0ca1c48ee..854115fe0 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -328,7 +328,7 @@ async def shuffle_task( ignore_index: bool, num_rounds: int, batchsize: int, -) -> List[DataFrame]: +) -> Dict[int, DataFrame]: """Explicit-comms shuffle task This function is running on each worker participating in the shuffle. @@ -360,8 +360,8 @@ async def shuffle_task( Returns ------- - partitions: list of DataFrames - List of dataframe-partitions + partitions: dict + dict that maps each Partition ID to a dataframe-partition """ proxify = get_proxify(s["worker"]) @@ -387,14 +387,13 @@ async def shuffle_task( ) # Finally, we concatenate the output dataframes into the final output partitions - ret = [] + ret = {} while out_part_id_to_dataframe_list: - ret.append( - proxify( - dd_concat( - out_part_id_to_dataframe_list.popitem()[1], - ignore_index=ignore_index, - ) + part_id, dataframe_list = out_part_id_to_dataframe_list.popitem() + ret[part_id] = proxify( + dd_concat( + dataframe_list, + ignore_index=ignore_index, ) ) # For robustness, we yield this task to give Dask a chance to do bookkeeping @@ -529,9 +528,12 @@ def shuffle( dsk = {} for rank in ranks: - for i, part_id in enumerate(rank_to_out_part_ids[rank]): + for part_id in rank_to_out_part_ids[rank]: dsk[(name, part_id)] = c.client.submit( - getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]] + getitem, + shuffle_result[rank], + part_id, + workers=[c.worker_addresses[rank]], ) # Create a distributed Dataframe from all the pieces