Skip to content

Commit

Permalink
introduce ExplicitCommsShuffle wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 17, 2024
1 parent 1af30be commit 3cf956b
Showing 1 changed file with 47 additions and 10 deletions.
57 changes: 47 additions & 10 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,24 +583,61 @@ def get_default_shuffle_method() -> str:
def patch_shuffle_expression() -> None:
"""Patch Dasks Shuffle expression.
This changes ``Shuffle._lower`` to apply explicit-comms
shuffling when the 'explicit-comms' config is enabled.
Notice, this is monkey patched into Dask at dask_cuda
import, and it changes `Shuffle._lower` to wrap the
original shuffle expression in `ExplicitCommsShuffle`.
"""
from dask_expr._collection import new_collection
from dask_expr._expr import Expr
from dask_expr._shuffle import Shuffle as DXShuffle

class ExplicitCommsShuffle(Expr):
"""Explicit Comms Shuffle."""

_parameters = ["wrapped"]

@property
def original(self):
assert len(self.wrapped) == 1, f"Unexpected parameters: {self.wrapped[1:]}"
return self.wrapped[0]

@property
def _meta(self):
return self.original.frame._meta

def _lower(self):
return None

def _divisions(self):
return (None,) * (self.original.frame.npartitions + 1)

def _layer(self):
if not hasattr(self, "_shuffle_cache"):
self._shuffle_cache = {}
try:
expr = self._shuffle_cache[self._name]
except KeyError:
on = self.original.partitioning_index
expr = shuffle(
new_collection(self.original.frame),
[on] if isinstance(on, str) else on,
self.original.npartitions_out,
self.original.ignore_index,
)
self._shuffle_cache[self._name] = expr
graph = expr.dask.copy()
graph.update(
{(self._name, i): (expr._name, i) for i in range(self.npartitions)}
)
return graph

_base_lower = DXShuffle._lower

def _lower(self):
if self.method in ("tasks", None) and _use_explicit_comms():
on = self.partitioning_index
on = [on] if isinstance(on, str) else on
return shuffle(
new_collection(self.frame),
on,
self.npartitions_out,
self.ignore_index,
).expr
# Wrap the original Shuffle in an ExplicitCommsShuffle
# (Use list argument to encapsulate dependencies)
return ExplicitCommsShuffle([self])
else:
# Use upstream lowering logic
return _base_lower(self)
Expand Down

0 comments on commit 3cf956b

Please sign in to comment.