Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add horizontal aggregations and inequality joins #39

Merged
merged 7 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate_col_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def indent(s: str, by: int) -> str:

for line in file:
new_file_contents += line
if line.startswith(" return LiteralCol"):
if line.startswith("# --- from here the code is generated ---"):
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if isinstance(op, Operator) and not op.generate_expr_method:
Expand Down
18,345 changes: 12,519 additions & 5,826 deletions pixi.lock

Large diffs are not rendered by default.

41 changes: 21 additions & 20 deletions src/pydiverse/transform/_internal/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Order,
)
from pydiverse.transform._internal.tree.types import Bool, Datetime, Dtype, String
from pydiverse.transform._internal.util.warnings import warn_non_standard


class MsSqlImpl(SqlImpl):
Expand Down Expand Up @@ -194,49 +193,51 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr

@impl(ops.equal, String(), String())
def _eq(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
return (sqa.func.LENGTH(x + "a") == sqa.func.LENGTH(y + "a")) & (
x.collate("Latin1_General_bin") == y
)
return x == y

@impl(ops.not_equal, String(), String())
def _ne(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
return (sqa.func.LENGTH(x + "a") != sqa.func.LENGTH(y + "a")) | (
x.collate("Latin1_General_bin") != y
)
return x != y

@impl(ops.less_than, String(), String())
def _lt(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1)
return (x.collate("Latin1_General_bin") < y_) | (
(sqa.func.LENGTH(x + "a") < sqa.func.LENGTH(y + "a"))
& (x.collate("Latin1_General_bin") == y_)
)
return x < y

@impl(ops.less_equal, String(), String())
def _le(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1)
return (x.collate("Latin1_General_bin") < y_) | (
(sqa.func.LENGTH(x + "a") <= sqa.func.LENGTH(y + "a"))
& (x.collate("Latin1_General_bin") == y_)
)
return x <= y

@impl(ops.greater_than, String(), String())
def _gt(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1)
return (x.collate("Latin1_General_bin") > y_) | (
(sqa.func.LENGTH(x + "a") > sqa.func.LENGTH(y + "a"))
& (x.collate("Latin1_General_bin") == y_)
)
return x > y

@impl(ops.greater_equal, String(), String())
def _ge(x, y):
warn_non_standard(
"MSSQL ignores trailing whitespace when comparing strings",
y_ = sqa.func.SUBSTRING(y, 1, sqa.func.LENGTH(x + "a") - 1)
return (x.collate("Latin1_General_bin") > y_) | (
(sqa.func.LENGTH(x + "a") >= sqa.func.LENGTH(y + "a"))
& (x.collate("Latin1_General_bin") == y_)
)
return x >= y

@impl(ops.str_len)
def _str_length(x):
return sqa.func.LENGTH(x + "a", type_=sqa.Integer()) - 1
return sqa.func.LENGTH(x + "a", type_=sqa.BigInteger()) - 1

@impl(ops.str_replace_all)
def _str_replace_all(x, y, z):
Expand Down
104 changes: 77 additions & 27 deletions src/pydiverse/transform/_internal/backend/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Float,
Float32,
Float64,
Int,
Int8,
Int16,
Int32,
Expand Down Expand Up @@ -63,7 +64,7 @@ def export(
schema_overrides: dict,
) -> Any:
lf, _, select, _ = compile_ast(nd)
lf = lf.select(select)
lf = lf.select(*select)
if isinstance(target, Polars):
if not target.lazy:
lf = lf.collect()
Expand Down Expand Up @@ -208,6 +209,12 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
return pl.lit(expr.val, dtype=polars_type(expr.dtype()))

elif isinstance(expr, Cast):
if (
expr.target_type <= Int()
or expr.target_type <= Float()
and expr.val.dtype() <= String()
):
expr.val = expr.val.str.strip()
compiled = compile_col_expr(expr.val, name_in_df).cast(
polars_type(expr.target_type)
)
Expand All @@ -221,23 +228,12 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
raise AssertionError


def compile_join_cond(
expr: ColExpr, name_in_df: dict[UUID, str]
) -> list[tuple[pl.Expr, pl.Expr]]:
if isinstance(expr, ColFn):
if expr.op == ops.bool_and:
return compile_join_cond(expr.args[0], name_in_df) + compile_join_cond(
expr.args[1], name_in_df
)
if expr.op == ops.equal:
return [
(
compile_col_expr(expr.args[0], name_in_df),
compile_col_expr(expr.args[1], name_in_df),
)
]

raise AssertionError()
def split_join_cond(expr: ColFn) -> list[ColFn]:
assert isinstance(expr, ColFn)
if expr.op == ops.bool_and:
return split_join_cond(expr.args[0]) + split_join_cond(expr.args[1])
else:
return [expr]


def compile_ast(
Expand Down Expand Up @@ -344,23 +340,71 @@ def has_path_to_leaf_without_agg(expr: ColExpr):

elif isinstance(nd, verbs.Join):
right_df, right_name_in_df, right_select, _ = compile_ast(nd.right)
assert not set(right_name_in_df.keys()) & set(name_in_df.keys())
name_in_df.update(
{uid: name + nd.suffix for uid, name in right_name_in_df.items()}
)
left_on, right_on = zip(*compile_join_cond(nd.on, name_in_df), strict=True)

assert len(partition_by) == 0
select += [col_name + nd.suffix for col_name in right_select]

df = df.join(
right_df.rename({name: name + nd.suffix for name in right_df.columns}),
left_on=left_on,
right_on=right_on,
how=nd.how,
validate=nd.validate,
coalesce=False,
predicates = split_join_cond(nd.on)
right_df = right_df.rename(
{name: name + nd.suffix for name in right_df.columns}
)

if all(pred.op == ops.equal for pred in predicates):
left_on = []
right_on = []
for pred in predicates:
left_on.append(pred.args[0])
right_on.append(pred.args[1])

left_is_left = None
for e in pred.args[0].iter_subtree():
if isinstance(e, Col):
left_is_left = e._uuid not in right_name_in_df
assert e._uuid in name_in_df
break
assert left_is_left is not None

if not left_is_left:
left_on[-1], right_on[-1] = right_on[-1], left_on[-1]

df = df.join(
right_df,
left_on=[compile_col_expr(col, name_in_df) for col in left_on],
right_on=[compile_col_expr(col, name_in_df) for col in right_on],
how=nd.how,
validate=nd.validate,
coalesce=False,
)
else:
if nd.how in ("left", "full"):
df = df.with_columns(
__LEFT_INDEX__=pl.int_range(0, pl.len(), dtype=pl.Int64)
)
if nd.how in ("full"):
right_df = right_df.with_columns(
__RIGHT_INDEX__=pl.int_range(0, pl.len(), dtype=pl.Int64)
)

joined = df.join_where(
right_df, *(compile_col_expr(pred, name_in_df) for pred in predicates)
)

if nd.how in ("left", "full"):
joined = df.join(joined, on="__LEFT_INDEX__", how="left").drop(
"__LEFT_INDEX__"
)
if nd.how in ("full"):
joined = joined.join(right_df, on="__RIGHT_INDEX__", how="right").drop(
"__RIGHT_INDEX__"
)

df = joined

select += [col_name + nd.suffix for col_name in right_select]

elif isinstance(nd, PolarsImpl):
df = nd.df
name_in_df = {col._uuid: col.name for col in nd.cols.values()}
Expand Down Expand Up @@ -566,6 +610,8 @@ def _shift(x, n, fill_value=None):

@impl(ops.is_in)
def _is_in(x, *values, _pdt_args):
if len(values) == 0:
return pl.lit(False)
return pl.any_horizontal(
(x == val if not arg.dtype() <= NullType() else x.is_null())
for val, arg in zip(values, _pdt_args[1:], strict=True)
Expand Down Expand Up @@ -678,3 +724,7 @@ def _is_nan(x):
@impl(ops.is_not_nan)
def _is_not_nan(x):
return x.is_not_nan()

@impl(ops.coalesce)
def _coalesce(*x):
return pl.coalesce(*x)
4 changes: 2 additions & 2 deletions src/pydiverse/transform/_internal/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def _dt_millisecond(x):
sqa.extract("milliseconds", x) % _1000, type_=sqa.Integer()
)

@impl(ops.horizontal_max, String(), ...)
@impl(ops.horizontal_max, String(), String(), ...)
def _horizontal_max(*x):
return sqa.func.GREATEST(*(sqa.collate(e, "POSIX") for e in x))

@impl(ops.horizontal_min, String(), ...)
@impl(ops.horizontal_min, String(), String(), ...)
def _least(*x):
return sqa.func.LEAST(*(sqa.collate(e, "POSIX") for e in x))

Expand Down
4 changes: 4 additions & 0 deletions src/pydiverse/transform/_internal/backend/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,3 +939,7 @@ def _is_inf(x, *, _Impl):
@impl(ops.is_not_inf)
def _is_not_inf(x, *, _Impl):
return x != _Impl.inf()

@impl(ops.coalesce)
def _coalesce(*x):
return sqa.func.coalesce(*x)
10 changes: 6 additions & 4 deletions src/pydiverse/transform/_internal/ops/ops/horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

from pydiverse.transform._internal.ops.op import Operator
from pydiverse.transform._internal.ops.signature import Signature
from pydiverse.transform._internal.tree.types import COMPARABLE
from pydiverse.transform._internal.tree.types import COMPARABLE, D


class Horizontal(Operator):
def __init__(self, name: str, *signatures: Signature):
super().__init__(
name, *signatures, param_names=["args"], generate_expr_method=False
name, *signatures, param_names=["arg", "args"], generate_expr_method=False
)


horizontal_max = Horizontal(
"max", *(Signature(dtype, ..., return_type=dtype) for dtype in COMPARABLE)
"max", *(Signature(dtype, dtype, ..., return_type=dtype) for dtype in COMPARABLE)
)

horizontal_min = Horizontal(
"min", *(Signature(dtype, ..., return_type=dtype) for dtype in COMPARABLE)
"min", *(Signature(dtype, dtype, ..., return_type=dtype) for dtype in COMPARABLE)
)

coalesce = Horizontal("coalesce", Signature(D, D, ..., return_type=D))
8 changes: 5 additions & 3 deletions src/pydiverse/transform/_internal/ops/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def insert(
*,
last_type: Dtype | None = None,
) -> None:
if len(sig) == 1 and last_is_vararg:
assert isinstance(last_type, Dtype)
self.children[last_type] = self
sig = []

if len(sig) == 0:
assert self.data is None
self.data = data
if last_is_vararg:
assert isinstance(last_type, Dtype)
self.children[last_type] = self
return

if sig[0] not in self.children:
Expand Down
Loading
Loading