Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: centralize pyarrow Table to pandas conversions and types_mapper handling #60324

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions pandas/io/_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat import pa_version_under18p0
from pandas.compat._optional import import_optional_dependency

Expand All @@ -12,6 +18,10 @@
if TYPE_CHECKING:
from collections.abc import Callable

import pyarrow

from pandas._typing import DtypeBackend


def _arrow_dtype_mapping() -> dict:
pa = import_optional_dependency("pyarrow")
Expand All @@ -33,7 +43,7 @@ def _arrow_dtype_mapping() -> dict:
}


def arrow_string_types_mapper() -> Callable:
def _arrow_string_types_mapper() -> Callable:
pa = import_optional_dependency("pyarrow")

mapping = {
Expand All @@ -44,3 +54,31 @@ def arrow_string_types_mapper() -> Callable:
mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)

return mapping.get


def arrow_table_to_pandas(
table: pyarrow.Table,
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
null_to_int64: bool = False,
) -> pd.DataFrame:
pa = import_optional_dependency("pyarrow")

types_mapper: type[pd.ArrowDtype] | None | Callable
if dtype_backend == "numpy_nullable":
mapping = _arrow_dtype_mapping()
if null_to_int64:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is for compatability, but is this a feature or a bug that the CSV reader does this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, I didn't look in detail at why this is happening in the CSV reader, for now just wanded to keep the same behaviour (but this is certainly an ugly keyword, the problem with centralizing the conversion as I am doing, it's not otherwise possible to change it on the CSV side)

# Modify the default mapping to also map null to Int64
# (to match other engines - only for CSV parser)
mapping[pa.null()] = pd.Int64Dtype()
types_mapper = mapping.get
elif dtype_backend == "pyarrow":
types_mapper = pd.ArrowDtype
elif using_string_dtype():
types_mapper = _arrow_string_types_mapper()
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
types_mapper = None
else:
raise NotImplementedError

df = table.to_pandas(types_mapper=types_mapper)
return df
17 changes: 2 additions & 15 deletions pandas/io/feather_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.api import DataFrame
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import get_handle

if TYPE_CHECKING:
Expand Down Expand Up @@ -147,16 +146,4 @@ def read_feather(
pa_table = feather.read_table(
handles.handle, columns=columns, use_threads=bool(use_threads)
)

if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)

elif dtype_backend == "pyarrow":
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)

elif using_string_dtype():
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
else:
raise NotImplementedError
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
15 changes: 2 additions & 13 deletions pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from pandas.core.dtypes.dtypes import PeriodDtype

from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
Expand All @@ -48,6 +47,7 @@
from pandas.core.reshape.concat import concat
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
dedup_names,
Expand Down Expand Up @@ -940,18 +940,7 @@ def read(self) -> DataFrame | Series:
if self.engine == "pyarrow":
pyarrow_json = import_optional_dependency("pyarrow.json")
pa_table = pyarrow_json.read_json(self.data)

mapping: type[ArrowDtype] | None | Callable
if self.dtype_backend == "pyarrow":
mapping = ArrowDtype
elif self.dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
else:
mapping = None

return pa_table.to_pandas(types_mapper=mapping)
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
elif self.engine == "ujson":
if self.lines:
if self.chunksize:
Expand Down
21 changes: 2 additions & 19 deletions pandas/io/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
Literal,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.indexes.api import default_index

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
get_handle,
is_fsspec_url,
Expand Down Expand Up @@ -127,21 +124,7 @@ def read_orc(
pa_table = orc.read_table(
source=source, columns=columns, filesystem=filesystem, **kwargs
)
if dtype_backend is not lib.no_default:
if dtype_backend == "pyarrow":
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
else:
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
df = pa_table.to_pandas(types_mapper=mapping.get)
return df
else:
if using_string_dtype():
types_mapper = arrow_string_types_mapper()
else:
types_mapper = None
return pa_table.to_pandas(types_mapper=types_mapper)
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)


def to_orc(
Expand Down
18 changes: 2 additions & 16 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,19 @@
filterwarnings,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import AbstractMethodError
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas import (
DataFrame,
get_option,
)
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
get_handle,
Expand Down Expand Up @@ -249,17 +246,6 @@ def read(
) -> DataFrame:
kwargs["use_pandas_metadata"] = True

to_pandas_kwargs = {}
if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
to_pandas_kwargs["types_mapper"] = mapping.get
elif dtype_backend == "pyarrow":
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment]
elif using_string_dtype():
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()

path_or_handle, handles, filesystem = _get_path_or_handle(
path,
filesystem,
Expand All @@ -280,7 +266,7 @@ def read(
"make_block is deprecated",
DeprecationWarning,
)
result = pa_table.to_pandas(**to_pandas_kwargs)
result = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

if pa_table.schema.metadata:
if b"PANDAS_ATTRS" in pa_table.schema.metadata:
Expand Down
27 changes: 6 additions & 21 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
import warnings

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import (
Expand All @@ -16,18 +14,14 @@
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.inference import is_integer

import pandas as pd
from pandas import DataFrame

from pandas.io._util import (
_arrow_dtype_mapping,
arrow_string_types_mapper,
)
from pandas.io._util import arrow_table_to_pandas
from pandas.io.parsers.base_parser import ParserBase

if TYPE_CHECKING:
from pandas._typing import ReadBuffer

from pandas import DataFrame


class ArrowParserWrapper(ParserBase):
"""
Expand Down Expand Up @@ -293,17 +287,8 @@ def read(self) -> DataFrame:
"make_block is deprecated",
DeprecationWarning,
)
if dtype_backend == "pyarrow":
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
elif dtype_backend == "numpy_nullable":
# Modify the default mapping to also
# map null to Int64 (to match other engines)
dtype_mapping = _arrow_dtype_mapping()
dtype_mapping[pa.null()] = pd.Int64Dtype()
frame = table.to_pandas(types_mapper=dtype_mapping.get)
elif using_string_dtype():
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
frame = arrow_table_to_pandas(
table, dtype_backend=dtype_backend, null_to_int64=True
)

else:
frame = table.to_pandas()
return self._finalize_pandas_output(frame)
41 changes: 7 additions & 34 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
DatetimeTZDtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna

from pandas import get_option
Expand All @@ -67,6 +64,8 @@
from pandas.core.internals.construction import convert_object_array
from pandas.core.tools.datetimes import to_datetime

from pandas.io._util import arrow_table_to_pandas

if TYPE_CHECKING:
from collections.abc import (
Callable,
Expand Down Expand Up @@ -2208,23 +2207,10 @@ def read_table(
else:
stmt = f"SELECT {select_list} FROM {table_name}"

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(stmt)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down Expand Up @@ -2292,23 +2278,10 @@ def read_query(
if chunksize:
raise NotImplementedError("'chunksize' is not implemented for ADBC drivers")

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(sql)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data):

adbc_connectable_iris = [
pytest.param("postgresql_adbc_iris", marks=pytest.mark.db),
pytest.param("sqlite_adbc_iris", marks=pytest.mark.db),
"sqlite_adbc_iris",
]

adbc_connectable_types = [
pytest.param("postgresql_adbc_types", marks=pytest.mark.db),
pytest.param("sqlite_adbc_types", marks=pytest.mark.db),
"sqlite_adbc_types",
]


Expand Down