Skip to content

Commit

Permalink
shuffle_task() now returns a dict mapping partition IDs to dataframes
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 26, 2023
1 parent ec80f97 commit 4aa35aa
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
) -> List[DataFrame]:
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task
This function is running on each worker participating in the shuffle.
Expand Down Expand Up @@ -360,8 +360,8 @@ async def shuffle_task(
Returns
-------
partitions: list of DataFrames
List of dataframe-partitions
partitions: dict
dict that maps each Partition ID to a dataframe-partition
"""

proxify = get_proxify(s["worker"])
Expand All @@ -387,14 +387,13 @@ async def shuffle_task(
)

# Finally, we concatenate the output dataframes into the final output partitions
ret = []
ret = {}
while out_part_id_to_dataframe_list:
ret.append(
proxify(
dd_concat(
out_part_id_to_dataframe_list.popitem()[1],
ignore_index=ignore_index,
)
part_id, dataframe_list = out_part_id_to_dataframe_list.popitem()
ret[part_id] = proxify(
dd_concat(
dataframe_list,
ignore_index=ignore_index,
)
)
# For robustness, we yield this task to give Dask a chance to do bookkeeping
Expand Down Expand Up @@ -529,9 +528,12 @@ def shuffle(

dsk = {}
for rank in ranks:
for i, part_id in enumerate(rank_to_out_part_ids[rank]):
for part_id in rank_to_out_part_ids[rank]:
dsk[(name, part_id)] = c.client.submit(
getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]]
getitem,
shuffle_result[rank],
part_id,
workers=[c.worker_addresses[rank]],
)

# Create a distributed Dataframe from all the pieces
Expand Down

0 comments on commit 4aa35aa

Please sign in to comment.