-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test round-tripping dataframe parquet I/O including pyspark (dask#9156)
- Loading branch information
Ian Rose
authored
Jun 8, 2022
1 parent
6369cdb
commit 22915dc
Showing
7 changed files
with
134 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import signal | ||
import sys | ||
import threading | ||
|
||
import pytest | ||
|
||
from dask.datasets import timeseries | ||
|
||
dd = pytest.importorskip("dask.dataframe") | ||
pyspark = pytest.importorskip("pyspark") | ||
pytest.importorskip("pyarrow") | ||
pytest.importorskip("fastparquet") | ||
|
||
from dask.dataframe.utils import assert_eq | ||
|
||
if not sys.platform.startswith("linux"): | ||
pytest.skip( | ||
"Unnecessary, and hard to get spark working on non-linux platforms", | ||
allow_module_level=True, | ||
) | ||
|
||
# pyspark auto-converts timezones -- round-tripping timestamps is easier if | ||
# we set everything to UTC. | ||
pdf = timeseries(freq="1H").compute() | ||
pdf.index = pdf.index.tz_localize("UTC") | ||
pdf = pdf.reset_index() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def spark_session(): | ||
# Spark registers a global signal handler that can cause problems elsewhere | ||
# in the test suite. In particular, the handler fails if the spark session | ||
# is stopped (a bug in pyspark). | ||
prev = signal.getsignal(signal.SIGINT) | ||
# Create a spark session. Note that we set the timezone to UTC to avoid | ||
# conversion to local time when reading parquet files. | ||
spark = ( | ||
pyspark.sql.SparkSession.builder.master("local") | ||
.appName("Dask Testing") | ||
.config("spark.sql.session.timeZone", "UTC") | ||
.getOrCreate() | ||
) | ||
yield spark | ||
|
||
spark.stop() | ||
# Make sure we get rid of the signal once we leave stop the session. | ||
if threading.current_thread() is threading.main_thread(): | ||
signal.signal(signal.SIGINT, prev) | ||
|
||
|
||
@pytest.mark.parametrize("npartitions", (1, 5, 10)) | ||
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet")) | ||
def test_roundtrip_parquet_spark_to_dask(spark_session, npartitions, tmpdir, engine): | ||
tmpdir = str(tmpdir) | ||
|
||
sdf = spark_session.createDataFrame(pdf) | ||
# We are not overwriting any data, but spark complains if the directory | ||
# already exists (as tmpdir does) and we don't set overwrite | ||
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite") | ||
|
||
ddf = dd.read_parquet(tmpdir, engine=engine) | ||
# Papercut: pandas TZ localization doesn't survive roundtrip | ||
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC")) | ||
assert ddf.npartitions == npartitions | ||
|
||
assert_eq(ddf, pdf, check_index=False) | ||
|
||
|
||
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet")) | ||
def test_roundtrip_hive_parquet_spark_to_dask(spark_session, tmpdir, engine): | ||
tmpdir = str(tmpdir) | ||
|
||
sdf = spark_session.createDataFrame(pdf) | ||
# not overwriting any data, but spark complains if the directory | ||
# already exists and we don't set overwrite | ||
sdf.write.parquet(tmpdir, mode="overwrite", partitionBy="name") | ||
|
||
ddf = dd.read_parquet(tmpdir, engine=engine) | ||
# Papercut: pandas TZ localization doesn't survive roundtrip | ||
ddf = ddf.assign(timestamp=ddf.timestamp.dt.tz_localize("UTC")) | ||
|
||
# Partitioning can change the column order. This is mostly okay, | ||
# but we sort them here to ease comparison | ||
ddf = ddf.compute().sort_index(axis=1) | ||
# Dask automatically converts hive-partitioned columns to categories. | ||
# This is fine, but convert back to strings for comparison. | ||
ddf = ddf.assign(name=ddf.name.astype("str")) | ||
|
||
assert_eq(ddf, pdf.sort_index(axis=1), check_index=False) | ||
|
||
|
||
@pytest.mark.parametrize("npartitions", (1, 5, 10)) | ||
@pytest.mark.parametrize("engine", ("pyarrow", "fastparquet")) | ||
def test_roundtrip_parquet_dask_to_spark(spark_session, npartitions, tmpdir, engine): | ||
tmpdir = str(tmpdir) | ||
ddf = dd.from_pandas(pdf, npartitions=npartitions) | ||
|
||
# Papercut: https://github.com/dask/fastparquet/issues/646#issuecomment-885614324 | ||
kwargs = {"times": "int96"} if engine == "fastparquet" else {} | ||
|
||
ddf.to_parquet(tmpdir, engine=engine, write_index=False, **kwargs) | ||
|
||
sdf = spark_session.read.parquet(tmpdir) | ||
sdf = sdf.toPandas() | ||
|
||
# Papercut: pandas TZ localization doesn't survive roundtrip | ||
sdf = sdf.assign(timestamp=sdf.timestamp.dt.tz_localize("UTC")) | ||
|
||
assert_eq(sdf, ddf, check_index=False) |