Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-partition Shuffle operation to cuDF Polars #17744

Merged
merged 23 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
89392c0
try importing dask_expr from dask.dataframe
rjzamora Jan 9, 2025
2a6821d
Merge remote-tracking branch 'upstream/branch-25.02' into dask-expr-m…
rjzamora Jan 9, 2025
5743030
Merge branch 'branch-25.02' into dask-expr-migration
rjzamora Jan 11, 2025
7d36d3b
Merge remote-tracking branch 'upstream/branch-25.02' into dask-expr-m…
rjzamora Jan 14, 2025
88e078d
update the error message
rjzamora Jan 14, 2025
1f77ec4
add basic shuffle support
rjzamora Jan 14, 2025
8c52fde
major revision
rjzamora Jan 15, 2025
0886ab7
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Jan 15, 2025
f714a51
roll back unrelated changes
rjzamora Jan 15, 2025
677ef36
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Jan 22, 2025
6b0b9f1
address some code review
rjzamora Jan 22, 2025
4da24b1
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Jan 22, 2025
c7b81e3
check the result
rjzamora Jan 22, 2025
fd6e39c
fix test
rjzamora Jan 22, 2025
ecba98d
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora Jan 23, 2025
f02c146
simplify Shuffle (only handle hash-based partitioning for now)
rjzamora Jan 23, 2025
8604e1b
remove multi-child validation
rjzamora Jan 23, 2025
86fad9d
Merge remote-tracking branch 'upstream/branch-25.04' into cudf-polars…
rjzamora Jan 27, 2025
9624396
address code review
rjzamora Jan 27, 2025
264fcfd
avoid shuffling single partition
rjzamora Jan 27, 2025
82f9c78
fix test bug
rjzamora Jan 27, 2025
a502f71
turn do_evaluate back into a no-op
rjzamora Jan 28, 2025
08b4db5
Merge branch 'branch-25.04' into cudf-polars-multi-shuffle
rjzamora Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions python/cudf_polars/cudf_polars/experimental/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Multi-partition base classes."""

Expand All @@ -12,20 +12,26 @@
from collections.abc import Iterator, Sequence

from cudf_polars.containers import DataFrame
from cudf_polars.dsl.expr import NamedExpr
from cudf_polars.dsl.nodebase import Node


class PartitionInfo:
"""
Partitioning information.

This class only tracks the partition count (for now).
"""

__slots__ = ("count",)

def __init__(self, count: int):
"""Partitioning information."""

__slots__ = ("count", "partitioned_on")
count: int
"""Partition count."""
partitioned_on: tuple[NamedExpr, ...]
"""Columns the data is hash-partitioned on."""

def __init__(
self,
count: int,
partitioned_on: tuple[NamedExpr, ...] = (),
):
self.count = count
self.partitioned_on = partitioned_on

def keys(self, node: Node) -> Iterator[tuple[str, int]]:
"""Return the partitioned keys for a given node."""
Expand Down
5 changes: 3 additions & 2 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Multi-partition Dask execution."""

Expand All @@ -10,7 +10,8 @@
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.io
import cudf_polars.experimental.select # noqa: F401
import cudf_polars.experimental.select
import cudf_polars.experimental.shuffle # noqa: F401
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
Expand Down
204 changes: 204 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Shuffle Logic."""

from __future__ import annotations

import json
import operator
from typing import TYPE_CHECKING, Any

import pyarrow as pa

import pylibcudf as plc

from cudf_polars.containers import DataFrame
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.base import _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node

if TYPE_CHECKING:
from collections.abc import Hashable, MutableMapping

from cudf_polars.dsl.expr import NamedExpr
from cudf_polars.experimental.dispatch import LowerIRTransformer
from cudf_polars.experimental.parallel import PartitionInfo
from cudf_polars.typing import Schema


class Shuffle(IR):
"""
Shuffle multi-partition data.

Notes
-----
Only hash-based partitioning is supported (for now).
"""

__slots__ = ("keys", "options")
_non_child = ("schema", "keys", "options")
keys: tuple[NamedExpr, ...]
"""Keys to shuffle on."""
options: dict[str, Any]
"""Shuffling options."""

def __init__(
self,
schema: Schema,
keys: tuple[NamedExpr, ...],
options: dict[str, Any],
df: IR,
):
self.schema = schema
self.keys = keys
self.options = options
self._non_child_args = (schema, keys, options)
self.children = (df,)

def get_hashable(self) -> Hashable:
"""Hashable representation of the node."""
return (
type(self),
tuple(self.schema.items()),
self.keys,
json.dumps(self.options),
self.children,
)

@classmethod
def do_evaluate(
cls,
schema: Schema,
keys: tuple[NamedExpr, ...],
options: dict[str, Any],
df: DataFrame,
): # pragma: no cover
"""Evaluate and return a dataframe."""
# Single-partition Shuffle evaluation is a no-op
return df


def _partition_dataframe(
df: DataFrame,
keys: tuple[NamedExpr, ...],
count: int,
) -> dict[int, DataFrame]:
"""
Partition an input DataFrame for shuffling.

