diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py index 624a9bd87ea..2ba483c7b2d 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py @@ -40,6 +40,7 @@ def __init__( self.dtype = dtype self.name = name self.options = options + self.is_pointwise = False self.children = children if name not in Agg._SUPPORTED: raise NotImplementedError( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/base.py b/python/cudf_polars/cudf_polars/dsl/expressions/base.py index 4c7ae007070..8ba3f9f407c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/base.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/base.py @@ -36,9 +36,11 @@ class ExecutionContext(IntEnum): class Expr(Node["Expr"]): """An abstract expression object.""" - __slots__ = ("dtype",) + __slots__ = ("dtype", "is_pointwise") dtype: plc.DataType """Data type of the expression.""" + is_pointwise: bool + """Whether this expression acts pointwise on its inputs.""" # This annotation is needed because of https://github.com/python/mypy/issues/17981 _non_child: ClassVar[tuple[str, ...]] = ("dtype",) """Names of non-child data (not Exprs) for reconstruction.""" @@ -164,6 +166,7 @@ def __init__(self, dtype: plc.DataType, error: str) -> None: self.dtype = dtype self.error = error self.children = () + self.is_pointwise = True class NamedExpr: @@ -243,6 +246,7 @@ class Col(Expr): def __init__(self, dtype: plc.DataType, name: str) -> None: self.dtype = dtype self.name = name + self.is_pointwise = True self.children = () def do_evaluate( @@ -280,6 +284,7 @@ def __init__( self.dtype = dtype self.index = index self.table_ref = table_ref + self.is_pointwise = True self.children = (column,) def do_evaluate( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py index 245bdbefe88..556847b4738 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py @@ -42,6 +42,7 @@ def __init__( op = BinOp._BOOL_KLEENE_MAPPING.get(op, op) self.op = op self.children = (left, right) + self.is_pointwise = True if not plc.binaryop.is_supported_operation( self.dtype, left.dtype, right.dtype, op ): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index 5aa35ead127..d5ca22dd8d5 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -81,6 +81,14 @@ def __init__( self.options = options self.name = name self.children = children + self.is_pointwise = self.name not in ( + BooleanFunction.Name.All, + BooleanFunction.Name.Any, + BooleanFunction.Name.IsDuplicated, + BooleanFunction.Name.IsFirstDistinct, + BooleanFunction.Name.IsLastDistinct, + BooleanFunction.Name.IsUnique, + ) if self.name is BooleanFunction.Name.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index c2dddfd9940..0c3159c73d6 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -114,6 +114,7 @@ def __init__( self.options = options self.name = name self.children = children + self.is_pointwise = True if self.name not in self._COMPONENT_MAP: raise NotImplementedError(f"Temporal function {self.name}") diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py index 7eba0c110ab..8528e66c69c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py @@ -38,6 +38,7 @@ def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None: assert value.type == plc.interop.to_arrow(dtype) self.value = value self.children = () + self.is_pointwise = True def do_evaluate( self, @@ -65,6 +66,7 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None: data = value.to_arrow() self.value = data.cast(dtypes.downcast_arrow_lists(data.type)) self.children = () + self.is_pointwise = True def get_hashable(self) -> Hashable: """Compute a hash of the column.""" diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py index 48c37d101f4..d4616d5d00a 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py @@ -24,6 +24,7 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None: self.dtype = dtype self.options = options self.children = (agg,) + self.is_pointwise = False raise NotImplementedError("Rolling window not implemented") @@ -35,4 +36,5 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> N self.dtype = dtype self.options = options self.children = (agg, *by) + self.is_pointwise = False raise NotImplementedError("Grouped rolling window not implemented") diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py index 12326740f74..93ecd026eaf 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py @@ -30,6 +30,7 @@ class Gather(Expr): def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None: self.dtype = dtype self.children = (values, indices) + self.is_pointwise = False def do_evaluate( self, @@ -71,6 +72,7 @@ class Filter(Expr): def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr): self.dtype = dtype self.children = (values, indices) + self.is_pointwise = True def do_evaluate( self, diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py index 99512e2ef52..189f109e1a2 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py @@ -32,6 +32,7 @@ def __init__( self.dtype = dtype self.options = options self.children = (column,) + self.is_pointwise = False def do_evaluate( self, @@ -71,6 +72,7 @@ def __init__( self.dtype = dtype self.options = options self.children = (column, *by) + self.is_pointwise = False def do_evaluate( self, diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 124a6e8d71c..256840c1f3d 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -106,6 +106,7 @@ def __init__( self.options = options self.name = name self.children = children + self.is_pointwise = True self._validate_input() def _validate_input(self): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py index d2b5d6bae29..120ca8edce0 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py @@ -34,6 +34,7 @@ def __init__( ) -> None: self.dtype = dtype self.children = (when, then, otherwise) + self.is_pointwise = True def do_evaluate( self, diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py index 10caaff6811..3336c901e7f 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py @@ -33,6 +33,7 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: self.dtype = dtype self.children = (value,) + self.is_pointwise = True if not dtypes.can_cast(value.dtype, self.dtype): raise NotImplementedError( f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}" @@ -63,6 +64,7 @@ class Len(Expr): def __init__(self, dtype: plc.DataType) -> None: self.dtype = dtype self.children = () + self.is_pointwise = False def do_evaluate( self, @@ -147,6 +149,14 @@ def __init__( self.name = name self.options = options self.children = children + self.is_pointwise = self.name not in ( + "cum_min", + "cum_max", + "cum_prod", + "cum_sum", + "drop_nulls", + "unique", + ) if self.name not in UnaryFunction._supported_fns: raise NotImplementedError(f"Unary function {name=}") diff --git a/python/cudf_polars/cudf_polars/dsl/traversal.py b/python/cudf_polars/cudf_polars/dsl/traversal.py index b3248dae93c..9c45a68812a 100644 --- a/python/cudf_polars/cudf_polars/dsl/traversal.py +++ b/python/cudf_polars/cudf_polars/dsl/traversal.py @@ -10,7 +10,7 @@ from cudf_polars.typing import U_contra, V_co if TYPE_CHECKING: - from collections.abc import Callable, Generator, Mapping, MutableMapping + from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence from cudf_polars.typing import GenericTransformer, NodeT @@ -23,22 +23,22 @@ ] -def traversal(node: NodeT) -> Generator[NodeT, None, None]: +def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]: """ Pre-order traversal of nodes in an expression. Parameters ---------- - node - Root of expression to traverse. + nodes + Roots of expressions to traverse. Yields ------ - Unique nodes in the expression, parent before child, children + Unique nodes in the expressions, parent before child, children in-order from left to right. """ - seen = {node} - lifo = [node] + seen = set(nodes) + lifo = list(nodes) while lifo: node = lifo.pop() diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index e5884f1c574..6843ed9ee2e 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -9,8 +9,9 @@ from functools import reduce from typing import TYPE_CHECKING, Any -import cudf_polars.experimental.io # noqa: F401 -from cudf_polars.dsl.ir import IR, Cache, Projection, Union +import cudf_polars.experimental.io +import cudf_polars.experimental.select # noqa: F401 +from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union from cudf_polars.dsl.traversal import CachingVisitor, traversal from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name from cudf_polars.experimental.dispatch import ( @@ -112,7 +113,7 @@ def task_graph( """ graph = reduce( operator.or_, - (generate_ir_tasks(node, partition_info) for node in traversal(ir)), + (generate_ir_tasks(node, partition_info) for node in traversal([ir])), ) key_name = get_key_name(ir) @@ -226,6 +227,8 @@ def _lower_ir_pwise( lower_ir_node.register(Projection, _lower_ir_pwise) lower_ir_node.register(Cache, _lower_ir_pwise) +lower_ir_node.register(Filter, _lower_ir_pwise) +lower_ir_node.register(HStack, _lower_ir_pwise) def _generate_ir_tasks_pwise( @@ -245,3 +248,6 @@ def _generate_ir_tasks_pwise( generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise) generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise) +generate_ir_tasks.register(Filter, _generate_ir_tasks_pwise) +generate_ir_tasks.register(HStack, _generate_ir_tasks_pwise) +generate_ir_tasks.register(Select, _generate_ir_tasks_pwise) diff --git a/python/cudf_polars/cudf_polars/experimental/select.py b/python/cudf_polars/cudf_polars/experimental/select.py new file mode 100644 index 00000000000..5f79384b569 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/select.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Parallel Select Logic.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from cudf_polars.dsl.ir import Select +from cudf_polars.dsl.traversal import traversal +from cudf_polars.experimental.dispatch import lower_ir_node + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from cudf_polars.dsl.ir import IR + from cudf_polars.experimental.base import PartitionInfo + from cudf_polars.experimental.parallel import LowerIRTransformer + + +@lower_ir_node.register(Select) +def _( + ir: Select, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + child, partition_info = rec(ir.children[0]) + pi = partition_info[child] + if pi.count > 1 and not all( + expr.is_pointwise for expr in traversal([e.value for e in ir.exprs]) + ): + # TODO: Handle non-pointwise expressions. + raise NotImplementedError( + f"Selection {ir} does not support multiple partitions." + ) + new_node = ir.reconstruct([child]) + partition_info[new_node] = pi + return new_node, partition_info diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index 9755994c419..9fcca2e290e 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -32,21 +32,21 @@ def test_traversal_unique(): dt = plc.DataType(plc.TypeId.INT8) e1 = make_expr(dt, "a", "a") - unique_exprs = list(traversal(e1)) + unique_exprs = list(traversal([e1])) assert len(unique_exprs) == 2 assert set(unique_exprs) == {expr.Col(dt, "a"), e1} assert unique_exprs == [e1, expr.Col(dt, "a")] e2 = make_expr(dt, "a", "b") - unique_exprs = list(traversal(e2)) + unique_exprs = list(traversal([e2])) assert len(unique_exprs) == 3 assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2} assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")] e3 = make_expr(dt, "b", "a") - unique_exprs = list(traversal(e3)) + unique_exprs = list(traversal([e3])) assert len(unique_exprs) == 3 assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3} diff --git a/python/cudf_polars/tests/experimental/test_select.py b/python/cudf_polars/tests/experimental/test_select.py new file mode 100644 index 00000000000..7dfe6ead148 --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_select.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(scope="module") +def engine(): + return pl.GPUEngine( + raise_on_fail=True, + executor="dask-experimental", + executor_options={"max_rows_per_partition": 3}, + ) + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7], + "b": [1, 1, 1, 1, 1, 1, 1], + } + ) + + +def test_select(df, engine): + query = df.select( + pl.col("a") + pl.col("b"), (pl.col("a") * 2 + pl.col("b")).alias("d") + ) + assert_gpu_result_equal(query, engine=engine) + + +def test_select_reduce_raises(df, engine): + query = df.select( + (pl.col("a") + pl.col("b")).max(), + (pl.col("a") * 2 + pl.col("b")).alias("d").mean(), + ) + with pytest.raises( + pl.exceptions.ComputeError, + match="NotImplementedError", + ): + assert_gpu_result_equal(query, engine=engine) + + +def test_select_with_cse_no_agg(df, engine): + expr = pl.col("a") + pl.col("a") + query = df.select(expr, (expr * 2).alias("b"), ((expr * 2) + 10).alias("c")) + assert_gpu_result_equal(query, engine=engine)