Skip to content

Commit

Permalink
fix(python): Fix interchange protocol boolean buffer size (#12177)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Nov 1, 2023
1 parent 5c72df9 commit daac19f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
20 changes: 11 additions & 9 deletions py-polars/polars/interchange/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ def bufsize(self) -> int:

if dtype[0] == DtypeKind.STRING:
return self._data.str.len_bytes().sum() # type: ignore[return-value]

n_bits = self._data.len() * dtype[1]

result, rest = divmod(n_bits, 8)
# Round up to the nearest byte
if rest:
return result + 1
else:
return result
elif dtype[0] == DtypeKind.BOOL:
offset, length, _pointer = self._data._s.get_ptr()
n_bits = offset + length
n_bytes, rest = divmod(n_bits, 8)
# Round up to the nearest byte
if rest:
return n_bytes + 1
else:
return n_bytes

return self._data.len() * (dtype[1] // 8)

@property
def ptr(self) -> int:
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/interchange/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_init_invalid_input() -> None:
(pl.Series(["a", "b", "a", "c", "a"], dtype=pl.Categorical), 20),
(pl.Series([True, False], dtype=pl.Boolean), 1),
(pl.Series([True] * 9, dtype=pl.Boolean), 2),
(pl.Series([True] * 9, dtype=pl.Boolean)[5:], 2),
],
)
def test_bufsize(data: pl.Series, expected: int) -> None:
Expand Down
19 changes: 14 additions & 5 deletions py-polars/tests/unit/interchange/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pyarrow as pa
import pyarrow.interchange
import pytest
from hypothesis import given, note
from hypothesis import given

import polars as pl
from polars.testing import assert_frame_equal
Expand All @@ -28,10 +28,9 @@
]


@given(dataframes(allowed_dtypes=protocol_dtypes, excluded_dtypes=[pl.Boolean]))
@given(dataframes(allowed_dtypes=protocol_dtypes))
def test_roundtrip_pyarrow_parametric(df: pl.DataFrame) -> None:
dfi = df.__dataframe__()
note(f"n_chunks: {dfi.num_chunks()}")
df_pa = pa.interchange.from_dataframe(dfi)
with pl.StringCache():
result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment]
Expand All @@ -41,13 +40,12 @@ def test_roundtrip_pyarrow_parametric(df: pl.DataFrame) -> None:
@given(
dataframes(
allowed_dtypes=protocol_dtypes,
excluded_dtypes=[pl.Boolean, pl.Categorical],
excluded_dtypes=[pl.Categorical],
chunked=False,
)
)
def test_roundtrip_pyarrow_zero_copy_parametric(df: pl.DataFrame) -> None:
dfi = df.__dataframe__(allow_copy=False)
note(f"n_chunks: {dfi.num_chunks()}")
df_pa = pa.interchange.from_dataframe(dfi, allow_copy=False)
result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment]
assert_frame_equal(result, df, categorical_as_str=True)
Expand Down Expand Up @@ -100,3 +98,14 @@ def test_roundtrip_pyarrow_boolean() -> None:
result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment]

assert_frame_equal(result, df)


def test_roundtrip_pyarrow_boolean_midbyte_slice() -> None:
s = pl.Series("a", [False] * 9)[3:]
df = s.to_frame()
dfi = df.__dataframe__()

df_pa = pa.interchange.from_dataframe(dfi)
result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment]

assert_frame_equal(result, df)

0 comments on commit daac19f

Please sign in to comment.