From 5fdf710c150abf14b9d74c4c4d9eac552b2fe458 Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Fri, 26 Apr 2024 15:27:46 +0300 Subject: [PATCH] [DOP-13845] - add _get_schema_json --- onetl/file/format/avro.py | 94 ++++++++++++++----- .../test_avro_integration.py | 43 +++++---- 2 files changed, 94 insertions(+), 43 deletions(-) diff --git a/onetl/file/format/avro.py b/onetl/file/format/avro.py index 29e39da07..ff7059537 100644 --- a/onetl/file/format/avro.py +++ b/onetl/file/format/avro.py @@ -6,8 +6,6 @@ import logging from typing import TYPE_CHECKING, ClassVar, Dict, Optional -import requests - try: from pydantic.v1 import Field, validator except (ImportError, AttributeError): @@ -201,6 +199,14 @@ def parse_column(self, column: str | Column) -> Column: Can be used only with Spark 3.x+ + .. warning:: + + If ``schema_url`` is provided, ``requests`` library is used to fetch the schema from the URL. It should be installed manually, like this: + + .. code:: bash + + pip install requests + Parameters ---------- column : str | Column @@ -211,6 +217,14 @@ def parse_column(self, column: str | Column) -> Column: Column A new Column object with data parsed from Avro binary to the specified structured format. + Raises + ------ + ValueError + If the Spark version is less than 3.x or if neither schema_dict nor schema_url is defined. + ImportError + If ``schema_url`` is used and the ``requests`` library is not installed. + + Examples -------- .. code:: python @@ -233,25 +247,24 @@ def parse_column(self, column: str | Column) -> Column: """ from pyspark.sql import Column, SparkSession # noqa: WPS442 - from pyspark.sql.avro.functions import from_avro from pyspark.sql.functions import col - self.check_if_supported(SparkSession.getActiveSession()) + spark = SparkSession.getActiveSession() + self.check_if_supported(spark) + self._check_spark_version_for_serialization(spark) + + from pyspark.sql.avro.functions import from_avro if isinstance(column, Column): column_name = column._jc.toString() # noqa: WPS437 else: column_name, column = column, col(column).cast("binary") - if self.schema_dict: - schema_json = json.dumps(self.schema_dict) - elif self.schema_url: - response = requests.get(self.schema_url) # noqa: S113 - schema_json = response.text - else: - raise ValueError("No schema defined in Avro class instance.") + schema = self._get_schema_json() + if not schema: + raise ValueError("Avro.parse_column can be used only with defined `schema_dict` or `schema_url`") - return from_avro(column, schema_json).alias(column_name) + return from_avro(column, schema).alias(column_name) def serialize_column(self, column: str | Column) -> Column: """ @@ -262,6 +275,14 @@ def serialize_column(self, column: str | Column) -> Column: Can be used only with Spark 3.x+ + .. warning:: + + If ``schema_url`` is provided, ``requests`` library is used to fetch the schema from the URL. It should be installed manually, like this: + + .. code:: bash + + pip install requests + Parameters ---------- column : str | Column @@ -272,6 +293,13 @@ def serialize_column(self, column: str | Column) -> Column: Column A new Column object with data serialized from Spark SQL structures to Avro binary. + Raises + ------ + ValueError + If the Spark version is less than 3.x. + ImportError + If ``schema_url`` is used and the ``requests`` library is not installed. + Examples -------- .. code:: python @@ -295,23 +323,20 @@ def serialize_column(self, column: str | Column) -> Column: """ from pyspark.sql import Column, SparkSession # noqa: WPS442 - from pyspark.sql.avro.functions import to_avro from pyspark.sql.functions import col - self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437 + spark = SparkSession.getActiveSession() + self.check_if_supported(spark) + self._check_spark_version_for_serialization(spark) + + from pyspark.sql.avro.functions import to_avro if isinstance(column, Column): column_name = column._jc.toString() # noqa: WPS437 else: column_name, column = column, col(column) - if self.schema_dict: - schema = json.dumps(self.schema_dict) - elif self.self.schema_url: - schema = requests.get(self.self.schema_url) # noqa: S113 - else: - schema = "" - + schema = self._get_schema_json() return to_avro(column, schema).alias(column_name) @validator("schema_dict", pre=True) @@ -319,3 +344,30 @@ def _parse_schema_from_json(cls, value): if isinstance(value, (str, bytes)): return json.loads(value) return value + + def _check_spark_version_for_serialization(self, spark: SparkSession): + spark_version = get_spark_version(spark) + if spark_version.major < 3: + class_name = self.__class__.__name__ + error_msg = ( + f"`{class_name}.parse_column` or `{class_name}.serialize_column` are available " + f"only since Spark 3.x, but got {spark_version}." + ) + raise ValueError(error_msg) + + def _get_schema_json(self) -> str: + if self.schema_dict: + return json.dumps(self.schema_dict) + elif self.schema_url: + try: + import requests + + response = requests.get(self.schema_url) # noqa: S113 + return response.text + except ImportError as e: + raise ImportError( + "The 'requests' library is required to use 'schema_url' but is not installed. " + "Install it with 'pip install requests' or avoid using 'schema_url'.", + ) from e + else: + return "" diff --git a/tests/tests_integration/test_file_format_integration/test_avro_integration.py b/tests/tests_integration/test_file_format_integration/test_avro_integration.py index a8e9a241a..63a50ca40 100644 --- a/tests/tests_integration/test_file_format_integration/test_avro_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_avro_integration.py @@ -4,6 +4,8 @@ Do not test all the possible options and combinations, we are not testing Spark here. """ +import contextlib + import pytest from onetl._util.spark import get_spark_version @@ -61,7 +63,7 @@ def test_avro_reader( """Reading Avro files working as expected on any Spark, Python and Java versions""" spark_version = get_spark_version(spark) if spark_version < Version("2.4"): - pytest.skip("Avro files are supported on Spark 3.2+ only") + pytest.skip("Avro files are supported on Spark 2.4+ only") local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files df = file_df_dataframe @@ -98,7 +100,7 @@ def test_avro_writer( """Written files can be read by Spark""" spark_version = get_spark_version(spark) if spark_version < Version("2.4"): - pytest.skip("Avro files are supported on Spark 3.2+ only") + pytest.skip("Avro files are supported on Spark 2.4+ only") file_df_connection, source_path = local_fs_file_df_connection_with_path df = file_df_dataframe @@ -124,39 +126,36 @@ def test_avro_writer( assert_equal_df(read_df, df, order_by="id") -@pytest.mark.parametrize( - "path, options", - [ - ("without_compression", {}), - ("with_compression", {"compression": "snappy"}), - ], - ids=["without_compression", "with_compression"], -) @pytest.mark.parametrize("column_type", [str, col]) def test_avro_serialize_and_parse_column( spark, local_fs_file_df_connection_with_path, file_df_dataframe, - path, avro_schema, - options, column_type, ): from pyspark.sql.functions import struct from pyspark.sql.types import BinaryType spark_version = get_spark_version(spark) - if spark_version < Version("2.4"): - pytest.skip("Avro from_avro, to_avro are supported on Spark 3.x+ only") - + if spark_version <= Version("2.4"): + pytest.skip("Avro files are supported on Spark 2.4+ only") + + if spark_version.major < 3: + msg = ( + f"`Avro.parse_column` or `Avro.serialize_column` are available " + f"only since Spark 3.x, but got {spark_version}" + ) + context_manager = pytest.raises(ValueError, match=msg) + else: + context_manager = contextlib.nullcontext() df = file_df_dataframe - avro = Avro(schema_dict=avro_schema, **options) + avro = Avro(schema_dict=avro_schema) combined_df = df.withColumn("combined", struct([col(c) for c in df.columns])) - serialized_df = combined_df.select(avro.serialize_column(column_type("combined"))) - - assert isinstance(serialized_df.schema["combined"].dataType, BinaryType) - - parsed_df = serialized_df.select(avro.parse_column(column_type("combined"))) - assert combined_df.select(column_type("combined")).collect() == parsed_df.collect() + with context_manager: + serialized_df = combined_df.select(avro.serialize_column(column_type("combined"))) + assert isinstance(serialized_df.schema["combined"].dataType, BinaryType) + parsed_df = serialized_df.select(avro.parse_column(column_type("combined"))) + assert combined_df.select("combined").collect() == parsed_df.collect()