Skip to content

Commit

Permalink
REF: centralize pyarrow Table to pandas conversions and types_mapper …
Browse files Browse the repository at this point in the history
…handling (#60324)
  • Loading branch information
jorisvandenbossche authored Nov 15, 2024
1 parent ee3c18f commit 12d6f60
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 122 deletions.
42 changes: 40 additions & 2 deletions pandas/io/_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat import pa_version_under18p0
from pandas.compat._optional import import_optional_dependency

Expand All @@ -12,6 +18,10 @@
if TYPE_CHECKING:
from collections.abc import Callable

import pyarrow

from pandas._typing import DtypeBackend


def _arrow_dtype_mapping() -> dict:
pa = import_optional_dependency("pyarrow")
Expand All @@ -33,7 +43,7 @@ def _arrow_dtype_mapping() -> dict:
}


def arrow_string_types_mapper() -> Callable:
def _arrow_string_types_mapper() -> Callable:
pa = import_optional_dependency("pyarrow")

mapping = {
Expand All @@ -44,3 +54,31 @@ def arrow_string_types_mapper() -> Callable:
mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)

return mapping.get


def arrow_table_to_pandas(
table: pyarrow.Table,
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
null_to_int64: bool = False,
) -> pd.DataFrame:
pa = import_optional_dependency("pyarrow")

types_mapper: type[pd.ArrowDtype] | None | Callable
if dtype_backend == "numpy_nullable":
mapping = _arrow_dtype_mapping()
if null_to_int64:
# Modify the default mapping to also map null to Int64
# (to match other engines - only for CSV parser)
mapping[pa.null()] = pd.Int64Dtype()
types_mapper = mapping.get
elif dtype_backend == "pyarrow":
types_mapper = pd.ArrowDtype
elif using_string_dtype():
types_mapper = _arrow_string_types_mapper()
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
types_mapper = None
else:
raise NotImplementedError

df = table.to_pandas(types_mapper=types_mapper)
return df
17 changes: 2 additions & 15 deletions pandas/io/feather_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.api import DataFrame
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import get_handle

if TYPE_CHECKING:
Expand Down Expand Up @@ -147,16 +146,4 @@ def read_feather(
pa_table = feather.read_table(
handles.handle, columns=columns, use_threads=bool(use_threads)
)

if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)

elif dtype_backend == "pyarrow":
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)

elif using_string_dtype():
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
else:
raise NotImplementedError
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
15 changes: 2 additions & 13 deletions pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from pandas.core.dtypes.dtypes import PeriodDtype

from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
Expand All @@ -48,6 +47,7 @@
from pandas.core.reshape.concat import concat
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
dedup_names,
Expand Down Expand Up @@ -940,18 +940,7 @@ def read(self) -> DataFrame | Series:
if self.engine == "pyarrow":
pyarrow_json = import_optional_dependency("pyarrow.json")
pa_table = pyarrow_json.read_json(self.data)

mapping: type[ArrowDtype] | None | Callable
if self.dtype_backend == "pyarrow":
mapping = ArrowDtype
elif self.dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
else:
mapping = None

return pa_table.to_pandas(types_mapper=mapping)
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
elif self.engine == "ujson":
if self.lines:
if self.chunksize:
Expand Down
21 changes: 2 additions & 19 deletions pandas/io/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
Literal,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.indexes.api import default_index

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
get_handle,
is_fsspec_url,
Expand Down Expand Up @@ -127,21 +124,7 @@ def read_orc(
pa_table = orc.read_table(
source=source, columns=columns, filesystem=filesystem, **kwargs
)
if dtype_backend is not lib.no_default:
if dtype_backend == "pyarrow":
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
else:
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
df = pa_table.to_pandas(types_mapper=mapping.get)
return df
else:
if using_string_dtype():
types_mapper = arrow_string_types_mapper()
else:
types_mapper = None
return pa_table.to_pandas(types_mapper=types_mapper)
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)


def to_orc(
Expand Down
18 changes: 2 additions & 16 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,19 @@
filterwarnings,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import AbstractMethodError
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas import (
DataFrame,
get_option,
)
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
get_handle,
Expand Down Expand Up @@ -249,17 +246,6 @@ def read(
) -> DataFrame:
kwargs["use_pandas_metadata"] = True

to_pandas_kwargs = {}
if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
to_pandas_kwargs["types_mapper"] = mapping.get
elif dtype_backend == "pyarrow":
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment]
elif using_string_dtype():
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()

path_or_handle, handles, filesystem = _get_path_or_handle(
path,
filesystem,
Expand All @@ -280,7 +266,7 @@ def read(
"make_block is deprecated",
DeprecationWarning,
)
result = pa_table.to_pandas(**to_pandas_kwargs)
result = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

if pa_table.schema.metadata:
if b"PANDAS_ATTRS" in pa_table.schema.metadata:
Expand Down
27 changes: 6 additions & 21 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
import warnings

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import (
Expand All @@ -16,18 +14,14 @@
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.inference import is_integer

import pandas as pd
from pandas import DataFrame

from pandas.io._util import (
_arrow_dtype_mapping,
arrow_string_types_mapper,
)
from pandas.io._util import arrow_table_to_pandas
from pandas.io.parsers.base_parser import ParserBase

if TYPE_CHECKING:
from pandas._typing import ReadBuffer

from pandas import DataFrame


class ArrowParserWrapper(ParserBase):
"""
Expand Down Expand Up @@ -293,17 +287,8 @@ def read(self) -> DataFrame:
"make_block is deprecated",
DeprecationWarning,
)
if dtype_backend == "pyarrow":
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
elif dtype_backend == "numpy_nullable":
# Modify the default mapping to also
# map null to Int64 (to match other engines)
dtype_mapping = _arrow_dtype_mapping()
dtype_mapping[pa.null()] = pd.Int64Dtype()
frame = table.to_pandas(types_mapper=dtype_mapping.get)
elif using_string_dtype():
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
frame = arrow_table_to_pandas(
table, dtype_backend=dtype_backend, null_to_int64=True
)

else:
frame = table.to_pandas()
return self._finalize_pandas_output(frame)
41 changes: 7 additions & 34 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
DatetimeTZDtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna

from pandas import get_option
Expand All @@ -67,6 +64,8 @@
from pandas.core.internals.construction import convert_object_array
from pandas.core.tools.datetimes import to_datetime

from pandas.io._util import arrow_table_to_pandas

if TYPE_CHECKING:
from collections.abc import (
Callable,
Expand Down Expand Up @@ -2208,23 +2207,10 @@ def read_table(
else:
stmt = f"SELECT {select_list} FROM {table_name}"

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(stmt)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down Expand Up @@ -2292,23 +2278,10 @@ def read_query(
if chunksize:
raise NotImplementedError("'chunksize' is not implemented for ADBC drivers")

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(sql)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data):

adbc_connectable_iris = [
pytest.param("postgresql_adbc_iris", marks=pytest.mark.db),
pytest.param("sqlite_adbc_iris", marks=pytest.mark.db),
"sqlite_adbc_iris",
]

adbc_connectable_types = [
pytest.param("postgresql_adbc_types", marks=pytest.mark.db),
pytest.param("sqlite_adbc_types", marks=pytest.mark.db),
"sqlite_adbc_types",
]


Expand Down

0 comments on commit 12d6f60

Please sign in to comment.