diff --git a/dask_cuda/__init__.py b/dask_cuda/__init__.py index d9a775ff..d6ac9655 100644 --- a/dask_cuda/__init__.py +++ b/dask_cuda/__init__.py @@ -15,7 +15,6 @@ from ._version import __git_commit__, __version__ from .cuda_worker import CUDAWorker from .explicit_comms.dataframe.shuffle import ( - get_default_shuffle_method, patch_shuffle_expression, ) from .local_cuda_cluster import LocalCUDACluster @@ -24,11 +23,6 @@ # Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True` patch_shuffle_expression() -# We have to replace all modules that imports Dask's `get_default_shuffle_method()` -# TODO: introduce a shuffle-algorithm dispatcher in Dask so we don't need this hack -dask.dataframe.shuffle.get_default_shuffle_method = get_default_shuffle_method -dask.dataframe.multi.get_default_shuffle_method = get_default_shuffle_method -dask.bag.core.get_default_shuffle_method = get_default_shuffle_method # Monkey patching Dask to make use of proxify and unproxify in compatibility mode diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 79e54060..600da07d 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -568,31 +568,20 @@ def _use_explicit_comms() -> bool: return False -def get_default_shuffle_method() -> str: - """Return the default shuffle algorithm used by Dask - - This changes the default shuffle algorithm from "p2p" to "tasks" - when explicit comms is enabled. - """ - ret = dask.config.get("dataframe.shuffle.algorithm", None) - if ret is None and _use_explicit_comms(): - return "tasks" - return dask.utils.get_default_shuffle_method() - - def patch_shuffle_expression() -> None: """Patch Dasks Shuffle expression. Notice, this is monkey patched into Dask at dask_cuda - import, and it changes `TaskShuffle._layer` to execute - an explicit-comms shuffle. + import, and it changes `Shuffle._layer` to lower into + an `ECShuffle` expression when the 'explicit-comms' + config is set to `True`. """ import dask_expr - _base_layer = dask_expr._shuffle.TaskShuffle._layer + class ECShuffle(dask_expr._shuffle.TaskShuffle): + """Explicit-Comms Shuffle Expression.""" - def _patched_layer(self): - if _use_explicit_comms(): + def _layer(self): # Execute an explicit-comms shuffle if not hasattr(self, "_ec_shuffled"): on = self.partitioning_index @@ -608,8 +597,20 @@ def _patched_layer(self): for i in range(self.npartitions_out): graph[(self._name, i)] = graph[(shuffled_name, i)] return graph + + _base_lower = dask_expr._shuffle.Shuffle._lower + + def _patched_lower(self): + if self.method in (None, "tasks") and _use_explicit_comms(): + return ECShuffle( + self.frame, + self.partitioning_index, + self.npartitions_out, + self.ignore_index, + self.options, + self.original_partitioning_index, + ) else: - # Use upstream lowering logic - return _base_layer(self) + return _base_lower(self) - dask_expr._shuffle.TaskShuffle._layer = _patched_layer + dask_expr._shuffle.Shuffle._lower = _patched_lower