diff --git a/python/vegafusion/vegafusion/datasource/__init__.py b/python/vegafusion/vegafusion/datasource/__init__.py index 94a9a391..cc8d1d0a 100644 --- a/python/vegafusion/vegafusion/datasource/__init__.py +++ b/python/vegafusion/vegafusion/datasource/__init__.py @@ -1,3 +1,4 @@ from .dfi_datasource import DfiDatasource from .pandas_datasource import PandasDatasource +from .pyarrow_datasource import PyArrowDatasource from .datasource import Datasource diff --git a/python/vegafusion/vegafusion/datasource/pyarrow_datasource.py b/python/vegafusion/vegafusion/datasource/pyarrow_datasource.py new file mode 100644 index 00000000..f068e34d --- /dev/null +++ b/python/vegafusion/vegafusion/datasource/pyarrow_datasource.py @@ -0,0 +1,16 @@ +from typing import Iterable +import pyarrow as pa +from .datasource import Datasource + +class PyArrowDatasource(Datasource): + def __init__(self, dataframe: pa.Table): + self._table = dataframe + + def schema(self) -> pa.Schema: + return self._table.schema + + def fetch(self, columns: Iterable[str]) -> pa.Table: + return pa.Table.from_arrays( + [self._table[c] for c in columns], + names=list(columns) + ) diff --git a/python/vegafusion/vegafusion/runtime.py b/python/vegafusion/vegafusion/runtime.py index 45f547c0..1551105a 100644 --- a/python/vegafusion/vegafusion/runtime.py +++ b/python/vegafusion/vegafusion/runtime.py @@ -4,7 +4,7 @@ from typing import Union from .connection import SqlConnection from .dataset import SqlDataset, DataFrameDataset -from .datasource import PandasDatasource, DfiDatasource +from .datasource import PandasDatasource, DfiDatasource, PyArrowDatasource from .evaluation import get_mark_group_for_scope from .transformer import import_pyarrow_interchange, to_arrow_table from .local_tz import get_local_tz @@ -209,16 +209,6 @@ def _import_or_register_inline_datasets(self, inline_datasets=None): imported_inline_datasets[name] = value elif isinstance(value, DataFrameDataset): imported_inline_datasets[name] = value - elif isinstance(value, pa.Table): - if self._connection is not None: - try: - # Try registering Arrow Table if supported - self._connection.register_arrow(name, value, temporary=True) - continue - except ValueError: - pass - - imported_inline_datasets[name] = DfiDatasource(value) elif isinstance(value, pd.DataFrame): if self._connection is not None: try: @@ -230,7 +220,26 @@ def _import_or_register_inline_datasets(self, inline_datasets=None): imported_inline_datasets[name] = PandasDatasource(value) elif hasattr(value, "__dataframe__"): - imported_inline_datasets[name] = DfiDatasource(value) + # Let polars convert to pyarrow since it has broader support than the raw dataframe interchange + # protocol, and "This operation is mostly zero copy." + try: + import polars as pl + if isinstance(value, pl.DataFrame): + value = value.to_arrow() + except ImportError: + pass + + if isinstance(value, pa.Table): + try: + if self._connection is not None: + # Try registering Arrow Table if supported + self._connection.register_arrow(name, value, temporary=True) + continue + except ValueError: + pass + imported_inline_datasets[name] = PyArrowDatasource(value) + else: + imported_inline_datasets[name] = DfiDatasource(value) else: raise ValueError(f"Unsupported DataFrame type: {type(value)}")