diff --git a/dask_cuda/__init__.py b/dask_cuda/__init__.py index ed8e6ae9e..dc971797f 100644 --- a/dask_cuda/__init__.py +++ b/dask_cuda/__init__.py @@ -7,10 +7,14 @@ import dask import dask.dataframe.core import dask.dataframe.shuffle +import dask.dataframe.multi from ._version import get_versions from .cuda_worker import CUDAWorker -from .explicit_comms.dataframe.shuffle import get_rearrange_by_column_tasks_wrapper +from .explicit_comms.dataframe.shuffle import ( + get_rearrange_by_column_wrapper, + get_default_shuffle_algorithm, +) from .local_cuda_cluster import LocalCUDACluster from .proxify_device_objects import proxify_decorator, unproxify_decorator @@ -19,12 +23,10 @@ # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` -dask.dataframe.shuffle.rearrange_by_column_tasks = ( - get_rearrange_by_column_tasks_wrapper( - dask.dataframe.shuffle.rearrange_by_column_tasks - ) +dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper( + dask.dataframe.shuffle.rearrange_by_column ) - +dask.dataframe.multi.get_default_shuffle_algorithm = get_default_shuffle_algorithm # Monkey patching Dask to make use of proxify and unproxify in compatibility mode dask.dataframe.shuffle.shuffle_group = proxify_decorator( diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 4b240d2f1..a444fce0b 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -9,7 +9,10 @@ from typing import Any, Callable, Dict, List, Optional, Set, TypeVar import dask +import dask.config import dask.dataframe +import dask.utils +import distributed.worker from dask.base import tokenize from dask.dataframe.core import DataFrame, Series, _concat as dd_concat, new_dd_object from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch @@ -467,7 +470,7 @@ def shuffle( # Step (a): df = df.persist() # Make sure optimizations are apply on the existing graph - wait(df) # Make sure all keys has been materialized on workers + wait([df]) # Make sure all keys has been materialized on workers name = ( "explicit-comms-shuffle-" f"{tokenize(df, column_names, npartitions, ignore_index)}" @@ -534,7 +537,7 @@ def shuffle( # Create a distributed Dataframe from all the pieces divs = [None] * (len(dsk) + 1) ret = new_dd_object(dsk, name, df_meta, divs).persist() - wait(ret) + wait([ret]) # Release all temporary dataframes for fut in [*shuffle_result.values(), *dsk.values()]: @@ -542,7 +545,20 @@ def shuffle( return ret -def get_rearrange_by_column_tasks_wrapper(func): +def _use_explicit_comms() -> bool: + """Is explicit-comms and available?""" + if dask.config.get("explicit-comms", False): + try: + # Make sure we have an activate client. + distributed.worker.get_client() + except (ImportError, ValueError): + pass + else: + return True + return False + + +def get_rearrange_by_column_wrapper(func): """Returns a function wrapper that dispatch the shuffle to explicit-comms. Notice, this is monkey patched into Dask at dask_cuda import @@ -552,23 +568,30 @@ def get_rearrange_by_column_tasks_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if dask.config.get("explicit-comms", False): - try: - import distributed.worker - - # Make sure we have an activate client. - distributed.worker.get_client() - except (ImportError, ValueError): - pass - else: - # Convert `*args, **kwargs` to a dict of `keyword -> values` - kw = func_sig.bind(*args, **kwargs) - kw.apply_defaults() - kw = kw.arguments - column = kw["column"] - if isinstance(column, str): - column = [column] - return shuffle(kw["df"], column, kw["npartitions"], kw["ignore_index"]) + if _use_explicit_comms(): + # Convert `*args, **kwargs` to a dict of `keyword -> values` + kw = func_sig.bind(*args, **kwargs) + kw.apply_defaults() + kw = kw.arguments + # Notice, we only overwrite the default and the "tasks" shuffle + # algorithm. The "disk" and "p2p" algorithm, we don't touch. + if kw["shuffle"] in ("tasks", None): + col = kw["col"] + if isinstance(col, str): + col = [col] + return shuffle(kw["df"], col, kw["npartitions"], kw["ignore_index"]) return func(*args, **kwargs) return wrapper + + +def get_default_shuffle_algorithm() -> str: + """Return the default shuffle algorithm used by Dask + + This changes the default shuffle algorithm from "p2p" to "tasks" + when explicit comms is enabled. + """ + ret = dask.config.get("dataframe.shuffle.algorithm", None) + if ret is None and _use_explicit_comms(): + return "tasks" + return dask.utils.get_default_shuffle_algorithm() diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 413bf5bdd..624815e75 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -183,10 +183,10 @@ def check_shuffle(): name = "explicit-comms-shuffle" ddf = dd.from_pandas(pd.DataFrame({"key": np.arange(10)}), npartitions=2) with dask.config.set(explicit_comms=False): - res = ddf.shuffle(on="key", npartitions=4, shuffle="tasks") + res = ddf.shuffle(on="key", npartitions=4) assert all(name not in str(key) for key in res.dask) with dask.config.set(explicit_comms=True): - res = ddf.shuffle(on="key", npartitions=4, shuffle="tasks") + res = ddf.shuffle(on="key", npartitions=4) if in_cluster: assert any(name in str(key) for key in res.dask) else: # If not in cluster, we cannot use explicit comms @@ -200,7 +200,7 @@ def check_shuffle(): ): dask.config.refresh() # Trigger re-read of the environment variables with pytest.raises(ValueError, match="explicit-comms-batchsize"): - ddf.shuffle(on="key", npartitions=4, shuffle="tasks") + ddf.shuffle(on="key", npartitions=4) if in_cluster: with LocalCluster(