Skip to content

Commit

Permalink
Add partition-wise Select support to cuDF-Polars (#17495)
Browse files Browse the repository at this point in the history
Adds multi-partition (partition-wise) `Select` support following the same design as #17441

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17495
  • Loading branch information
rjzamora authored Dec 18, 2024
1 parent 1f55d80 commit 0ba1eb9
Show file tree
Hide file tree
Showing 17 changed files with 146 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
self.dtype = dtype
self.name = name
self.options = options
self.is_pointwise = False
self.children = children
if name not in Agg._SUPPORTED:
raise NotImplementedError(
Expand Down
7 changes: 6 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class ExecutionContext(IntEnum):
class Expr(Node["Expr"]):
"""An abstract expression object."""

__slots__ = ("dtype",)
__slots__ = ("dtype", "is_pointwise")
dtype: plc.DataType
"""Data type of the expression."""
is_pointwise: bool
"""Whether this expression acts pointwise on its inputs."""
# This annotation is needed because of https://github.com/python/mypy/issues/17981
_non_child: ClassVar[tuple[str, ...]] = ("dtype",)
"""Names of non-child data (not Exprs) for reconstruction."""
Expand Down Expand Up @@ -164,6 +166,7 @@ def __init__(self, dtype: plc.DataType, error: str) -> None:
self.dtype = dtype
self.error = error
self.children = ()
self.is_pointwise = True


class NamedExpr:
Expand Down Expand Up @@ -243,6 +246,7 @@ class Col(Expr):
def __init__(self, dtype: plc.DataType, name: str) -> None:
self.dtype = dtype
self.name = name
self.is_pointwise = True
self.children = ()

def do_evaluate(
Expand Down Expand Up @@ -280,6 +284,7 @@ def __init__(
self.dtype = dtype
self.index = index
self.table_ref = table_ref
self.is_pointwise = True
self.children = (column,)

def do_evaluate(
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
op = BinOp._BOOL_KLEENE_MAPPING.get(op, op)
self.op = op
self.children = (left, right)
self.is_pointwise = True
if not plc.binaryop.is_supported_operation(
self.dtype, left.dtype, right.dtype, op
):
Expand Down
8 changes: 8 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = self.name not in (
BooleanFunction.Name.All,
BooleanFunction.Name.Any,
BooleanFunction.Name.IsDuplicated,
BooleanFunction.Name.IsFirstDistinct,
BooleanFunction.Name.IsLastDistinct,
BooleanFunction.Name.IsUnique,
)
if self.name is BooleanFunction.Name.IsIn and not all(
c.dtype == self.children[0].dtype for c in self.children
):
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = True
if self.name not in self._COMPONENT_MAP:
raise NotImplementedError(f"Temporal function {self.name}")

Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
assert value.type == plc.interop.to_arrow(dtype)
self.value = value
self.children = ()
self.is_pointwise = True

def do_evaluate(
self,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
self.children = ()
self.is_pointwise = True

def get_hashable(self) -> Hashable:
"""Compute a hash of the column."""
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
self.dtype = dtype
self.options = options
self.children = (agg,)
self.is_pointwise = False
raise NotImplementedError("Rolling window not implemented")


Expand All @@ -35,4 +36,5 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> N
self.dtype = dtype
self.options = options
self.children = (agg, *by)
self.is_pointwise = False
raise NotImplementedError("Grouped rolling window not implemented")
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Gather(Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
self.dtype = dtype
self.children = (values, indices)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -71,6 +72,7 @@ class Filter(Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
self.dtype = dtype
self.children = (values, indices)
self.is_pointwise = True

def do_evaluate(
self,
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.dtype = dtype
self.options = options
self.children = (column,)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self.dtype = dtype
self.options = options
self.children = (column, *by)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = True
self._validate_input()

def _validate_input(self):
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
) -> None:
self.dtype = dtype
self.children = (when, then, otherwise)
self.is_pointwise = True

def do_evaluate(
self,
Expand Down
10 changes: 10 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Cast(Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
self.dtype = dtype
self.children = (value,)
self.is_pointwise = True
if not dtypes.can_cast(value.dtype, self.dtype):
raise NotImplementedError(
f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
Expand Down Expand Up @@ -63,6 +64,7 @@ class Len(Expr):
def __init__(self, dtype: plc.DataType) -> None:
self.dtype = dtype
self.children = ()
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -147,6 +149,14 @@ def __init__(
self.name = name
self.options = options
self.children = children
self.is_pointwise = self.name not in (
"cum_min",
"cum_max",
"cum_prod",
"cum_sum",
"drop_nulls",
"unique",
)

if self.name not in UnaryFunction._supported_fns:
raise NotImplementedError(f"Unary function {name=}")
Expand Down
14 changes: 7 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cudf_polars.typing import U_contra, V_co

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, MutableMapping
from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence

from cudf_polars.typing import GenericTransformer, NodeT

Expand All @@ -23,22 +23,22 @@
]


def traversal(node: NodeT) -> Generator[NodeT, None, None]:
def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
"""
Pre-order traversal of nodes in an expression.
Parameters
----------
node
Root of expression to traverse.
nodes
Roots of expressions to traverse.
Yields
------
Unique nodes in the expression, parent before child, children
Unique nodes in the expressions, parent before child, children
in-order from left to right.
"""
seen = {node}
lifo = [node]
seen = set(nodes)
lifo = list(nodes)

while lifo:
node = lifo.pop()
Expand Down
12 changes: 9 additions & 3 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.io # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Projection, Union
import cudf_polars.experimental.io
import cudf_polars.experimental.select # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def task_graph(
"""
graph = reduce(
operator.or_,
(generate_ir_tasks(node, partition_info) for node in traversal(ir)),
(generate_ir_tasks(node, partition_info) for node in traversal([ir])),
)

key_name = get_key_name(ir)
Expand Down Expand Up @@ -226,6 +227,8 @@ def _lower_ir_pwise(

lower_ir_node.register(Projection, _lower_ir_pwise)
lower_ir_node.register(Cache, _lower_ir_pwise)
lower_ir_node.register(Filter, _lower_ir_pwise)
lower_ir_node.register(HStack, _lower_ir_pwise)


def _generate_ir_tasks_pwise(
Expand All @@ -245,3 +248,6 @@ def _generate_ir_tasks_pwise(

generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Filter, _generate_ir_tasks_pwise)
generate_ir_tasks.register(HStack, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Select, _generate_ir_tasks_pwise)
36 changes: 36 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel Select Logic."""

from __future__ import annotations

from typing import TYPE_CHECKING

from cudf_polars.dsl.ir import Select
from cudf_polars.dsl.traversal import traversal
from cudf_polars.experimental.dispatch import lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.base import PartitionInfo
from cudf_polars.experimental.parallel import LowerIRTransformer


@lower_ir_node.register(Select)
def _(
ir: Select, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
child, partition_info = rec(ir.children[0])
pi = partition_info[child]
if pi.count > 1 and not all(
expr.is_pointwise for expr in traversal([e.value for e in ir.exprs])
):
# TODO: Handle non-pointwise expressions.
raise NotImplementedError(
f"Selection {ir} does not support multiple partitions."
)
new_node = ir.reconstruct([child])
partition_info[new_node] = pi
return new_node, partition_info
6 changes: 3 additions & 3 deletions python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def test_traversal_unique():
dt = plc.DataType(plc.TypeId.INT8)

e1 = make_expr(dt, "a", "a")
unique_exprs = list(traversal(e1))
unique_exprs = list(traversal([e1]))

assert len(unique_exprs) == 2
assert set(unique_exprs) == {expr.Col(dt, "a"), e1}
assert unique_exprs == [e1, expr.Col(dt, "a")]

e2 = make_expr(dt, "a", "b")
unique_exprs = list(traversal(e2))
unique_exprs = list(traversal([e2]))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2}
assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")]

e3 = make_expr(dt, "b", "a")
unique_exprs = list(traversal(e3))
unique_exprs = list(traversal([e3]))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3}
Expand Down
54 changes: 54 additions & 0 deletions python/cudf_polars/tests/experimental/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 3},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [1, 1, 1, 1, 1, 1, 1],
}
)


def test_select(df, engine):
query = df.select(
pl.col("a") + pl.col("b"), (pl.col("a") * 2 + pl.col("b")).alias("d")
)
assert_gpu_result_equal(query, engine=engine)


def test_select_reduce_raises(df, engine):
query = df.select(
(pl.col("a") + pl.col("b")).max(),
(pl.col("a") * 2 + pl.col("b")).alias("d").mean(),
)
with pytest.raises(
pl.exceptions.ComputeError,
match="NotImplementedError",
):
assert_gpu_result_equal(query, engine=engine)


def test_select_with_cse_no_agg(df, engine):
expr = pl.col("a") + pl.col("a")
query = df.select(expr, (expr * 2).alias("b"), ((expr * 2) + 10).alias("c"))
assert_gpu_result_equal(query, engine=engine)

0 comments on commit 0ba1eb9

Please sign in to comment.