From 3468fa1f5b9dfcf83a95bcb09fe5a4d8d3808620 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 6 Jun 2024 19:30:48 +0100 Subject: [PATCH] Add more complete type annotations in polars interpreter (#15942) We can check this with: pyright --verifytypes cudf_polars --ignoreexternal Which reports a "type completeness" score of around 94%. This will improve once pylibcudf gets type stubs. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - James Lamb (https://github.com/jameslamb) - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/15942 --- .pre-commit-config.yaml | 2 +- python/cudf_polars/cudf_polars/__init__.py | 5 +- python/cudf_polars/cudf_polars/callback.py | 3 +- .../cudf_polars/containers/dataframe.py | 13 +- python/cudf_polars/cudf_polars/dsl/expr.py | 55 +++++--- python/cudf_polars/cudf_polars/dsl/ir.py | 110 +++++++-------- .../cudf_polars/cudf_polars/dsl/translate.py | 127 ++++++++++++------ python/cudf_polars/cudf_polars/py.typed | 0 .../cudf_polars/testing/asserts.py | 2 +- .../cudf_polars/typing/__init__.py | 91 +++++++++++++ python/cudf_polars/pyproject.toml | 2 - 11 files changed, 287 insertions(+), 123 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/py.typed create mode 100644 python/cudf_polars/cudf_polars/typing/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8865fb48e0d..4cdcac88091 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -134,7 +134,7 @@ repos: - id: rapids-dependency-file-generator args: ["--clean"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.3 + rev: v0.4.8 hooks: - id: ruff files: python/.*$ diff --git a/python/cudf_polars/cudf_polars/__init__.py b/python/cudf_polars/cudf_polars/__init__.py index 74547fe2448..b19a282129a 100644 --- a/python/cudf_polars/cudf_polars/__init__.py +++ b/python/cudf_polars/cudf_polars/__init__.py @@ -10,4 +10,7 @@ from __future__ import annotations -__all__: list[str] = [] +from cudf_polars.callback import execute_with_cudf +from cudf_polars.dsl.translate import translate_ir + +__all__: list[str] = ["execute_with_cudf", "translate_ir"] diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index aabb8498ce2..979087d5273 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -16,6 +16,7 @@ import polars as pl from cudf_polars.dsl.ir import IR + from cudf_polars.typing import NodeTraverser __all__: list[str] = ["execute_with_cudf"] @@ -33,7 +34,7 @@ def _callback( return ir.evaluate(cache={}).to_polars() -def execute_with_cudf(nt, *, raise_on_fail: bool = False) -> None: +def execute_with_cudf(nt: NodeTraverser, *, raise_on_fail: bool = False) -> None: """ A post optimization callback that attempts to execute the plan with cudf. diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index eeaf181be0c..ac7e748095e 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -6,7 +6,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import polars as pl @@ -17,6 +17,7 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence, Set + import pyarrow as pa from typing_extensions import Self import cudf @@ -44,13 +45,13 @@ def copy(self) -> Self: def to_polars(self) -> pl.DataFrame: """Convert to a polars DataFrame.""" - return pl.from_arrow( - plc.interop.to_arrow( - self.table, - [plc.interop.ColumnMetadata(name=c.name) for c in self.columns], - ) + table: pa.Table = plc.interop.to_arrow( + self.table, + [plc.interop.ColumnMetadata(name=c.name) for c in self.columns], ) + return cast(pl.DataFrame, pl.from_arrow(table)) + @cached_property def column_names_set(self) -> frozenset[str]: """Return the column names as a set.""" diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index c7c11cf6c68..6d9435ce373 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -139,14 +139,14 @@ def is_equal(self, other: Any) -> bool: other.children ) - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: """Equality of expressions.""" if type(self) != type(other) or hash(self) != hash(other): return False else: return self.is_equal(other) - def __ne__(self, other) -> bool: + def __ne__(self, other: Any) -> bool: """Inequality of expressions.""" return not self.__eq__(other) @@ -285,6 +285,8 @@ class NamedExpr: # when evaluating expressions themselves, only when constructing # named return values in dataframe (IR) nodes. __slots__ = ("name", "value") + value: Expr + name: str def __init__(self, name: str, value: Expr) -> None: self.name = name @@ -298,7 +300,7 @@ def __repr__(self) -> str: """Repr of the expression.""" return f"NamedExpr({self.name}, {self.value}" - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: """Equality of two expressions.""" return ( type(self) is type(other) @@ -306,7 +308,7 @@ def __eq__(self, other) -> bool: and self.value == other.value ) - def __ne__(self, other) -> bool: + def __ne__(self, other: Any) -> bool: """Inequality of expressions.""" return not self.__eq__(other) @@ -344,9 +346,10 @@ def collect_agg(self, *, depth: int) -> AggInfo: class Literal(Expr): __slots__ = ("value",) _non_child = ("dtype", "value") - value: pa.Scalar + value: pa.Scalar[Any] + children: tuple[()] - def __init__(self, dtype: plc.DataType, value: pa.Scalar) -> None: + def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None: super().__init__(dtype) assert value.type == plc.interop.to_arrow(dtype) self.value = value @@ -367,6 +370,7 @@ class Col(Expr): __slots__ = ("name",) _non_child = ("dtype", "name") name: str + children: tuple[()] def __init__(self, dtype: plc.DataType, name: str) -> None: self.dtype = dtype @@ -388,6 +392,8 @@ def collect_agg(self, *, depth: int) -> AggInfo: class Len(Expr): + children: tuple[()] + def do_evaluate( self, df: DataFrame, @@ -410,8 +416,15 @@ def collect_agg(self, *, depth: int) -> AggInfo: class BooleanFunction(Expr): __slots__ = ("name", "options", "children") _non_child = ("dtype", "name", "options") + children: tuple[Expr, ...] - def __init__(self, dtype: plc.DataType, name: str, options: tuple, *children: Expr): + def __init__( + self, + dtype: plc.DataType, + name: pl_expr.BooleanFunction, + options: tuple[Any, ...], + *children: Expr, + ) -> None: super().__init__(dtype) self.options = options self.name = name @@ -610,14 +623,15 @@ def do_evaluate( class StringFunction(Expr): __slots__ = ("name", "options", "children") _non_child = ("dtype", "name", "options") + children: tuple[Expr, ...] def __init__( self, dtype: plc.DataType, name: pl_expr.StringFunction, - options: tuple, + options: tuple[Any, ...], *children: Expr, - ): + ) -> None: super().__init__(dtype) self.options = options self.name = name @@ -661,10 +675,11 @@ def do_evaluate( class Sort(Expr): __slots__ = ("options", "children") _non_child = ("dtype", "options") + children: tuple[Expr] def __init__( self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr - ): + ) -> None: super().__init__(dtype) self.options = options self.children = (column,) @@ -696,6 +711,7 @@ def do_evaluate( class SortBy(Expr): __slots__ = ("options", "children") _non_child = ("dtype", "options") + children: tuple[Expr, ...] def __init__( self, @@ -703,7 +719,7 @@ def __init__( options: tuple[bool, tuple[bool], tuple[bool]], column: Expr, *by: Expr, - ): + ) -> None: super().__init__(dtype) self.options = options self.children = (column, *by) @@ -734,8 +750,9 @@ def do_evaluate( class Gather(Expr): __slots__ = ("children",) _non_child = ("dtype",) + children: tuple[Expr, Expr] - def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr): + def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None: super().__init__(dtype) self.children = (values, indices) @@ -775,6 +792,7 @@ def do_evaluate( class Filter(Expr): __slots__ = ("children",) _non_child = ("dtype",) + children: tuple[Expr, Expr] def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr): super().__init__(dtype) @@ -801,8 +819,9 @@ def do_evaluate( class RollingWindow(Expr): __slots__ = ("options", "children") _non_child = ("dtype", "options") + children: tuple[Expr] - def __init__(self, dtype: plc.DataType, options: Any, agg: Expr): + def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None: super().__init__(dtype) self.options = options self.children = (agg,) @@ -811,8 +830,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr): class GroupedRollingWindow(Expr): __slots__ = ("options", "children") _non_child = ("dtype", "options") + children: tuple[Expr, ...] - def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr): + def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None: super().__init__(dtype) self.options = options self.children = (agg, *by) @@ -821,8 +841,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr): class Cast(Expr): __slots__ = ("children",) _non_child = ("dtype",) + children: tuple[Expr] - def __init__(self, dtype: plc.DataType, value: Expr): + def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) @@ -848,6 +869,7 @@ def collect_agg(self, *, depth: int) -> AggInfo: class Agg(Expr): __slots__ = ("name", "options", "op", "request", "children") _non_child = ("dtype", "name", "options") + children: tuple[Expr] def __init__( self, dtype: plc.DataType, name: str, options: Any, value: Expr @@ -1007,7 +1029,7 @@ def _last(self, column: Column) -> Column: def do_evaluate( self, - df, + df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, @@ -1022,6 +1044,7 @@ def do_evaluate( class BinOp(Expr): __slots__ = ("op", "children") _non_child = ("dtype", "op") + children: tuple[Expr, Expr] def __init__( self, diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 0a72cbd9f83..665bbe5be41 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -1,7 +1,5 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 -# TODO: remove need for this -# ruff: noqa: D101 """ DSL nodes for the LogicalPlan of polars. @@ -15,11 +13,11 @@ from __future__ import annotations +import dataclasses import itertools import types -from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, NoReturn import pyarrow as pa from typing_extensions import assert_never @@ -34,8 +32,11 @@ from cudf_polars.utils import sorting if TYPE_CHECKING: + from collections.abc import MutableMapping from typing import Literal + from cudf_polars.typing import Schema + __all__ = [ "IR", @@ -91,14 +92,14 @@ def broadcast( ] -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class IR: """Abstract plan node, representing an unevaluated dataframe.""" - schema: dict[str, plc.DataType] + schema: Schema """Mapping from column names to their data types.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """ Evaluate the node and return a dataframe. @@ -123,7 +124,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: raise NotImplementedError -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class PythonScan(IR): """Representation of input from a python function.""" @@ -133,7 +134,7 @@ class PythonScan(IR): """Filter to apply to the constructed dataframe before returning it.""" -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Scan(IR): """Input from files.""" @@ -153,14 +154,14 @@ class Scan(IR): predicate: expr.NamedExpr | None """Mask to apply to the read dataframe.""" - def __post_init__(self): + def __post_init__(self) -> None: """Validate preconditions.""" if self.file_options.n_rows is not None: raise NotImplementedError("row limit in scan") if self.typ not in ("csv", "parquet"): raise NotImplementedError(f"Unhandled scan type: {self.typ}") - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" options = self.file_options with_columns = options.with_columns @@ -172,9 +173,9 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: ) ) elif self.typ == "parquet": - df = DataFrame.from_cudf( - cudf.read_parquet(self.paths, columns=with_columns) - ) + cdf = cudf.read_parquet(self.paths, columns=with_columns) + assert isinstance(cdf, cudf.DataFrame) + df = DataFrame.from_cudf(cdf) else: assert_never(self.typ) if row_index is not None: @@ -208,7 +209,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return df.filter(mask) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Cache(IR): """ Return a cached plan node. @@ -221,7 +222,7 @@ class Cache(IR): value: IR """The unevaluated node to cache.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" try: return cache[self.key] @@ -229,7 +230,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return cache.setdefault(self.key, self.value.evaluate(cache=cache)) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class DataFrameScan(IR): """ Input from an existing polars DataFrame. @@ -244,7 +245,7 @@ class DataFrameScan(IR): predicate: expr.NamedExpr | None """Mask to apply.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" pdf = pl.DataFrame._from_pydf(self.df) if self.projection is not None: @@ -270,7 +271,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return df -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Select(IR): """Produce a new dataframe selecting given expressions from an input.""" @@ -279,7 +280,7 @@ class Select(IR): expr: list[expr.NamedExpr] """List of expressions to evaluate to form the new dataframe.""" - def evaluate(self, *, cache: dict[int, DataFrame]): + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) # Handle any broadcasting @@ -287,7 +288,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]): return DataFrame(columns) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Reduce(IR): """ Produce a new dataframe selecting given expressions from an input. @@ -300,7 +301,7 @@ class Reduce(IR): expr: list[expr.NamedExpr] """List of expressions to evaluate to form the new dataframe.""" - def evaluate(self, *, cache: dict[int, DataFrame]): + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) columns = broadcast(*(e.evaluate(df) for e in self.expr)) @@ -308,7 +309,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]): return DataFrame(columns) -def placeholder_column(n: int): +def placeholder_column(n: int) -> plc.Column: """ Produce a placeholder pylibcudf column with NO BACKING DATA. @@ -338,7 +339,7 @@ def placeholder_column(n: int): ) -@dataclass(slots=False) +@dataclasses.dataclass(slots=False) class GroupBy(IR): """Perform a groupby.""" @@ -352,6 +353,7 @@ class GroupBy(IR): """Should the order of the input dataframe be maintained?""" options: Any """Options controlling style of groupby.""" + agg_infos: list[expr.AggInfo] = dataclasses.field(init=False) @staticmethod def check_agg(agg: expr.Expr) -> int: @@ -383,7 +385,7 @@ def check_agg(agg: expr.Expr) -> int: else: raise NotImplementedError(f"No handler for {agg=}") - def __post_init__(self): + def __post_init__(self) -> None: """Check whether all the aggregations are implemented.""" if self.options.rolling is None and self.maintain_order: raise NotImplementedError("Maintaining order in groupby") @@ -393,7 +395,7 @@ def __post_init__(self): raise NotImplementedError("Nested aggregations in groupby") self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) keys = broadcast( @@ -438,7 +440,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return DataFrame([*result_keys, *results]).slice(self.options.slice) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Join(IR): """A join of two dataframes.""" @@ -466,7 +468,7 @@ class Join(IR): - coalesce: should key columns be coalesced (only makes sense for outer joins) """ - def __post_init__(self): + def __post_init__(self) -> None: """Validate preconditions.""" if self.options[0] == "cross": raise NotImplementedError("cross join not implemented") @@ -511,7 +513,7 @@ def _joiners( else: assert_never(how) - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" left = self.left.evaluate(cache=cache) right = self.right.evaluate(cache=cache) @@ -577,7 +579,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return result.slice(zlice) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class HStack(IR): """Add new columns to a dataframe.""" @@ -586,7 +588,7 @@ class HStack(IR): columns: list[expr.NamedExpr] """List of expressions to produce new columns.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) columns = [c.evaluate(df) for c in self.columns] @@ -597,7 +599,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return df.with_columns(columns) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Distinct(IR): """Produce a new dataframe with distinct rows.""" @@ -619,7 +621,7 @@ class Distinct(IR): "any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY, } - def __init__(self, schema: dict, df: IR, options: Any): + def __init__(self, schema: Schema, df: IR, options: Any) -> None: self.schema = schema self.df = df (keep, subset, maintain_order, zlice) = options @@ -628,7 +630,7 @@ def __init__(self, schema: dict, df: IR, options: Any): self.stable = maintain_order self.zlice = zlice - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) if self.subset is None: @@ -667,7 +669,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return result.slice(self.zlice) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Sort(IR): """Sort a dataframe.""" @@ -686,12 +688,12 @@ class Sort(IR): def __init__( self, - schema: dict, + schema: Schema, df: IR, by: list[expr.NamedExpr], options: Any, zlice: tuple[int, int] | None, - ): + ) -> None: self.schema = schema self.df = df self.by = by @@ -704,7 +706,7 @@ def __init__( plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key ) - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) sort_keys = broadcast( @@ -736,7 +738,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return DataFrame(columns).slice(self.zlice) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Slice(IR): """Slice a dataframe.""" @@ -747,13 +749,13 @@ class Slice(IR): length: int """Length of the slice.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) return df.slice((self.offset, self.length)) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Filter(IR): """Filter a dataframe with a boolean mask.""" @@ -762,21 +764,21 @@ class Filter(IR): mask: expr.NamedExpr """Expression evaluating to a mask.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) (mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows) return df.filter(mask) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Projection(IR): """Select a subset of columns from a dataframe.""" df: IR """Input.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) # This can reorder things. @@ -786,7 +788,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: return DataFrame(columns) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class MapFunction(IR): """Apply some function to a dataframe.""" @@ -807,7 +809,7 @@ class MapFunction(IR): ] ) - def __post_init__(self): + def __post_init__(self) -> None: """Validate preconditions.""" if self.name not in MapFunction._NAMES: raise NotImplementedError(f"Unhandled map function {self.name}") @@ -824,7 +826,7 @@ def __post_init__(self): if key_column not in self.df.dfs[0].schema: raise ValueError(f"Key column {key_column} not found") - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" if self.name == "merge_sorted": # merge_sorted operates on Union inputs @@ -876,7 +878,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: raise AssertionError("Should never be reached") -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Union(IR): """Concatenate dataframes vertically.""" @@ -885,13 +887,13 @@ class Union(IR): zlice: tuple[int, int] | None """Optional slice to apply after concatenation.""" - def __post_init__(self): + def __post_init__(self) -> None: """Validated preconditions.""" schema = self.dfs[0].schema if not all(s.schema == schema for s in self.dfs[1:]): raise ValueError("Schema mismatch") - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" # TODO: only evaluate what we need if we have a slice dfs = [df.evaluate(cache=cache) for df in self.dfs] @@ -900,14 +902,14 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: ).slice(self.zlice) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class HConcat(IR): """Concatenate dataframes horizontally.""" dfs: list[IR] """List of inputs.""" - def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" dfs = [df.evaluate(cache=cache) for df in self.dfs] return DataFrame( @@ -915,7 +917,7 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: ) -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class ExtContext(IR): """ Concatenate dataframes horizontally. @@ -928,7 +930,7 @@ class ExtContext(IR): extra: list[IR] """List of extra inputs.""" - def __post_init__(self): + def __post_init__(self) -> NoReturn: """Validate preconditions.""" raise NotImplementedError( "ExtContext will be deprecated, use horizontal concat instead." diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 641176daff4..38107023365 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -16,12 +16,13 @@ import cudf._lib.pylibcudf as plc from cudf_polars.dsl import expr, ir +from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes __all__ = ["translate_ir", "translate_named_expr"] -class set_node(AbstractContextManager): +class set_node(AbstractContextManager[None]): """ Run a block with current node set in the visitor. @@ -39,30 +40,36 @@ class set_node(AbstractContextManager): """ __slots__ = ("n", "visitor") + visitor: NodeTraverser + n: int - def __init__(self, visitor, n: int): + def __init__(self, visitor: NodeTraverser, n: int) -> None: self.visitor = visitor self.n = n - def __enter__(self): + def __enter__(self) -> None: n = self.visitor.get_node() self.visitor.set_node(self.n) self.n = n - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.visitor.set_node(self.n) -noop_context: nullcontext = nullcontext() +noop_context: nullcontext[None] = nullcontext() @singledispatch -def _translate_ir(node: Any, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _translate_ir( + node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: raise NotImplementedError(f"Translation for {type(node).__name__}") @_translate_ir.register -def _(node: pl_ir.PythonScan, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.PythonScan( schema, node.options, @@ -73,7 +80,9 @@ def _(node: pl_ir.PythonScan, visitor: Any, schema: dict[str, plc.DataType]) -> @_translate_ir.register -def _(node: pl_ir.Scan, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.Scan( schema, node.scan_type, @@ -86,13 +95,15 @@ def _(node: pl_ir.Scan, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: @_translate_ir.register -def _(node: pl_ir.Cache, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input)) @_translate_ir.register def _( - node: pl_ir.DataFrameScan, visitor: Any, schema: dict[str, plc.DataType] + node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.DataFrameScan( schema, @@ -105,7 +116,9 @@ def _( @_translate_ir.register -def _(node: pl_ir.Select, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.expr] @@ -113,7 +126,9 @@ def _(node: pl_ir.Select, visitor: Any, schema: dict[str, plc.DataType]) -> ir.I @_translate_ir.register -def _(node: pl_ir.GroupBy, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) aggs = [translate_named_expr(visitor, n=e) for e in node.aggs] @@ -129,7 +144,9 @@ def _(node: pl_ir.GroupBy, visitor: Any, schema: dict[str, plc.DataType]) -> ir. @_translate_ir.register -def _(node: pl_ir.Join, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: # Join key dtypes are dependent on the schema of the left and # right inputs, so these must be translated with the relevant # input active. @@ -143,7 +160,9 @@ def _(node: pl_ir.Join, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: @_translate_ir.register -def _(node: pl_ir.HStack, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] @@ -151,7 +170,9 @@ def _(node: pl_ir.HStack, visitor: Any, schema: dict[str, plc.DataType]) -> ir.I @_translate_ir.register -def _(node: pl_ir.Reduce, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.expr] @@ -159,7 +180,9 @@ def _(node: pl_ir.Reduce, visitor: Any, schema: dict[str, plc.DataType]) -> ir.I @_translate_ir.register -def _(node: pl_ir.Distinct, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.Distinct( schema, translate_ir(visitor, n=node.input), @@ -168,7 +191,9 @@ def _(node: pl_ir.Distinct, visitor: Any, schema: dict[str, plc.DataType]) -> ir @_translate_ir.register -def _(node: pl_ir.Sort, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) by = [translate_named_expr(visitor, n=e) for e in node.by_column] @@ -176,12 +201,16 @@ def _(node: pl_ir.Sort, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: @_translate_ir.register -def _(node: pl_ir.Slice, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len) @_translate_ir.register -def _(node: pl_ir.Filter, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) mask = translate_named_expr(visitor, n=node.predicate) @@ -190,13 +219,17 @@ def _(node: pl_ir.Filter, visitor: Any, schema: dict[str, plc.DataType]) -> ir.I @_translate_ir.register def _( - node: pl_ir.SimpleProjection, visitor: Any, schema: dict[str, plc.DataType] + node: pl_ir.SimpleProjection, + visitor: NodeTraverser, + schema: dict[str, plc.DataType], ) -> ir.IR: return ir.Projection(schema, translate_ir(visitor, n=node.input)) @_translate_ir.register -def _(node: pl_ir.MapFunction, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: name, *options = node.function return ir.MapFunction( schema, @@ -208,19 +241,25 @@ def _(node: pl_ir.MapFunction, visitor: Any, schema: dict[str, plc.DataType]) -> @_translate_ir.register -def _(node: pl_ir.Union, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.Union( schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options ) @_translate_ir.register -def _(node: pl_ir.HConcat, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs]) @_translate_ir.register -def _(node: pl_ir.ExtContext, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: +def _( + node: pl_ir.ExtContext, visitor: NodeTraverser, schema: dict[str, plc.DataType] +) -> ir.IR: return ir.ExtContext( schema, translate_ir(visitor, n=node.input), @@ -228,7 +267,7 @@ def _(node: pl_ir.ExtContext, visitor: Any, schema: dict[str, plc.DataType]) -> ) -def translate_ir(visitor: Any, *, n: int | None = None) -> ir.IR: +def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: """ Translate a polars-internal IR node to our representation. @@ -249,7 +288,7 @@ def translate_ir(visitor: Any, *, n: int | None = None) -> ir.IR: NotImplementedError If we can't translate the nodes due to unsupported functionality. """ - ctx: AbstractContextManager = ( + ctx: AbstractContextManager[None] = ( set_node(visitor, n) if n is not None else noop_context ) with ctx: @@ -258,7 +297,9 @@ def translate_ir(visitor: Any, *, n: int | None = None) -> ir.IR: return _translate_ir(node, visitor, schema) -def translate_named_expr(visitor: Any, *, n: pl_expr.PyExprIR) -> expr.NamedExpr: +def translate_named_expr( + visitor: NodeTraverser, *, n: pl_expr.PyExprIR +) -> expr.NamedExpr: """ Translate a polars-internal named expression IR object into our representation. @@ -289,12 +330,14 @@ def translate_named_expr(visitor: Any, *, n: pl_expr.PyExprIR) -> expr.NamedExpr @singledispatch -def _translate_expr(node: Any, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _translate_expr( + node: Any, visitor: NodeTraverser, dtype: plc.DataType +) -> expr.Expr: raise NotImplementedError(f"Translation for {type(node).__name__}") @_translate_expr.register -def _(node: pl_expr.Function, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: name, *options = node.function_data options = tuple(options) if isinstance(name, pl_expr.StringFunction): @@ -316,7 +359,7 @@ def _(node: pl_expr.Function, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Window, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby? if node.partition_by is None: return expr.RollingWindow( @@ -332,19 +375,19 @@ def _(node: pl_expr.Window, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Literal, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype)) return expr.Literal(dtype, value) @_translate_expr.register -def _(node: pl_expr.Sort, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Sort, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: # TODO: raise in groupby return expr.Sort(dtype, node.options, translate_expr(visitor, n=node.expr)) @_translate_expr.register -def _(node: pl_expr.SortBy, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.SortBy, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.SortBy( dtype, node.sort_options, @@ -354,7 +397,7 @@ def _(node: pl_expr.SortBy, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Gather, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Gather, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.Gather( dtype, translate_expr(visitor, n=node.expr), @@ -363,7 +406,7 @@ def _(node: pl_expr.Gather, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Filter, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Filter, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.Filter( dtype, translate_expr(visitor, n=node.input), @@ -372,7 +415,7 @@ def _(node: pl_expr.Filter, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Cast, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: inner = translate_expr(visitor, n=node.expr) # Push casts into literals so we can handle Cast(Literal(Null)) if isinstance(inner, expr.Literal): @@ -382,12 +425,12 @@ def _(node: pl_expr.Cast, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Column, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.Col(dtype, node.name) @_translate_expr.register -def _(node: pl_expr.Agg, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.Agg( dtype, node.name, @@ -397,7 +440,9 @@ def _(node: pl_expr.Agg, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.BinaryExpr, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _( + node: pl_expr.BinaryExpr, visitor: NodeTraverser, dtype: plc.DataType +) -> expr.Expr: return expr.BinOp( dtype, expr.BinOp._MAPPING[node.op], @@ -407,11 +452,11 @@ def _(node: pl_expr.BinaryExpr, visitor: Any, dtype: plc.DataType) -> expr.Expr: @_translate_expr.register -def _(node: pl_expr.Len, visitor: Any, dtype: plc.DataType) -> expr.Expr: +def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: return expr.Len(dtype) -def translate_expr(visitor: Any, *, n: int) -> expr.Expr: +def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr: """ Translate a polars-internal expression IR into our representation. diff --git a/python/cudf_polars/cudf_polars/py.typed b/python/cudf_polars/cudf_polars/py.typed new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 2fbfa971fef..2f19b41cc3a 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -28,7 +28,7 @@ def assert_gpu_result_equal( rtol: float = 1e-05, atol: float = 1e-08, categorical_as_str: bool = False, -): +) -> None: """ Assert that collection of a lazyframe on GPU produces correct results. diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py new file mode 100644 index 00000000000..287c977f4eb --- /dev/null +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Typing utilities for cudf_polars.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Protocol, TypeAlias + +from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir + +import cudf._lib.pylibcudf as plc + +if TYPE_CHECKING: + from typing import Callable + + import polars as pl + +IR: TypeAlias = ( + pl_ir.PythonScan + | pl_ir.Scan + | pl_ir.Cache + | pl_ir.DataFrameScan + | pl_ir.Select + | pl_ir.GroupBy + | pl_ir.Join + | pl_ir.HStack + | pl_ir.Distinct + | pl_ir.Sort + | pl_ir.Slice + | pl_ir.Filter + | pl_ir.SimpleProjection + | pl_ir.MapFunction + | pl_ir.Union + | pl_ir.HConcat + | pl_ir.ExtContext +) + +Expr: TypeAlias = ( + pl_expr.Function + | pl_expr.Window + | pl_expr.Literal + | pl_expr.Sort + | pl_expr.SortBy + | pl_expr.Gather + | pl_expr.Filter + | pl_expr.Cast + | pl_expr.Column + | pl_expr.Agg + | pl_expr.BinaryExpr + | pl_expr.Len + | pl_expr.PyExprIR +) + +Schema: TypeAlias = Mapping[str, plc.DataType] + + +class NodeTraverser(Protocol): + """Abstract protocol for polars NodeTraverser.""" + + def get_node(self) -> int: + """Return current plan node id.""" + ... + + def set_node(self, n: int) -> None: + """Set the current plan node to n.""" + ... + + def view_current_node(self) -> IR: + """Convert current plan node to python rep.""" + ... + + def get_schema(self) -> Mapping[str, pl.DataType]: + """Get the schema of the current plan node.""" + ... + + def get_dtype(self, n: int) -> pl.DataType: + """Get the datatype of the given expression id.""" + ... + + def view_expression(self, n: int) -> Expr: + """Convert the given expression to python rep.""" + ... + + def set_udf( + self, + callback: Callable[[list[str] | None, str | None, int | None], pl.DataFrame], + ) -> None: + """Set the callback replacing the current node in the plan.""" + ... diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index e50ee76a9b9..2faf8c3193f 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -62,8 +62,6 @@ target-version = "py39" fix = true [tool.ruff.lint] -# __init__.py must re-export everything it imports -ignore-init-module-imports = false select = [ "E", # pycodestyle "W", # pycodestyle