diff --git a/pandas/io/_util.py b/pandas/io/_util.py index 9a8c87a738d4c..21203ad036fc6 100644 --- a/pandas/io/_util.py +++ b/pandas/io/_util.py @@ -1,9 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Literal, +) import numpy as np +from pandas._config import using_string_dtype + +from pandas._libs import lib from pandas.compat import pa_version_under18p0 from pandas.compat._optional import import_optional_dependency @@ -12,6 +18,10 @@ if TYPE_CHECKING: from collections.abc import Callable + import pyarrow + + from pandas._typing import DtypeBackend + def _arrow_dtype_mapping() -> dict: pa = import_optional_dependency("pyarrow") @@ -33,7 +43,7 @@ def _arrow_dtype_mapping() -> dict: } -def arrow_string_types_mapper() -> Callable: +def _arrow_string_types_mapper() -> Callable: pa = import_optional_dependency("pyarrow") mapping = { @@ -44,3 +54,31 @@ def arrow_string_types_mapper() -> Callable: mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan) return mapping.get + + +def arrow_table_to_pandas( + table: pyarrow.Table, + dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default, + null_to_int64: bool = False, +) -> pd.DataFrame: + pa = import_optional_dependency("pyarrow") + + types_mapper: type[pd.ArrowDtype] | None | Callable + if dtype_backend == "numpy_nullable": + mapping = _arrow_dtype_mapping() + if null_to_int64: + # Modify the default mapping to also map null to Int64 + # (to match other engines - only for CSV parser) + mapping[pa.null()] = pd.Int64Dtype() + types_mapper = mapping.get + elif dtype_backend == "pyarrow": + types_mapper = pd.ArrowDtype + elif using_string_dtype(): + types_mapper = _arrow_string_types_mapper() + elif dtype_backend is lib.no_default or dtype_backend == "numpy": + types_mapper = None + else: + raise NotImplementedError + + df = table.to_pandas(types_mapper=types_mapper) + return df diff --git a/pandas/io/feather_format.py b/pandas/io/feather_format.py index aaae9857b4fae..7b4c81853eba3 100644 --- a/pandas/io/feather_format.py +++ b/pandas/io/feather_format.py @@ -15,11 +15,10 @@ from pandas.util._decorators import doc from pandas.util._validators import check_dtype_backend -import pandas as pd from pandas.core.api import DataFrame from pandas.core.shared_docs import _shared_docs -from pandas.io._util import arrow_string_types_mapper +from pandas.io._util import arrow_table_to_pandas from pandas.io.common import get_handle if TYPE_CHECKING: @@ -147,16 +146,4 @@ def read_feather( pa_table = feather.read_table( handles.handle, columns=columns, use_threads=bool(use_threads) ) - - if dtype_backend == "numpy_nullable": - from pandas.io._util import _arrow_dtype_mapping - - return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get) - - elif dtype_backend == "pyarrow": - return pa_table.to_pandas(types_mapper=pd.ArrowDtype) - - elif using_string_dtype(): - return pa_table.to_pandas(types_mapper=arrow_string_types_mapper()) - else: - raise NotImplementedError + return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) diff --git a/pandas/io/json/_json.py b/pandas/io/json/_json.py index e9c9f5ba225a5..983780f81043f 100644 --- a/pandas/io/json/_json.py +++ b/pandas/io/json/_json.py @@ -36,7 +36,6 @@ from pandas.core.dtypes.dtypes import PeriodDtype from pandas import ( - ArrowDtype, DataFrame, Index, MultiIndex, @@ -48,6 +47,7 @@ from pandas.core.reshape.concat import concat from pandas.core.shared_docs import _shared_docs +from pandas.io._util import arrow_table_to_pandas from pandas.io.common import ( IOHandles, dedup_names, @@ -940,18 +940,7 @@ def read(self) -> DataFrame | Series: if self.engine == "pyarrow": pyarrow_json = import_optional_dependency("pyarrow.json") pa_table = pyarrow_json.read_json(self.data) - - mapping: type[ArrowDtype] | None | Callable - if self.dtype_backend == "pyarrow": - mapping = ArrowDtype - elif self.dtype_backend == "numpy_nullable": - from pandas.io._util import _arrow_dtype_mapping - - mapping = _arrow_dtype_mapping().get - else: - mapping = None - - return pa_table.to_pandas(types_mapper=mapping) + return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend) elif self.engine == "ujson": if self.lines: if self.chunksize: diff --git a/pandas/io/orc.py b/pandas/io/orc.py index f179dafc919e5..a945f3dc38d35 100644 --- a/pandas/io/orc.py +++ b/pandas/io/orc.py @@ -9,16 +9,13 @@ Literal, ) -from pandas._config import using_string_dtype - from pandas._libs import lib from pandas.compat._optional import import_optional_dependency from pandas.util._validators import check_dtype_backend -import pandas as pd from pandas.core.indexes.api import default_index -from pandas.io._util import arrow_string_types_mapper +from pandas.io._util import arrow_table_to_pandas from pandas.io.common import ( get_handle, is_fsspec_url, @@ -127,21 +124,7 @@ def read_orc( pa_table = orc.read_table( source=source, columns=columns, filesystem=filesystem, **kwargs ) - if dtype_backend is not lib.no_default: - if dtype_backend == "pyarrow": - df = pa_table.to_pandas(types_mapper=pd.ArrowDtype) - else: - from pandas.io._util import _arrow_dtype_mapping - - mapping = _arrow_dtype_mapping() - df = pa_table.to_pandas(types_mapper=mapping.get) - return df - else: - if using_string_dtype(): - types_mapper = arrow_string_types_mapper() - else: - types_mapper = None - return pa_table.to_pandas(types_mapper=types_mapper) + return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) def to_orc( diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index 24415299e799b..116f228faca93 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -15,22 +15,19 @@ filterwarnings, ) -from pandas._config import using_string_dtype - from pandas._libs import lib from pandas.compat._optional import import_optional_dependency from pandas.errors import AbstractMethodError from pandas.util._decorators import doc from pandas.util._validators import check_dtype_backend -import pandas as pd from pandas import ( DataFrame, get_option, ) from pandas.core.shared_docs import _shared_docs -from pandas.io._util import arrow_string_types_mapper +from pandas.io._util import arrow_table_to_pandas from pandas.io.common import ( IOHandles, get_handle, @@ -249,17 +246,6 @@ def read( ) -> DataFrame: kwargs["use_pandas_metadata"] = True - to_pandas_kwargs = {} - if dtype_backend == "numpy_nullable": - from pandas.io._util import _arrow_dtype_mapping - - mapping = _arrow_dtype_mapping() - to_pandas_kwargs["types_mapper"] = mapping.get - elif dtype_backend == "pyarrow": - to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment] - elif using_string_dtype(): - to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper() - path_or_handle, handles, filesystem = _get_path_or_handle( path, filesystem, @@ -280,7 +266,7 @@ def read( "make_block is deprecated", DeprecationWarning, ) - result = pa_table.to_pandas(**to_pandas_kwargs) + result = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) if pa_table.schema.metadata: if b"PANDAS_ATTRS" in pa_table.schema.metadata: diff --git a/pandas/io/parsers/arrow_parser_wrapper.py b/pandas/io/parsers/arrow_parser_wrapper.py index 86bb5f190e403..672672490996d 100644 --- a/pandas/io/parsers/arrow_parser_wrapper.py +++ b/pandas/io/parsers/arrow_parser_wrapper.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING import warnings -from pandas._config import using_string_dtype - from pandas._libs import lib from pandas.compat._optional import import_optional_dependency from pandas.errors import ( @@ -16,18 +14,14 @@ from pandas.core.dtypes.common import pandas_dtype from pandas.core.dtypes.inference import is_integer -import pandas as pd -from pandas import DataFrame - -from pandas.io._util import ( - _arrow_dtype_mapping, - arrow_string_types_mapper, -) +from pandas.io._util import arrow_table_to_pandas from pandas.io.parsers.base_parser import ParserBase if TYPE_CHECKING: from pandas._typing import ReadBuffer + from pandas import DataFrame + class ArrowParserWrapper(ParserBase): """ @@ -293,17 +287,8 @@ def read(self) -> DataFrame: "make_block is deprecated", DeprecationWarning, ) - if dtype_backend == "pyarrow": - frame = table.to_pandas(types_mapper=pd.ArrowDtype) - elif dtype_backend == "numpy_nullable": - # Modify the default mapping to also - # map null to Int64 (to match other engines) - dtype_mapping = _arrow_dtype_mapping() - dtype_mapping[pa.null()] = pd.Int64Dtype() - frame = table.to_pandas(types_mapper=dtype_mapping.get) - elif using_string_dtype(): - frame = table.to_pandas(types_mapper=arrow_string_types_mapper()) + frame = arrow_table_to_pandas( + table, dtype_backend=dtype_backend, null_to_int64=True + ) - else: - frame = table.to_pandas() return self._finalize_pandas_output(frame) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 125ca51a456d8..3c0c5cc64c24c 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -48,10 +48,7 @@ is_object_dtype, is_string_dtype, ) -from pandas.core.dtypes.dtypes import ( - ArrowDtype, - DatetimeTZDtype, -) +from pandas.core.dtypes.dtypes import DatetimeTZDtype from pandas.core.dtypes.missing import isna from pandas import get_option @@ -67,6 +64,8 @@ from pandas.core.internals.construction import convert_object_array from pandas.core.tools.datetimes import to_datetime +from pandas.io._util import arrow_table_to_pandas + if TYPE_CHECKING: from collections.abc import ( Callable, @@ -2208,23 +2207,10 @@ def read_table( else: stmt = f"SELECT {select_list} FROM {table_name}" - mapping: type[ArrowDtype] | None | Callable - if dtype_backend == "pyarrow": - mapping = ArrowDtype - elif dtype_backend == "numpy_nullable": - from pandas.io._util import _arrow_dtype_mapping - - mapping = _arrow_dtype_mapping().get - elif using_string_dtype(): - from pandas.io._util import arrow_string_types_mapper - - mapping = arrow_string_types_mapper() - else: - mapping = None - with self.con.cursor() as cur: cur.execute(stmt) - df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping) + pa_table = cur.fetch_arrow_table() + df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) return _wrap_result_adbc( df, @@ -2292,23 +2278,10 @@ def read_query( if chunksize: raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") - mapping: type[ArrowDtype] | None | Callable - if dtype_backend == "pyarrow": - mapping = ArrowDtype - elif dtype_backend == "numpy_nullable": - from pandas.io._util import _arrow_dtype_mapping - - mapping = _arrow_dtype_mapping().get - elif using_string_dtype(): - from pandas.io._util import arrow_string_types_mapper - - mapping = arrow_string_types_mapper() - else: - mapping = None - with self.con.cursor() as cur: cur.execute(sql) - df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping) + pa_table = cur.fetch_arrow_table() + df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) return _wrap_result_adbc( df, diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 96d63d3fe25e5..7e1220ecee218 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data): adbc_connectable_iris = [ pytest.param("postgresql_adbc_iris", marks=pytest.mark.db), - pytest.param("sqlite_adbc_iris", marks=pytest.mark.db), + "sqlite_adbc_iris", ] adbc_connectable_types = [ pytest.param("postgresql_adbc_types", marks=pytest.mark.db), - pytest.param("sqlite_adbc_types", marks=pytest.mark.db), + "sqlite_adbc_types", ]