Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Nov 27, 2024
2 parents d515633 + 07cbc49 commit f810f1f
Show file tree
Hide file tree
Showing 54 changed files with 9,872 additions and 9,641 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ jobs:
name: artifact
path: dist
- name: Publish package on TestPyPi
uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0
uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa
with:
repository-url: https://test.pypi.org/legacy/
- name: Publish package on PyPi
uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0
uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
environments: ${{ matrix.environment }}

- name: Start Docker Compose
uses: isbang/compose-action@e5813a5909aca4ae36058edae58f6e52b9c971f8
uses: isbang/compose-action@f1ca7fefe3627c2dab0ae1db43a106d82740245e
with:
compose-file: docker-compose.yaml

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update-lockfiles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
set -euo pipefail
pixi update --json | pixi exec pixi-diff-to-markdown >> diff.md
- name: Create pull request
uses: peter-evans/create-pull-request@v6
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update pixi lockfile
Expand Down
144 changes: 71 additions & 73 deletions generate_col_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
from collections.abc import Iterable
from types import NoneType

from pydiverse.transform._internal.backend.polars import PolarsImpl
from pydiverse.transform._internal.ops.core import NoExprMethod, Operator
from pydiverse.transform._internal.tree.dtypes import (
from pydiverse.transform._internal.ops import ops
from pydiverse.transform._internal.ops.op import Operator
from pydiverse.transform._internal.ops.signature import Signature
from pydiverse.transform._internal.tree.types import (
Dtype,
Template,
Tvar,
pdt_type_to_python,
)
from pydiverse.transform._internal.tree.registry import Signature

col_expr_path = "./src/pydiverse/transform/_internal/tree/col_expr.py"
fns_path = "./src/pydiverse/transform/_internal/pipe/functions.py"
reg = PolarsImpl.registry
namespaces = ["str", "dt"]
rversions = {
COL_EXPR_PATH = "./src/pydiverse/transform/_internal/tree/col_expr.py"
FNS_PATH = "./src/pydiverse/transform/_internal/pipe/functions.py"

NAMESPACES = ["str", "dt"]

RVERSIONS = {
"__add__",
"__sub__",
"__mul__",
Expand All @@ -31,19 +32,18 @@
}


def format_param(name: str, dtype: Dtype) -> str:
if dtype.vararg:
return f"*{name}"
return name
def add_vararg_star(formatted_args: str) -> str:
last_arg = "*" + formatted_args.split(", ")[-1]
return ", ".join(formatted_args.split(", ")[:-1] + [last_arg])


def type_annotation(param: Dtype, specialize_generic: bool) -> str:
if not specialize_generic or isinstance(param, Template):
def type_annotation(dtype: Dtype, specialize_generic: bool) -> str:
if (not specialize_generic and not dtype.const) or isinstance(dtype, Tvar):
return "ColExpr"
if param.const:
python_type = pdt_type_to_python(param)
if dtype.const:
python_type = pdt_type_to_python(dtype)
return python_type.__name__ if python_type is not NoneType else "None"
return f"ColExpr[{param.__class__.__name__}]"
return f"ColExpr[{dtype.__class__.__name__}]"


def generate_fn_decl(
Expand All @@ -53,19 +53,24 @@ def generate_fn_decl(
name = op.name

defaults: Iterable = (
op.defaults if op.defaults is not None else (... for _ in op.arg_names)
op.default_values
if op.default_values is not None
else (... for _ in op.param_names)
)

annotated_args = ", ".join(
f"{format_param(name, param)}: "
+ type_annotation(param, specialize_generic)
name
+ ": "
+ type_annotation(dtype, specialize_generic)
+ (f" = {default_val}" if default_val is not ... else "")
for param, name, default_val in zip(
sig.params, op.arg_names, defaults, strict=True
for dtype, name, default_val in zip(
sig.types, op.param_names, defaults, strict=True
)
)
if sig.is_vararg:
annotated_args = add_vararg_star(annotated_args)

if op.context_kwargs is not None:
if len(op.context_kwargs) > 0:
context_kwarg_annotation = {
"partition_by": "Col | ColName | Iterable[Col | ColName]",
"arrange": "ColExpr | Iterable[ColExpr]",
Expand All @@ -77,9 +82,9 @@ def generate_fn_decl(
for kwarg in op.context_kwargs
)

if len(sig.params) == 0 or not sig.params[-1].vararg:
if len(sig.types) == 0 or not sig.is_vararg:
annotated_kwargs = "*" + annotated_kwargs
if len(sig.params) > 0:
if len(sig.types) > 0:
annotated_kwargs = ", " + annotated_kwargs
else:
annotated_kwargs = ""
Expand All @@ -93,33 +98,33 @@ def generate_fn_decl(
def generate_fn_body(
op: Operator,
sig: Signature,
arg_names: list[str] | None = None,
param_names: list[str] | None = None,
*,
op_var_name: str,
rversion: bool = False,
):
if arg_names is None:
arg_names = op.arg_names
if param_names is None:
param_names = op.param_names

if rversion:
assert len(arg_names) == 2
assert not any(param.vararg for param in sig.params)
arg_names = list(reversed(arg_names))
assert len(param_names) == 2
assert not sig.is_vararg
param_names = list(reversed(param_names))

args = "".join(
f", {format_param(name, param)}"
for param, name in zip(sig.params, arg_names, strict=True)
)
args = "".join(f", {name}" for name in param_names)
if sig.is_vararg:
args = add_vararg_star(args)

if op.context_kwargs is not None:
kwargs = "".join(f", {kwarg}={kwarg}" for kwarg in op.context_kwargs)
else:
kwargs = ""

return f' return ColFn("{op.name}"{args}{kwargs})\n\n'
return f" return ColFn(ops.{op_var_name}{args}{kwargs})\n\n"


def generate_overloads(
op: Operator, *, name: str | None = None, rversion: bool = False
op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str
):
res = ""
in_namespace = "." in op.name
Expand All @@ -129,22 +134,16 @@ def generate_overloads(
has_overloads = len(op.signatures) > 1
if has_overloads:
for sig in op.signatures:
res += (
"@overload\n"
+ generate_fn_decl(op, Signature.parse(sig), name=name)
+ " ...\n\n"
)
res += "@overload\n" + generate_fn_decl(op, sig, name=name) + " ...\n\n"

res += generate_fn_decl(
op,
Signature.parse(op.signatures[0]),
name=name,
specialize_generic=not has_overloads,
op, op.signatures[0], name=name, specialize_generic=not has_overloads
) + generate_fn_body(
op,
Signature.parse(op.signatures[0]),
["self.arg"] + op.arg_names[1:] if in_namespace else None,
op.signatures[0],
["self.arg"] + op.param_names[1:] if in_namespace else None,
rversion=rversion,
op_var_name=op_var_name,
)

return res
Expand All @@ -154,7 +153,7 @@ def indent(s: str, by: int) -> str:
return "".join(" " * by + line + "\n" for line in s.split("\n"))


with open(col_expr_path, "r+") as file:
with open(COL_EXPR_PATH, "r+") as file:
new_file_contents = ""
in_col_expr_class = False
in_generated_section = False
Expand All @@ -163,7 +162,7 @@ def indent(s: str, by: int) -> str:
"@dataclasses.dataclass(slots=True)\n"
f"class {name.title()}Namespace(FnNamespace):\n"
)
for name in namespaces
for name in NAMESPACES
}

for line in file:
Expand All @@ -172,15 +171,18 @@ def indent(s: str, by: int) -> str:
elif not in_generated_section and line.startswith(" @overload"):
in_generated_section = True
elif in_col_expr_class and line.startswith("class Col"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if isinstance(op, NoExprMethod):
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if not isinstance(op, Operator) or not op.generate_expr_method:
continue

op_overloads = generate_overloads(op)
if op_name in rversions:
op_overloads = generate_overloads(op, op_var_name=op_var_name)
if op.name in RVERSIONS:
op_overloads += generate_overloads(
op, name=f"__r{op_name[2:]}", rversion=True
op,
name=f"__r{op.name[2:]}",
rversion=True,
op_var_name=op_var_name,
)

op_overloads = indent(op_overloads, 4)
Expand All @@ -190,7 +192,7 @@ def indent(s: str, by: int) -> str:
else:
new_file_contents += op_overloads

for name in namespaces:
for name in NAMESPACES:
new_file_contents += (
" @property\n"
f" def {name}(self):\n"
Expand All @@ -203,7 +205,7 @@ def indent(s: str, by: int) -> str:
" arg: ColExpr\n"
)

for name in namespaces:
for name in NAMESPACES:
new_file_contents += namespace_contents[name]

in_generated_section = False
Expand All @@ -216,27 +218,23 @@ def indent(s: str, by: int) -> str:
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {col_expr_path}")
os.system(f"ruff format {COL_EXPR_PATH}")


with open(fns_path, "r+") as file:
with open(FNS_PATH, "r+") as file:
new_file_contents = ""
display_name = {"hmin": "min", "hmax": "max"}

for line in file:
new_file_contents += line
if line.startswith(" return LiteralCol"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if not isinstance(op, NoExprMethod):
continue

new_file_contents += generate_overloads(
op, name=display_name.get(op_name)
)
if line.startswith("# --- from here the code is generated ---"):
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if isinstance(op, Operator) and not op.generate_expr_method:
new_file_contents += generate_overloads(op, op_var_name=op_var_name)
break

file.seek(0)
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {fns_path}")
os.system(f"ruff format {FNS_PATH}")
Loading

0 comments on commit f810f1f

Please sign in to comment.