Notes
-----
This utility only supports hash partitioning (for now).

Parameters
----------
df
DataFrame to partition.
keys
Shuffle key(s).
count
Total number of output partitions.

Returns
-------
A dictionary mapping between int partition indices and
DataFrame fragments.
"""
# Hash the specified keys to calculate the output
# partition for each row
partition_map = plc.binaryop.binary_operation(
plc.hashing.murmurhash3_x86_32(
DataFrame([expr.evaluate(df) for expr in keys]).table
),
plc.interop.from_arrow(pa.scalar(count, type="uint32")),
plc.binaryop.BinaryOperator.PYMOD,
plc.types.DataType(plc.types.TypeId.UINT32),
)

# Apply partitioning
t, offsets = plc.partitioning.partition(
df.table,
partition_map,
count,
)

# Split and return the partitioned result
return {
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
i: DataFrame.from_table(
split,
df.column_names,
)
for i, split in enumerate(plc.copying.split(t, offsets[1:-1]))
}


def _simple_shuffle_graph(
name_out: str,
name_in: str,
keys: tuple[NamedExpr, ...],
count_in: int,
count_out: int,
) -> MutableMapping[Any, Any]:
"""Make a simple all-to-all shuffle graph."""
split_name = f"split-{name_out}"
inter_name = f"inter-{name_out}"

graph: MutableMapping[Any, Any] = {}
for part_out in range(count_out):
_concat_list = []
for part_in in range(count_in):
graph[(split_name, part_in)] = (
_partition_dataframe,
(name_in, part_in),
keys,
count_out,
)
_concat_list.append((inter_name, part_out, part_in))
graph[_concat_list[-1]] = (
operator.getitem,
(split_name, part_in),
part_out,
)
graph[(name_out, part_out)] = (_concat, _concat_list)
return graph


@lower_ir_node.register(Shuffle)
def _(
ir: Shuffle, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Simple lower_ir_node handling for the default hash-based shuffle.
# More-complex logic (e.g. joining and sorting) should
# be handled separately.
from cudf_polars.experimental.parallel import PartitionInfo

(child,) = ir.children

new_child, pi = rec(child)
if pi[new_child].count == 1 or ir.keys == pi[new_child].partitioned_on:
# Already shuffled
return new_child, pi
new_node = ir.reconstruct([new_child])
pi[new_node] = PartitionInfo(
# Default shuffle preserves partition count
count=pi[new_child].count,
# Add partitioned_on info
partitioned_on=ir.keys,
)
return new_node, pi


@generate_ir_tasks.register(Shuffle)
def _(
ir: Shuffle, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
# Use a simple all-to-all shuffle graph.

# TODO: Optionally use rapidsmp.
return _simple_shuffle_graph(
get_key_name(ir),
get_key_name(ir.children[0]),
ir.keys,
partition_info[ir.children[0]].count,
partition_info[ir].count,
)
66 changes: 66 additions & 0 deletions python/cudf_polars/tests/experimental/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl
from polars.testing import assert_frame_equal

from cudf_polars import Translator
from cudf_polars.dsl.expr import Col, NamedExpr
from cudf_polars.experimental.parallel import evaluate_dask, lower_ir_graph
from cudf_polars.experimental.shuffle import Shuffle


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 4},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"x": [1, 2, 3, 4, 5, 6, 7],
"y": [1, 1, 1, 1, 1, 1, 1],
"z": ["a", "b", "c", "d", "e", "f", "g"],
}
)


def test_hash_shuffle(df, engine):
# Extract translated IR
qir = Translator(df._ldf.visit(), engine).translate_ir()

# Add first Shuffle node
keys = (NamedExpr("x", Col(qir.schema["x"], "x")),)
options = {}
qir1 = Shuffle(qir.schema, keys, options, qir)

# Add second Shuffle node (on the same keys)
qir2 = Shuffle(qir.schema, keys, options, qir1)

# Check that sequential shuffles on the same keys
# are replaced with a single shuffle node
partition_info = lower_ir_graph(qir2)[1]
assert len([node for node in partition_info if isinstance(node, Shuffle)]) == 1

# Add second Shuffle node (on different keys)
keys2 = (NamedExpr("z", Col(qir.schema["z"], "z")),)
qir3 = Shuffle(qir2.schema, keys2, options, qir2)

# Check that we have an additional shuffle
# node after shuffling on different keys
partition_info = lower_ir_graph(qir3)[1]
assert len([node for node in partition_info if isinstance(node, Shuffle)]) == 2

# Check that Dask evaluation works
result = evaluate_dask(qir3).to_polars()
expect = df.collect(engine="cpu")
assert_frame_equal(result, expect, check_row_order=False)
Loading