Skip to content

Commit

Permalink
Add more complicated rewrite test
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Oct 17, 2024
1 parent 00c92b5 commit f68d914
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from functools import singledispatch

import pylibcudf as plc

import polars as pl
Expand All @@ -16,6 +18,7 @@
reuse_if_unchanged,
traversal,
)
from cudf_polars.typing import ExprTransformer, IRTransformer


def make_expr(dt, n1, n2):
Expand Down Expand Up @@ -155,3 +158,71 @@ def replace_scan(node, rec):
expect = q.collect()

assert_frame_equal(result, expect, check_row_order=False)


def test_rewrite_names_and_ops():
df = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7], "d": [7, 8, 9]})

q = df.select(pl.col("a") - (pl.col("b") + pl.col("c") * 2), pl.col("d")).sort("d")

expect = (
df.select(
(pl.col("d") - (pl.col("b") * pl.col("d") * 2)).alias("a"), pl.col("d")
)
.sort("d")
.collect()
)

qir = translate_ir(q._ldf.visit())

@singledispatch
def _transform(e: expr.Expr, fn: ExprTransformer) -> expr.Expr:
raise NotImplementedError("Unhandled")

@_transform.register
def _(e: expr.Col, fn: ExprTransformer):
mapping = fn.state["mapping"]
if e.name in mapping:
return type(e)(e.dtype, mapping[e.name])
return e

@_transform.register
def _(e: expr.BinOp, fn: ExprTransformer):
if e.op == plc.binaryop.BinaryOperator.ADD:
return type(e)(
e.dtype, plc.binaryop.BinaryOperator.MUL, *map(fn, e.children)
)
return reuse_if_unchanged(e, fn)

_transform.register(expr.Expr)(reuse_if_unchanged)

@singledispatch
def _rewrite(node: ir.IR, fn: IRTransformer) -> ir.IR:
raise NotImplementedError("Unhandled")

@_rewrite.register
def _(node: ir.Select, fn: IRTransformer):
expr_mapper = fn.state["expr_mapper"]
return type(node)(
node.schema,
[expr.NamedExpr(e.name, expr_mapper(e.value)) for e in node.exprs],
node.should_broadcast,
fn(node.children[0]),
)

_rewrite.register(ir.IR)(reuse_if_unchanged)

rewriter = CachingVisitor(
_rewrite,
state={
"expr_mapper": CachingVisitor(
_transform, state={"mapping": {"a": "d", "c": "d"}}
)
},
)

new_ir = rewriter(qir)

got = new_ir.evaluate(cache={}).to_polars()

assert_frame_equal(expect, got)

0 comments on commit f68d914

Please sign in to comment.