Skip to content

Commit

Permalink
Remove recursion in task spec (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 19, 2024
1 parent ef6f27f commit e406551
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def _layer(self) -> dict:
_barrier_key_left,
p2p_barrier,
token_left,
transfer_keys_left,
*transfer_keys_left,
spec=DataFrameShuffleSpec(
id=token_left,
npartitions=self.npartitions,
Expand All @@ -698,7 +698,7 @@ def _layer(self) -> dict:
_barrier_key_right,
p2p_barrier,
token_right,
transfer_keys_right,
*transfer_keys_right,
spec=DataFrameShuffleSpec(
id=token_right,
npartitions=self.npartitions,
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def _layer(self):
_barrier_key,
p2p_barrier,
token,
transfer_keys,
*transfer_keys,
spec=DataFrameShuffleSpec(
id=shuffle_id,
npartitions=self.npartitions_out,
Expand Down
4 changes: 2 additions & 2 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pyarrow as pa
from dask._task_spec import Task
from dask._task_spec import List, Task
from dask.dataframe import methods
from dask.dataframe._pyarrow import to_pyarrow_string
from dask.dataframe.core import apply_and_enforce, is_dataframe_like, make_meta
Expand Down Expand Up @@ -135,7 +135,7 @@ def _task(self, name: Key, index: int) -> Task:
bucket = self._fusion_buckets[index]
# FIXME: This will likely require a wrapper
return Task(
name, methods.concat, [expr._filtered_task(name, i) for i in bucket]
name, methods.concat, List(*(expr._filtered_task(name, i) for i in bucket))
)

@functools.cached_property
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/io/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def test_pickle_size(tmpdir, filesystem):
df = read_parquet(tmpdir, filesystem=filesystem)
from distributed.protocol import dumps

assert len(b"".join(dumps(df.optimize().dask))) <= 9000
assert len(b"".join(dumps(df.optimize().dask))) <= 9100

0 comments on commit e406551

Please sign in to comment.