Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 17, 2024
1 parent 20df12e commit 32bcb22
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,16 @@ def patch_shuffle_expression() -> None:
import, and it changes `TaskShuffle._layer` to execute
an explicit-comms shuffle.
"""
from dask_expr._collection import new_collection
from dask_expr._shuffle import TaskShuffle
import dask_expr

_base_layer = TaskShuffle._layer
_base_layer = dask_expr._shuffle.TaskShuffle._layer

def _layer(self):
def _patched_layer(self):
if _use_explicit_comms():
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
df = new_collection(self.frame)
df = dask_expr._collection.new_collection(self.frame)
self._ec_shuffled = shuffle(
df,
[on] if isinstance(on, str) else on,
Expand All @@ -613,4 +612,4 @@ def _layer(self):
# Use upstream lowering logic
return _base_layer(self)

TaskShuffle._layer = _layer
dask_expr._shuffle.TaskShuffle._layer = _patched_layer

0 comments on commit 32bcb22

Please sign in to comment.