Skip to content

Commit

Permalink
avoid get_default_shuffle_method patching by patching Shuffle._lower
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 18, 2024
1 parent 1c0d974 commit a7b20f7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
6 changes: 0 additions & 6 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ._version import __git_commit__, __version__
from .cuda_worker import CUDAWorker
from .explicit_comms.dataframe.shuffle import (
get_default_shuffle_method,
patch_shuffle_expression,
)
from .local_cuda_cluster import LocalCUDACluster
Expand All @@ -24,11 +23,6 @@

# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
patch_shuffle_expression()
# We have to replace all modules that imports Dask's `get_default_shuffle_method()`
# TODO: introduce a shuffle-algorithm dispatcher in Dask so we don't need this hack
dask.dataframe.shuffle.get_default_shuffle_method = get_default_shuffle_method
dask.dataframe.multi.get_default_shuffle_method = get_default_shuffle_method
dask.bag.core.get_default_shuffle_method = get_default_shuffle_method


# Monkey patching Dask to make use of proxify and unproxify in compatibility mode
Expand Down
41 changes: 21 additions & 20 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,31 +568,20 @@ def _use_explicit_comms() -> bool:
return False


def get_default_shuffle_method() -> str:
"""Return the default shuffle algorithm used by Dask
This changes the default shuffle algorithm from "p2p" to "tasks"
when explicit comms is enabled.
"""
ret = dask.config.get("dataframe.shuffle.algorithm", None)
if ret is None and _use_explicit_comms():
return "tasks"
return dask.utils.get_default_shuffle_method()


def patch_shuffle_expression() -> None:
"""Patch Dasks Shuffle expression.
Notice, this is monkey patched into Dask at dask_cuda
import, and it changes `TaskShuffle._layer` to execute
an explicit-comms shuffle.
import, and it changes `Shuffle._layer` to lower into
an `ECShuffle` expression when the 'explicit-comms'
config is set to `True`.
"""
import dask_expr

_base_layer = dask_expr._shuffle.TaskShuffle._layer
class ECShuffle(dask_expr._shuffle.TaskShuffle):
"""Explicit-Comms Shuffle Expression."""

def _patched_layer(self):
if _use_explicit_comms():
def _layer(self):
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
Expand All @@ -608,8 +597,20 @@ def _patched_layer(self):
for i in range(self.npartitions_out):
graph[(self._name, i)] = graph[(shuffled_name, i)]
return graph

_base_lower = dask_expr._shuffle.Shuffle._lower

def _patched_lower(self):
if self.method in (None, "tasks") and _use_explicit_comms():
return ECShuffle(
self.frame,
self.partitioning_index,
self.npartitions_out,
self.ignore_index,
self.options,
self.original_partitioning_index,
)
else:
# Use upstream lowering logic
return _base_layer(self)
return _base_lower(self)

dask_expr._shuffle.TaskShuffle._layer = _patched_layer
dask_expr._shuffle.Shuffle._lower = _patched_lower

0 comments on commit a7b20f7

Please sign in to comment.