Skip to content

Commit

Permalink
Explicit-comms: update monkey patching of Dask (#1135)
Browse files Browse the repository at this point in the history
Closes #1134 by wrapping `rearrange_by_column()` instead of `rearrange_by_column_tasks()`.  

This also has the bonus that we avoid a re-partition when the shuffle changes number of partitions: https://github.com/dask/dask/blob/945f4e8b7646228aff34da07ffaa52f1b73aa1e0/dask/dataframe/shuffle.py#L510

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #1135
  • Loading branch information
madsbk authored Feb 28, 2023
1 parent af0f3ef commit 92190af
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 29 deletions.
14 changes: 8 additions & 6 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import dask
import dask.dataframe.core
import dask.dataframe.shuffle
import dask.dataframe.multi

from ._version import get_versions
from .cuda_worker import CUDAWorker
from .explicit_comms.dataframe.shuffle import get_rearrange_by_column_tasks_wrapper
from .explicit_comms.dataframe.shuffle import (
get_rearrange_by_column_wrapper,
get_default_shuffle_algorithm,
)
from .local_cuda_cluster import LocalCUDACluster
from .proxify_device_objects import proxify_decorator, unproxify_decorator

Expand All @@ -19,12 +23,10 @@


# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
dask.dataframe.shuffle.rearrange_by_column_tasks = (
get_rearrange_by_column_tasks_wrapper(
dask.dataframe.shuffle.rearrange_by_column_tasks
)
dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
dask.dataframe.shuffle.rearrange_by_column
)

dask.dataframe.multi.get_default_shuffle_algorithm = get_default_shuffle_algorithm

# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
dask.dataframe.shuffle.shuffle_group = proxify_decorator(
Expand Down
63 changes: 43 additions & 20 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar

import dask
import dask.config
import dask.dataframe
import dask.utils
import distributed.worker
from dask.base import tokenize
from dask.dataframe.core import DataFrame, Series, _concat as dd_concat, new_dd_object
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
Expand Down Expand Up @@ -467,7 +470,7 @@ def shuffle(

# Step (a):
df = df.persist() # Make sure optimizations are apply on the existing graph
wait(df) # Make sure all keys has been materialized on workers
wait([df]) # Make sure all keys has been materialized on workers
name = (
"explicit-comms-shuffle-"
f"{tokenize(df, column_names, npartitions, ignore_index)}"
Expand Down Expand Up @@ -534,15 +537,28 @@ def shuffle(
# Create a distributed Dataframe from all the pieces
divs = [None] * (len(dsk) + 1)
ret = new_dd_object(dsk, name, df_meta, divs).persist()
wait(ret)
wait([ret])

# Release all temporary dataframes
for fut in [*shuffle_result.values(), *dsk.values()]:
fut.release()
return ret


def get_rearrange_by_column_tasks_wrapper(func):
def _use_explicit_comms() -> bool:
"""Is explicit-comms and available?"""
if dask.config.get("explicit-comms", False):
try:
# Make sure we have an activate client.
distributed.worker.get_client()
except (ImportError, ValueError):
pass
else:
return True
return False


def get_rearrange_by_column_wrapper(func):
"""Returns a function wrapper that dispatch the shuffle to explicit-comms.
Notice, this is monkey patched into Dask at dask_cuda import
Expand All @@ -552,23 +568,30 @@ def get_rearrange_by_column_tasks_wrapper(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if dask.config.get("explicit-comms", False):
try:
import distributed.worker

# Make sure we have an activate client.
distributed.worker.get_client()
except (ImportError, ValueError):
pass
else:
# Convert `*args, **kwargs` to a dict of `keyword -> values`
kw = func_sig.bind(*args, **kwargs)
kw.apply_defaults()
kw = kw.arguments
column = kw["column"]
if isinstance(column, str):
column = [column]
return shuffle(kw["df"], column, kw["npartitions"], kw["ignore_index"])
if _use_explicit_comms():
# Convert `*args, **kwargs` to a dict of `keyword -> values`
kw = func_sig.bind(*args, **kwargs)
kw.apply_defaults()
kw = kw.arguments
# Notice, we only overwrite the default and the "tasks" shuffle
# algorithm. The "disk" and "p2p" algorithm, we don't touch.
if kw["shuffle"] in ("tasks", None):
col = kw["col"]
if isinstance(col, str):
col = [col]
return shuffle(kw["df"], col, kw["npartitions"], kw["ignore_index"])
return func(*args, **kwargs)

return wrapper


def get_default_shuffle_algorithm() -> str:
"""Return the default shuffle algorithm used by Dask
This changes the default shuffle algorithm from "p2p" to "tasks"
when explicit comms is enabled.
"""
ret = dask.config.get("dataframe.shuffle.algorithm", None)
if ret is None and _use_explicit_comms():
return "tasks"
return dask.utils.get_default_shuffle_algorithm()
6 changes: 3 additions & 3 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def check_shuffle():
name = "explicit-comms-shuffle"
ddf = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2)
with dask.config.set(explicit_comms=False):
res = ddf.shuffle(on="key", npartitions=4, shuffle="tasks")
res = ddf.shuffle(on="key", npartitions=4)
assert all(name not in str(key) for key in res.dask)
with dask.config.set(explicit_comms=True):
res = ddf.shuffle(on="key", npartitions=4, shuffle="tasks")
res = ddf.shuffle(on="key", npartitions=4)
if in_cluster:
assert any(name in str(key) for key in res.dask)
else: # If not in cluster, we cannot use explicit comms
Expand All @@ -200,7 +200,7 @@ def check_shuffle():
):
dask.config.refresh() # Trigger re-read of the environment variables
with pytest.raises(ValueError, match="explicit-comms-batchsize"):
ddf.shuffle(on="key", npartitions=4, shuffle="tasks")
ddf.shuffle(on="key", npartitions=4)

if in_cluster:
with LocalCluster(
Expand Down

0 comments on commit 92190af

Please sign in to comment.