Skip to content

Commit

Permalink
use futuers_of
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Mar 20, 2024
1 parent c11c47c commit dc32127
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
import dask.utils
import distributed.worker
from dask.base import tokenize
from dask.dataframe.core import DataFrame, Series, _concat as dd_concat
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker

from .. import comms
from dask_cuda.utils import _make_collection

from .. import comms

T = TypeVar("T")


Expand Down Expand Up @@ -469,8 +471,9 @@ def shuffle(
npartitions = df.npartitions

# Step (a):
df = df.persist() # Make sure optimizations are apply on the existing graph
df = df.persist() # Make sure optimizations are applied on the existing graph
wait([df]) # Make sure all keys has been materialized on workers
persisted_keys = [f.key for f in c.client.futures_of(df)]
name = (
"explicit-comms-shuffle-"
f"{tokenize(df, column_names, npartitions, ignore_index)}"
Expand All @@ -480,7 +483,7 @@ def shuffle(
# Stage all keys of `df` on the workers and cancel them, which makes it possible
# for the shuffle to free memory as the partitions of `df` are consumed.
# See CommsContext.stage_keys() for a description of staging.
rank_to_inkeys = c.stage_keys(name=name, keys=df.__dask_keys__())
rank_to_inkeys = c.stage_keys(name=name, keys=persisted_keys)
c.client.cancel(df)

# Get batchsize
Expand Down

0 comments on commit dc32127

Please sign in to comment.