Skip to content

Commit

Permalink
fix(dask): enable pyarrow conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 14, 2023
1 parent ef5e341 commit d2c5a1e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
self._register_in_memory_tables(expr)

def _define_udf_translation_rules(self, expr):
if self.supports_in_memory_tables:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def compile(
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping
from pathlib import Path

# Make sure that the pandas backend options have been loaded
ibis.pandas # noqa: B018
Expand All @@ -29,6 +30,7 @@
class Backend(BasePandasBackend):
name = "dask"
backend_table_type = dd.DataFrame
supports_in_memory_tables = False

def do_connect(
self,
Expand Down Expand Up @@ -133,3 +135,8 @@ def _convert_object(cls, obj: dd.DataFrame) -> dd.DataFrame:

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).persist())

def read_delta(
self, source: str | Path, table_name: str | None = None, **kwargs: Any
):
raise NotImplementedError(self.name)
8 changes: 4 additions & 4 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,6 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
def _clean_up_cached_table(self, op):
del self.dictionary[op.name]


class Backend(BasePandasBackend):
name = "pandas"

def to_pyarrow(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -264,6 +260,10 @@ def to_pyarrow_batches(
pa_table.schema, pa_table.to_batches(max_chunksize=chunk_size)
)


class Backend(BasePandasBackend):
name = "pandas"

def execute(self, query, params=None, limit="default", **kwargs):
from ibis.backends.pandas.core import execute_and_reset

Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/tests/test_dataframe_interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
]


