From bfd8ae9721d6c29dc9d6b4ecb973608a6e00a108 Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Wed, 24 Apr 2024 17:48:13 +0300 Subject: [PATCH] [DOP-13844] - add csv serialization tests --- onetl/file/format/csv.py | 14 ++++++++ .../test_csv_integration.py | 35 ++++++++++++++++--- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/onetl/file/format/csv.py b/onetl/file/format/csv.py index 310626237..153c8923d 100644 --- a/onetl/file/format/csv.py +++ b/onetl/file/format/csv.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import warnings from typing import TYPE_CHECKING, ClassVar try: @@ -152,6 +153,7 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: from pyspark.sql.functions import col, from_csv self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437 + self._check_unsupported_serialization_options() if isinstance(column, Column): column_name = column._jc.toString() # noqa: WPS437 @@ -197,6 +199,7 @@ def serialize_column(self, column: str | Column) -> Column: from pyspark.sql.functions import col, to_csv self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437 + self._check_unsupported_serialization_options() if isinstance(column, Column): column_name = column._jc.toString() # noqa: WPS437 @@ -204,3 +207,14 @@ def serialize_column(self, column: str | Column) -> Column: column_name, column = column, col(column) return to_csv(column, self.dict()).alias(column_name) + + def _check_unsupported_serialization_options(self): + unsupported_options = ["header", "compression", "inferSchema"] + for option in unsupported_options: + if self.dict().get(option): + warnings.warn( + f"Option `{option}` is set but not supported in `CSV.parse_column` or `CSV.serialize_column`. " + "This may lead to unexpected behavior.", + UserWarning, + stacklevel=2, + ) diff --git a/tests/tests_integration/test_file_format_integration/test_csv_integration.py b/tests/tests_integration/test_file_format_integration/test_csv_integration.py index 0944e76c0..760f4b114 100644 --- a/tests/tests_integration/test_file_format_integration/test_csv_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_csv_integration.py @@ -161,10 +161,11 @@ def test_csv_writer_with_options( ], ids=["comma-delimited", "semicolon-delimited", "quoted-comma-delimited"], ) -def test_csv_parse_column(spark, csv_string, schema, options, expected): +@pytest.mark.parametrize("column_type", [str, col]) +def test_csv_parse_column(spark, csv_string, schema, options, expected, column_type): csv_handler = CSV(**options) df = spark.createDataFrame([(csv_string,)], ["csv_string"]) - parsed_df = df.select(csv_handler.parse_column("csv_string", schema)) + parsed_df = df.select(csv_handler.parse_column(column_type("csv_string"), schema)) assert parsed_df.columns == ["csv_string"] assert parsed_df.first()["csv_string"] == expected @@ -187,10 +188,36 @@ def test_csv_parse_column(spark, csv_string, schema, options, expected): ], ids=["comma-delimited", "semicolon-delimited"], ) -def test_csv_serialize_column(spark, data, schema, options, expected_csv): +@pytest.mark.parametrize("column_type", [str, col]) +def test_csv_serialize_column(spark, data, schema, options, expected_csv, column_type): csv_handler = CSV(**options) df = spark.createDataFrame([data], schema) df = df.withColumn("csv_column", struct("id", "name")) - serialized_df = df.select(csv_handler.serialize_column("csv_column")) + serialized_df = df.select(csv_handler.serialize_column(column_type("csv_column"))) assert serialized_df.columns == ["csv_column"] assert serialized_df.first()["csv_column"] == expected_csv + + +@pytest.mark.parametrize( + "options", + [ + ({"header": True}), + ({"compression": "gzip"}), + ({"inferSchema": True}), + ], + ids=["with-header", "with-compression", "with-inferSchema"], +) +def test_csv_unsupported_options_warning(spark, options): + schema = StructType([StructField("id", IntegerType()), StructField("name", StringType())]) + df = spark.createDataFrame([Row(id=1, name="Alice")], schema) + df = df.withColumn("csv_column", struct("id", "name")) + + csv_handler = CSV(**options) + msg = ( + f"Option `{list(options.keys())[0]}` is set but not supported in `CSV.parse_column` or `CSV.serialize_column`." + ) + + with pytest.warns(UserWarning) as record: + df.select(csv_handler.serialize_column("csv_column")).collect() + assert record + assert msg in str(record[0].message)