Skip to content

Commit

Permalink
Increase minimum supported pyarrow to 7.0 (dask#10024)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Mar 8, 2023
1 parent 809804c commit c36fe08
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 142 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- fsspec
# Pin until sqlalchemy 2 support is added https://github.com/dask/dask/issues/9896
- sqlalchemy>=1.4.0,<2
- pyarrow
- pyarrow=10
- coverage
- jsonschema
# other -- IO
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- fsspec
# Pin until sqlalchemy 2 support is added https://github.com/dask/dask/issues/9896
- sqlalchemy>=1.4.0,<2
- pyarrow>=10
- pyarrow>=11
- coverage
- jsonschema
# # other -- IO
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- fsspec
# Pin until sqlalchemy 2 support is added https://github.com/dask/dask/issues/9896
- sqlalchemy>=1.4.0,<2
- pyarrow=4.0
- pyarrow=7
- coverage
- jsonschema
# other -- IO
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- fsspec
# Pin until sqlalchemy 2 support is added https://github.com/dask/dask/issues/9896
- sqlalchemy>=1.4.0,<2
- pyarrow
- pyarrow=9
- coverage
- jsonschema
# other -- IO
Expand Down
16 changes: 1 addition & 15 deletions dask/bytes/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from functools import partial

import pytest
from packaging.version import parse as parse_version

s3fs = pytest.importorskip("s3fs")
boto3 = pytest.importorskip("boto3")
Expand Down Expand Up @@ -441,25 +440,12 @@ def test_modification_time_read_bytes(s3, s3so):
@pytest.mark.parametrize("engine", ["pyarrow", "fastparquet"])
@pytest.mark.parametrize("metadata_file", [True, False])
def test_parquet(s3, engine, s3so, metadata_file):
import s3fs

dd = pytest.importorskip("dask.dataframe")
pd = pytest.importorskip("pandas")
np = pytest.importorskip("numpy")

lib = pytest.importorskip(engine)
lib_version = parse_version(lib.__version__)
if engine == "pyarrow" and lib_version < parse_version("0.13.1"):
pytest.skip("pyarrow < 0.13.1 not supported for parquet")
if (
engine == "pyarrow"
and lib_version.major == 2
and parse_version(s3fs.__version__) > parse_version("0.5.0")
):
pytest.skip("#7056 - new s3fs not supported before pyarrow 3.0")
pytest.importorskip(engine)

url = "s3://%s/test.parquet" % test_bucket_name

data = pd.DataFrame(
{
"i32": np.arange(1000, dtype=np.int32),
Expand Down
6 changes: 0 additions & 6 deletions dask/dataframe/io/orc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from fsspec.core import get_fs_token_paths
from fsspec.utils import stringify_path
from packaging.version import parse as parse_version

from dask.base import compute_as_if_collection, tokenize
from dask.dataframe.backends import dataframe_creation_dispatch
Expand Down Expand Up @@ -57,13 +56,8 @@ def __call__(self, parts):
def _get_engine(engine, write=False):
# Get engine
if engine == "pyarrow":
import pyarrow as pa

from dask.dataframe.io.orc.arrow import ArrowORCEngine

if write and parse_version(pa.__version__) < parse_version("4.0.0"):
raise ValueError("to_orc is not supported for pyarrow<4.0.0")

return ArrowORCEngine
elif not isinstance(engine, ORCEngine):
raise TypeError("engine must be 'pyarrow', or an ORCEngine object")
Expand Down
99 changes: 35 additions & 64 deletions dask/dataframe/io/parquet/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from packaging.version import parse as parse_version

# Check PyArrow version for feature support
from fsspec.core import expand_paths_if_needed, stringify_path
from fsspec.implementations.arrow import ArrowFSWrapper
from pyarrow import dataset as pa_ds
from pyarrow import fs as pa_fs

from dask import config
from dask.base import tokenize
Expand All @@ -31,19 +36,6 @@
from dask.delayed import Delayed
from dask.utils import getargspec, natural_sort_key

# Check PyArrow version for feature support
_pa_version = parse_version(pa.__version__)
from fsspec.core import expand_paths_if_needed, stringify_path
from fsspec.implementations.arrow import ArrowFSWrapper
from pyarrow import dataset as pa_ds
from pyarrow import fs as pa_fs

subset_stats_supported = _pa_version > parse_version("2.0.0")
pre_buffer_supported = _pa_version >= parse_version("5.0.0")
partitioning_supported = _pa_version >= parse_version("5.0.0")
nan_is_null_supported = _pa_version >= parse_version("6.0.0")
del _pa_version

PYARROW_NULLABLE_DTYPE_MAPPING = {
pa.int8(): pd.Int8Dtype(),
pa.int16(): pd.Int16Dtype(),
Expand Down Expand Up @@ -246,11 +238,7 @@ def _read_table_from_path(
# "pre-caching" method isn't already specified in `precache_options`
# (The distinct fsspec and pyarrow optimizations will conflict)
pre_buffer_default = precache_options.get("method", None) is None
pre_buffer = (
{"pre_buffer": read_kwargs.pop("pre_buffer", pre_buffer_default)}
if pre_buffer_supported
else {}
)
pre_buffer = {"pre_buffer": read_kwargs.pop("pre_buffer", pre_buffer_default)}

with _open_input_files(
[path],
Expand Down Expand Up @@ -285,37 +273,31 @@ def _get_rg_statistics(row_group, col_names):
statistics for all columns.
"""

if subset_stats_supported:
row_group_schema = {
col_name: i for i, col_name in enumerate(row_group.schema.names)
}

def name_stats(column_name):
col = row_group.metadata.column(row_group_schema[column_name])
row_group_schema = {
col_name: i for i, col_name in enumerate(row_group.schema.names)
}

stats = col.statistics
if stats is None or not stats.has_min_max:
return None, None
def name_stats(column_name):
col = row_group.metadata.column(row_group_schema[column_name])

name = col.path_in_schema
field_index = row_group.schema.get_field_index(name)
if field_index < 0:
return None, None
stats = col.statistics
if stats is None or not stats.has_min_max:
return None, None

return col.path_in_schema, {
"min": stats.min,
"max": stats.max,
"null_count": stats.null_count,
}
name = col.path_in_schema
field_index = row_group.schema.get_field_index(name)
if field_index < 0:
return None, None

return {
name: stats
for name, stats in map(name_stats, col_names)
if stats is not None
return col.path_in_schema, {
"min": stats.min,
"max": stats.max,
"null_count": stats.null_count,
}

else:
return row_group.statistics
return {
name: stats for name, stats in map(name_stats, col_names) if stats is not None
}


def _need_fragments(filters, partition_keys):
Expand All @@ -341,7 +323,6 @@ def _filters_to_expression(filters, propagate_null=False, nan_is_null=True):
# handling is resolved.
# See: https://github.com/dask/dask/issues/9845

nan_kwargs = dict(nan_is_null=nan_is_null) if nan_is_null_supported else {}
if isinstance(filters, pa_ds.Expression):
return filters

Expand All @@ -360,9 +341,9 @@ def convert_single_predicate(col, op, val):
# Handle null-value comparison
if val is None or (nan_is_null and val is np.nan):
if op == "is":
return field.is_null(**nan_kwargs)
return field.is_null(nan_is_null=nan_is_null)
elif op == "is not":
return ~field.is_null(**nan_kwargs)
return ~field.is_null(nan_is_null=nan_is_null)
else:
raise ValueError(
f'"{(col, op, val)}" is not a supported predicate '
Expand Down Expand Up @@ -392,7 +373,7 @@ def convert_single_predicate(col, op, val):

# (Optionally) Avoid null-value propagation
if not propagate_null and op in ("!=", "not in"):
return field.is_null(**nan_kwargs) | expr
return field.is_null(nan_is_null=nan_is_null) | expr
return expr

disjunction_members = []
Expand Down Expand Up @@ -1022,18 +1003,10 @@ def _collect_dataset_info(
# Get all partition keys (without filters) to populate partition_obj
partition_obj = [] # See `partition_info` description below
hive_categories = defaultdict(list)
file_frag = None
for file_frag in ds.get_fragments():
if partitioning_supported:
# Can avoid manual category discovery for pyarrow>=5.0.0
break
keys = pa_ds._get_partition_keys(file_frag.partition_expression)
if not (keys or hive_categories):
break # Bail - This is not a hive-partitioned dataset
for k, v in keys.items():
if v not in hive_categories[k]:
hive_categories[k].append(v)

try:
file_frag = next(ds.get_fragments())
except StopIteration:
file_frag = None
physical_schema = ds.schema
if file_frag is not None:
physical_schema = file_frag.physical_schema
Expand Down Expand Up @@ -1068,10 +1041,8 @@ def _collect_dataset_info(
k: hive_categories[k] for k in cat_keys if k in hive_categories
}

if (
partitioning_supported
and ds.partitioning.dictionaries
and all(arr is not None for arr in ds.partitioning.dictionaries)
if ds.partitioning.dictionaries and all(
arr is not None for arr in ds.partitioning.dictionaries
):
# Use ds.partitioning for pyarrow>=5.0.0
partition_names = list(ds.partitioning.schema.names)
Expand Down
10 changes: 1 addition & 9 deletions dask/dataframe/io/parquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import tlz as toolz
from fsspec.core import get_fs_token_paths
from fsspec.utils import stringify_path
from packaging.version import parse as parse_version

import dask
from dask.base import tokenize
Expand Down Expand Up @@ -1239,19 +1238,12 @@ def get_engine(engine):
return eng

elif engine in ("pyarrow", "arrow", "pyarrow-dataset"):
pa = import_required("pyarrow", "`pyarrow` not installed")
pa_version = parse_version(pa.__version__)
import_required("pyarrow", "`pyarrow` not installed")

if engine in ("pyarrow-dataset", "arrow"):
engine = "pyarrow"

if engine == "pyarrow":
if pa_version.major < 1:
raise ImportError(
f"pyarrow-{pa_version.major} does not support the "
f"pyarrow.dataset API. Please install pyarrow>=1."
)

from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine

_ENGINES[engine] = eng = ArrowDatasetEngine
Expand Down
17 changes: 0 additions & 17 deletions dask/dataframe/io/tests/test_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pandas as pd
import pytest
from packaging.version import parse as parse_version

import dask.dataframe as dd
from dask.dataframe.optimize import optimize_dataframe_getitem
Expand Down Expand Up @@ -78,10 +77,6 @@ def test_orc_multiple(orc_files):
assert_eq(d2[columns], dd.concat([d, d])[columns], check_index=False)


@pytest.mark.skipif(
parse_version(pa.__version__) < parse_version("4.0.0"),
reason=("PyArrow>=4.0.0 required for ORC write support."),
)
@pytest.mark.parametrize("index", [None, "i32"])
@pytest.mark.parametrize("columns", [None, ["i32", "i64", "f"]])
def test_orc_roundtrip(tmpdir, index, columns):
Expand Down Expand Up @@ -110,10 +105,6 @@ def test_orc_roundtrip(tmpdir, index, columns):
assert_eq(data, df2, check_index=bool(index))


@pytest.mark.skipif(
parse_version(pa.__version__) < parse_version("4.0.0"),
reason=("PyArrow>=4.0.0 required for ORC write support."),
)
@pytest.mark.parametrize("split_stripes", [True, False, 2, 4])
def test_orc_roundtrip_aggregate_files(tmpdir, split_stripes):
tmp = str(tmpdir)
Expand Down Expand Up @@ -147,10 +138,6 @@ def test_orc_aggregate_files_offset(orc_files):
assert len(df2.partitions[0].index) > len(df2.index) // 2


@pytest.mark.skipif(
parse_version(pa.__version__) < parse_version("4.0.0"),
reason=("PyArrow>=4.0.0 required for ORC write support."),
)
@pytest.mark.network
def test_orc_names(orc_files, tmp_path):
df = dd.read_orc(orc_files)
Expand All @@ -159,10 +146,6 @@ def test_orc_names(orc_files, tmp_path):
assert out._name.startswith("to-orc")


@pytest.mark.skipif(
parse_version(pa.__version__) < parse_version("4.0.0"),
reason=("PyArrow>=4.0.0 required for ORC write support."),
)
def test_to_orc_delayed(tmp_path):
# See: https://github.com/dask/dask/issues/8022
df = pd.DataFrame(np.random.randn(100, 4), columns=["a", "b", "c", "d"])
Expand Down
Loading

0 comments on commit c36fe08

Please sign in to comment.