Skip to content

Commit

Permalink
feat(pyarrow): support objects implementing __arrow_c_stream__ in `…
Browse files Browse the repository at this point in the history
…ibis.memtable`
  • Loading branch information
jcrist authored and cpcloud committed Nov 12, 2024
1 parent 321a382 commit 10bac0f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,24 @@ def test_memtable_construct_from_pyarrow(backend, con, monkeypatch):
)


def test_memtable_construct_from_pyarrow_c_stream(backend, con):
pa = pytest.importorskip("pyarrow")

class Opaque:
def __init__(self, table):
self._table = table

def __arrow_c_stream__(self, *args, **kwargs):
return self._table.__arrow_c_stream__(*args, **kwargs)

table = pa.table({"a": list("abc"), "b": [1, 2, 3]})

t = ibis.memtable(Opaque(table))

res = con.to_pyarrow(t.order_by("a"))
assert res.equals(table)


@pytest.mark.parametrize("lazy", [False, True])
def test_memtable_construct_from_polars(backend, con, lazy):
pl = pytest.importorskip("polars")
Expand Down
58 changes: 43 additions & 15 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,42 +412,55 @@ def memtable(

@lazy_singledispatch
def _memtable(
data: pd.DataFrame | Any,
data: Any,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
import pandas as pd

from ibis.formats.pandas import PandasDataFrameProxy
if hasattr(data, "__arrow_c_stream__"):
# Support objects exposing arrow's PyCapsule interface
import pyarrow as pa

if not isinstance(data, pd.DataFrame):
df = pd.DataFrame(data, columns=columns)
data = pa.table(data)
else:
df = data
import pandas as pd

data = pd.DataFrame(data, columns=columns)
return _memtable(data, columns=columns, schema=schema, name=name)


@_memtable.register("pandas.DataFrame")
def _memtable_from_pandas_dataframe(
data: pd.DataFrame,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
from ibis.formats.pandas import PandasDataFrameProxy

if df.columns.inferred_type != "string":
cols = df.columns
if data.columns.inferred_type != "string":
cols = data.columns
newcols = getattr(
schema,
"names",
(f"col{i:d}" for i in builtins.range(len(cols))),
)
df = df.rename(columns=dict(zip(cols, newcols)))
data = data.rename(columns=dict(zip(cols, newcols)))

if columns is not None:
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
if (provided_col := len(columns)) != (exist_col := len(data.columns)):
raise ValueError(
"Provided `columns` must have an entry for each column in `data`.\n"
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
)

df = df.rename(columns=dict(zip(df.columns, columns)))
data = data.rename(columns=dict(zip(data.columns, columns)))

# verify that the DataFrame has no duplicate column names because ibis
# doesn't allow that
cols = df.columns
cols = data.columns
dupes = [name for name, count in Counter(cols).items() if count > 1]
if dupes:
raise IbisInputError(
Expand All @@ -456,8 +469,8 @@ def _memtable(

op = ops.InMemoryTable(
name=name if name is not None else util.gen_name("pandas_memtable"),
schema=sch.infer(df) if schema is None else schema,
data=PandasDataFrameProxy(df),
schema=sch.infer(data) if schema is None else schema,
data=PandasDataFrameProxy(data),
)
return op.to_expr()

Expand Down Expand Up @@ -499,6 +512,21 @@ def _memtable_from_pyarrow_dataset(
).to_expr()


@_memtable.register("pyarrow.RecordBatchReader")
def _memtable_from_pyarrow_RecordBatchReader(
data: pa.Table,
*,
name: str | None = None,
schema: SchemaLike | None = None,
columns: Iterable[str] | None = None,
):
raise TypeError(
"Creating an `ibis.memtable` from a `pyarrow.RecordBatchReader` would "
"load _all_ data into memory. If you want to do this, please do so "
"explicitly like `ibis.memtable(reader.read_all())`"
)


@_memtable.register("polars.LazyFrame")
def _memtable_from_polars_lazyframe(data: pl.LazyFrame, **kwargs):
return _memtable_from_polars_dataframe(data.collect(), **kwargs)
Expand Down

0 comments on commit 10bac0f

Please sign in to comment.