@pytest.mark.notimpl(["dask", "druid"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.notimpl(
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor"
)
Expand Down Expand Up @@ -60,7 +60,6 @@ def test_dataframe_interchange_no_execute(con, alltypes, mocker):
assert not to_pyarrow.called


@pytest.mark.notimpl(["dask"])
@pytest.mark.notimpl(
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor"
)
Expand All @@ -80,7 +79,7 @@ def test_dataframe_interchange_dataframe_methods_execute(con, alltypes, mocker):
assert to_pyarrow.call_count == 1


@pytest.mark.notimpl(["dask", "druid"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.notimpl(
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor"
)
Expand Down Expand Up @@ -112,7 +111,6 @@ def test_dataframe_interchange_column_methods_execute(con, alltypes, mocker):
assert col2.size() == pa_col2.size()


@pytest.mark.notimpl(["dask"])
@pytest.mark.notimpl(
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor"
)
Expand Down
33 changes: 13 additions & 20 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
]

no_limit = [
param(
None, id="nolimit", marks=[pytest.mark.notimpl(["dask", "impala", "pyspark"])]
)
param(None, id="nolimit", marks=[pytest.mark.notimpl(["impala", "pyspark"])])
]

limit_no_limit = limit + no_limit
Expand Down Expand Up @@ -117,7 +115,7 @@ def test_scalar_to_pyarrow_scalar(limit, awards_players):
}


@pytest.mark.notimpl(["dask", "impala", "pyspark", "druid"])
@pytest.mark.notimpl(["impala", "pyspark", "druid"])
def test_table_to_pyarrow_table_schema(con, awards_players):
table = awards_players.to_pyarrow()
assert isinstance(table, pa.Table)
Expand All @@ -136,7 +134,7 @@ def test_table_to_pyarrow_table_schema(con, awards_players):
assert table.schema == expected_schema


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
@pytest.mark.notimpl(["impala", "pyspark"])
def test_column_to_pyarrow_table_schema(awards_players):
expr = awards_players.awardID
array = expr.to_pyarrow()
Expand Down Expand Up @@ -193,15 +191,15 @@ def test_to_pyarrow_batches_borked_types(batting):
util.consume(batch_reader)


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
@pytest.mark.notimpl(["impala", "pyspark"])
def test_to_pyarrow_memtable(con):
expr = ibis.memtable({"x": [1, 2, 3]})
table = con.to_pyarrow(expr)
assert isinstance(table, pa.Table)
assert len(table) == 3


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
@pytest.mark.notimpl(["impala", "pyspark"])
def test_to_pyarrow_batches_memtable(con):
expr = ibis.memtable({"x": [1, 2, 3]})
n = 0
Expand All @@ -212,7 +210,7 @@ def test_to_pyarrow_batches_memtable(con):
assert n == 3


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
@pytest.mark.notimpl(["impala", "pyspark"])
def test_table_to_parquet(tmp_path, backend, awards_players):
outparquet = tmp_path / "out.parquet"
awards_players.to_parquet(outparquet)
Expand Down Expand Up @@ -265,9 +263,7 @@ def test_roundtrip_partitioned_parquet(tmp_path, con, backend, awards_players):
backend.assert_frame_equal(reingest.to_pandas(), awards_players.to_pandas())


@pytest.mark.notimpl(
["dask", "impala", "pyspark"], reason="No support for exporting files"
)
@pytest.mark.notimpl(["impala", "pyspark"], reason="No support for exporting files")
@pytest.mark.parametrize("ftype", ["csv", "parquet"])
def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):
"""
Expand All @@ -288,7 +284,7 @@ def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):
assert outfile.is_file()


@pytest.mark.notimpl(["dask", "impala", "pyspark"])
@pytest.mark.notimpl(["impala", "pyspark"])
def test_table_to_csv(tmp_path, backend, awards_players):
outcsv = tmp_path / "out.csv"

Expand All @@ -314,7 +310,6 @@ def test_table_to_csv(tmp_path, backend, awards_players):
["impala"], raises=AttributeError, reason="fetchmany doesn't exist"
),
pytest.mark.notyet(["druid"], raises=sa.exc.ProgrammingError),
pytest.mark.notyet(["dask"], raises=NotImplementedError),
pytest.mark.notyet(["pyspark"], raises=NotImplementedError),
],
),
Expand All @@ -329,7 +324,6 @@ def test_table_to_csv(tmp_path, backend, awards_players):
["druid", "snowflake", "trino"], raises=sa.exc.ProgrammingError
),
pytest.mark.notyet(["oracle"], raises=sa.exc.DatabaseError),
pytest.mark.notyet(["dask"], raises=NotImplementedError),
pytest.mark.notyet(["mssql", "mysql"], raises=sa.exc.OperationalError),
pytest.mark.notyet(["pyspark"], raises=ParseException),
],
Expand Down Expand Up @@ -390,7 +384,6 @@ def test_roundtrip_delta(con, alltypes, tmp_path, monkeypatch):
@pytest.mark.xfail_version(
duckdb=["duckdb<0.8.1"], raises=AssertionError, reason="bug in duckdb"
)
@pytest.mark.notimpl(["dask"], raises=NotImplementedError)
@pytest.mark.notimpl(
["druid"], raises=AttributeError, reason="string type is used for timestamp_col"
)
Expand Down Expand Up @@ -419,7 +412,7 @@ def test_arrow_timestamp_with_time_zone(alltypes):
assert batch.schema.types == expected


@pytest.mark.notimpl(["dask", "druid"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.notimpl(
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor"
)
Expand Down Expand Up @@ -447,7 +440,7 @@ def test_empty_memtable(backend, con):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["dask", "flink", "impala", "pyspark"])
@pytest.mark.notimpl(["flink", "impala", "pyspark"])
def test_to_pandas_batches_empty_table(backend, con):
t = backend.functional_alltypes.limit(0)
n = t.count().execute()
Expand All @@ -456,7 +449,7 @@ def test_to_pandas_batches_empty_table(backend, con):
assert sum(map(len, t.to_pandas_batches())) == n


@pytest.mark.notimpl(["dask", "druid", "flink", "impala", "pyspark"])
@pytest.mark.notimpl(["druid", "flink", "impala", "pyspark"])
@pytest.mark.parametrize("n", [None, 1])
def test_to_pandas_batches_nonempty_table(backend, con, n):
t = backend.functional_alltypes.limit(n)
Expand All @@ -466,7 +459,7 @@ def test_to_pandas_batches_nonempty_table(backend, con, n):
assert sum(map(len, t.to_pandas_batches())) == n


@pytest.mark.notimpl(["dask", "flink", "impala", "pyspark"])
@pytest.mark.notimpl(["flink", "impala", "pyspark"])
@pytest.mark.parametrize("n", [None, 0, 1, 2])
def test_to_pandas_batches_column(backend, con, n):
t = backend.functional_alltypes.limit(n).timestamp_col
Expand All @@ -476,7 +469,7 @@ def test_to_pandas_batches_column(backend, con, n):
assert sum(map(len, t.to_pandas_batches())) == n


@pytest.mark.notimpl(["dask", "druid", "flink", "impala", "pyspark"])
@pytest.mark.notimpl(["druid", "flink", "impala", "pyspark"])
def test_to_pandas_batches_scalar(backend, con):
t = backend.functional_alltypes.timestamp_col.max()
expected = t.execute()
Expand Down

0 comments on commit d2c5a1e

Please sign in to comment.