Skip to content

Commit

Permalink
Test round-tripping dataframe parquet I/O including pyspark (dask#9156)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Rose authored Jun 8, 2022
1 parent 6369cdb commit 22915dc
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 8 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ jobs:
uses: actions/checkout@v2
with:
fetch-depth: 0 # Needed by codecov.io

- name: Setup Java
uses: actions/setup-java@v3
if: ${{ matrix.os == 'ubuntu-latest' }}
with:
distribution: "zulu"
java-version: "11"
- name: Setup Conda Environment
uses: conda-incubator/setup-miniconda@v2
with:
Expand Down
1 change: 1 addition & 0 deletions continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- pytables
- zarr
- tiledb-py
- pyspark
# resolver was pulling in old versions, so hard-coding floor
# https://github.com/dask/dask/pull/8505
- tiledb>=2.5.0
Expand Down
1 change: 1 addition & 0 deletions continuous_integration/environment-3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- pytables
- zarr
- tiledb-py
- pyspark
# resolver was pulling in old versions, so hard-coding floor
# https://github.com/dask/dask/pull/8505
- tiledb>=2.5.0
Expand Down
1 change: 1 addition & 0 deletions continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- pytables
- zarr
- tiledb-py
- pyspark
# resolver was pulling in old versions, so hard-coding floor
# https://github.com/dask/dask/pull/8505
- tiledb>=2.5.0
Expand Down
3 changes: 1 addition & 2 deletions dask/dataframe/io/parquet/fastparquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def _collect_dataset_info(
paths = [
path for path in paths if path.endswith(parquet_file_extension)
]
fns = [fn for fn in fns if fn.endswith(parquet_file_extension)]
if len0 and paths == []:
raise ValueError(
"No files satisfy the `parquet_file_extension` criteria "
Expand Down Expand Up @@ -535,13 +536,11 @@ def _create_dd_meta(cls, dataset_info):
) = _parse_pandas_metadata(pandas_md)
# auto-ranges should not be created by fastparquet
column_names.extend(pf.cats)

else:
index_names = []
column_names = pf.columns + list(pf.cats)
storage_name_mapping = {k: k for k in column_names}
column_index_names = [None]

if index is None and len(index_names) > 0:
if len(index_names) == 1 and index_names[0] is not None:
index = index_names[0]
Expand Down
20 changes: 15 additions & 5 deletions dask/dataframe/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,18 +3851,27 @@ def test_metadata_task_size(tmpdir, engine, write_metadata_file, metadata_task_s
assert_eq(ddf2b, ddf2c)


def test_extra_file(tmpdir, engine):
@pytest.mark.parametrize("partition_on", ("b", None))
def test_extra_file(tmpdir, engine, partition_on):
# Check that read_parquet can handle spark output
# See: https://github.com/dask/dask/issues/8087
tmpdir = str(tmpdir)
df = pd.DataFrame({"a": range(100), "b": ["dog", "cat"] * 50})
df = df.assign(b=df.b.astype("category"))
ddf = dd.from_pandas(df, npartitions=2)
ddf.to_parquet(tmpdir, engine=engine, write_metadata_file=True)
ddf.to_parquet(
tmpdir,
engine=engine,
write_metadata_file=True,
partition_on=partition_on,
)
open(os.path.join(tmpdir, "_SUCCESS"), "w").close()
open(os.path.join(tmpdir, "part.0.parquet.crc"), "w").close()
os.remove(os.path.join(tmpdir, "_metadata"))
out = dd.read_parquet(tmpdir, engine=engine, calculate_divisions=True)
assert_eq(out, df)
# Weird two-step since that we don't care if category ordering changes
assert_eq(out, df, check_categorical=False)
assert_eq(out.b, df.b, check_category_order=False)

# For "fastparquet" and "pyarrow", we can pass the
# expected file extension, or avoid checking file extensions
Expand All @@ -3884,7 +3893,9 @@ def _parquet_file_extension(val, legacy=False):
**_parquet_file_extension(".parquet"),
calculate_divisions=True,
)
assert_eq(out, df)
# Weird two-step since that we don't care if category ordering changes
assert_eq(out, df, check_categorical=False)
assert_eq(out.b, df.b, check_category_order=False)

# Should Work (with FutureWarning)
with pytest.warns(FutureWarning, match="require_extension is deprecated"):
Expand All @@ -3894,7 +3905,6 @@ def _parquet_file_extension(val, legacy=False):
**_parquet_file_extension(".parquet", legacy=True),
calculate_divisions=True,
)
assert_eq(out, df)

# Should Fail (for not capturing the _SUCCESS and crc files)
with pytest.raises((OSError, pa.lib.ArrowInvalid)):
Expand Down
109 changes: 109 additions & 0 deletions dask/tests/test_spark_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import signal
import sys
import threading

import pytest

from dask.datasets import timeseries

dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")

from dask.dataframe.utils import assert_eq

if not sys.platform.startswith("linux"):
pytest.skip(
"Unnecessary, and hard to get spark working on non-linux platforms",
allow_module_level=True,
)

# pyspark auto-converts timezones -- round-tripping timestamps is easier if
# we set everything to UTC.
pdf = timeseries(freq="1H").compute()
pdf.index = pdf.index.tz_localize("UTC")
pdf = pdf.reset_index()


@pytest.fixture(scope="module")
def spark_session():
# Spark registers a global signal handler that can cause problems elsewhere
# in the test suite. In particular, the handler fails if the spark session
# is stopped (a bug in pyspark).
prev = signal.getsignal(signal.SIGINT)
# Create a spark session. Note that we set the timezone to UTC to avoid
# conversion to local time when reading parquet files.
spark = (
pyspark.sql.SparkSession.builder.master("local")
.appName("Dask Testing")
.config("spark.sql.session.timeZone", "UTC")
.getOrCreate()
)
yield spark

spark.stop()
# Make sure we get rid of the signal once we leave stop the session.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, prev)


@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_spark_to_dask(spark_session, npartitions, tmpdir, engine):
tmpdir = str(tmpdir)

sdf = spark_session.createDataFrame(pdf)
# We are not overwriting any data, but spark complains if the directory
# already exists (as tmpdir does) and we don't set overwrite
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")

ddf = dd.read_parquet(tmpdir, engine=engine)
# Papercut: pandas TZ localization doesn't survive roundtrip
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))
assert ddf.npartitions == npartitions

