Skip to content

Commit

Permalink
drive-by typing (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 30, 2024
1 parent 13dcb7d commit 1f3f420
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
3 changes: 2 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
import numpy as np
import pyarrow as pa
from typing_extensions import Self

from narwhals._arrow.group_by import ArrowGroupBy
Expand All @@ -33,7 +34,7 @@
class ArrowDataFrame:
# --- not in the spec ---
def __init__(
self, native_dataframe: Any, *, backend_version: tuple[int, ...]
self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...]
) -> None:
self._native_frame = native_dataframe
self._implementation = Implementation.PYARROW
Expand Down
5 changes: 3 additions & 2 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType


Expand Down Expand Up @@ -157,7 +158,7 @@ def __invert__(self) -> Self:
def len(self) -> Self:
return reuse_series_implementation(self, "len", returns_scalar=True)

def filter(self, *predicates: Any) -> Self:
def filter(self, *predicates: IntoArrowExpr) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
return reuse_series_implementation(self, "filter", other=expr)
Expand Down Expand Up @@ -228,7 +229,7 @@ def null_count(self) -> Self:
def is_null(self) -> Self:
return reuse_series_implementation(self, "is_null")

def is_between(self, lower_bound: Any, upper_bound: Any, closed: str) -> Any:
def is_between(self, lower_bound: Any, upper_bound: Any, closed: str) -> Self:
return reuse_series_implementation(
self, "is_between", lower_bound, upper_bound, closed
)
Expand Down
13 changes: 10 additions & 3 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
import pyarrow as pa
from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
Expand All @@ -26,7 +27,11 @@

class ArrowSeries:
def __init__(
self, native_series: Any, *, name: str, backend_version: tuple[int, ...]
self,
native_series: pa.ChunkedArray,
*,
name: str,
backend_version: tuple[int, ...],
) -> None:
self._name = name
self._native_series = native_series
Expand Down Expand Up @@ -366,7 +371,9 @@ def all(self) -> bool:

return pc.all(self._native_series) # type: ignore[no-any-return]

def is_between(self, lower_bound: Any, upper_bound: Any, closed: str = "both") -> Any:
def is_between(
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
Expand Down Expand Up @@ -657,7 +664,7 @@ def clip(

return self._from_native_series(arr)

def to_arrow(self: Self) -> Any:
def to_arrow(self: Self) -> pa.Array:
return self._native_series.combine_chunks()

@property
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ exclude = [
"/tests",
"/tpch",
"/utils",
".gitignore",
"CONTRIBUTING.md",
"mkdocs.yml",
"noxfile.py",
"requirements-dev.txt",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 1f3f420

Please sign in to comment.