Skip to content

Commit

Permalink
fix: Fixes arrow support for df[:, list[int|str]] (#923)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored Sep 8, 2024
1 parent 525d92d commit ee8c62a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
9 changes: 6 additions & 3 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ def __getitem__(
and len(item) == 2
and isinstance(item[1], (list, tuple))
):
return self._from_native_frame(
self._native_frame.take(item[0]).select(item[1])
)
if item[0] == slice(None):
selected_rows = self._native_frame
else:
selected_rows = self._native_frame.take(item[0])

return self._from_native_frame(selected_rows.select(item[1]))

elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
Expand Down
10 changes: 9 additions & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,14 @@ def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]

@overload
Expand All @@ -606,7 +610,7 @@ def __getitem__(
| slice
| Sequence[int]
| tuple[Sequence[int], str | int]
| tuple[Sequence[int], Sequence[int] | Sequence[str] | slice],
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Extract column or slice of DataFrame.
Expand All @@ -623,6 +627,10 @@ def __getitem__(
a `Series`.
- `df[[0, 1], [0, 1, 2]]` extracts the first two rows and the first three columns
and returns a `DataFrame`
- `df[:, [0, 1, 2]]` extracts all rows from the first three columns and returns a
`DataFrame`.
- `df[:, ['a', 'c']]` extracts all rows and columns `'a'` and `'c'` and returns a
`DataFrame`.
- `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and
returns a `DataFrame`
- `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame`
Expand Down
4 changes: 4 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ class DataFrame(NwDataFrame[IntoDataFrameT]):
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...

@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
Expand Down
6 changes: 6 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def test_slice_slice_columns(constructor_eager: Any) -> None:
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)
result = df[:, ["b", "d"]]
expected = {"b": [4, 5, 6], "d": [1, 4, 2]}
compare_dicts(result, expected)
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
Expand Down

0 comments on commit ee8c62a

Please sign in to comment.