assert_eq(ddf, pdf, check_index=False)


@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_hive_parquet_spark_to_dask(spark_session, tmpdir, engine):
tmpdir = str(tmpdir)

sdf = spark_session.createDataFrame(pdf)
# not overwriting any data, but spark complains if the directory
# already exists and we don't set overwrite
sdf.write.parquet(tmpdir, mode="overwrite", partitionBy="name")

ddf = dd.read_parquet(tmpdir, engine=engine)
# Papercut: pandas TZ localization doesn't survive roundtrip
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC"))

# Partitioning can change the column order. This is mostly okay,
# but we sort them here to ease comparison
ddf = ddf.compute().sort_index(axis=1)
# Dask automatically converts hive-partitioned columns to categories.
# This is fine, but convert back to strings for comparison.
ddf = ddf.assign(name=ddf.name.astype("str"))

assert_eq(ddf, pdf.sort_index(axis=1), check_index=False)


@pytest.mark.parametrize("npartitions", (1, 5, 10))
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet"))
def test_roundtrip_parquet_dask_to_spark(spark_session, npartitions, tmpdir, engine):
tmpdir = str(tmpdir)
ddf = dd.from_pandas(pdf, npartitions=npartitions)

# Papercut: https://github.com/dask/fastparquet/issues/646#issuecomment-885614324
kwargs = {"times": "int96"} if engine == "fastparquet" else {}

ddf.to_parquet(tmpdir, engine=engine, write_index=False, **kwargs)

sdf = spark_session.read.parquet(tmpdir)
sdf = sdf.toPandas()

# Papercut: pandas TZ localization doesn't survive roundtrip
sdf = sdf.assign(timestamp=sdf.timestamp.dt.tz_localize("UTC"))

assert_eq(sdf, ddf, check_index=False)

0 comments on commit 22915dc

Please sign in to comment.