diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 227b67d1..79e54060 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -587,17 +587,16 @@ def patch_shuffle_expression() -> None: import, and it changes `TaskShuffle._layer` to execute an explicit-comms shuffle. """ - from dask_expr._collection import new_collection - from dask_expr._shuffle import TaskShuffle + import dask_expr - _base_layer = TaskShuffle._layer + _base_layer = dask_expr._shuffle.TaskShuffle._layer - def _layer(self): + def _patched_layer(self): if _use_explicit_comms(): # Execute an explicit-comms shuffle if not hasattr(self, "_ec_shuffled"): on = self.partitioning_index - df = new_collection(self.frame) + df = dask_expr._collection.new_collection(self.frame) self._ec_shuffled = shuffle( df, [on] if isinstance(on, str) else on, @@ -613,4 +612,4 @@ def _layer(self): # Use upstream lowering logic return _base_layer(self) - TaskShuffle._layer = _layer + dask_expr._shuffle.TaskShuffle._layer = _patched_layer