diff --git a/docs/connection/db_connection/kafka/format_handling.rst b/docs/connection/db_connection/kafka/format_handling.rst index 35b4c8b22..5f2d00864 100644 --- a/docs/connection/db_connection/kafka/format_handling.rst +++ b/docs/connection/db_connection/kafka/format_handling.rst @@ -276,9 +276,9 @@ To process XML formatted data from Kafka, use the :obj:`XML.parse_column Alice20"|topicXML |0 |0 |2024-04-24 13:02:25.911|0 | - # |[32]|"Bob25" |topicXML |0 |1 |2024-04-24 13:02:25.922|0 | + # +----+--------------------------------------------------------------------------------------------+----------+---------+------+-----------------------+-------------+ + # |[31]|"Alice20" |topicXML |0 |0 |2024-04-24 13:02:25.911|0 | + # |[32]|"Bob25" |topicXML |0 |1 |2024-04-24 13:02:25.922|0 | # +----+--------------------------------------------------------------------------------------------+----------+---------+------+-----------------------+-------------+ xml_schema = StructType( diff --git a/onetl/file/format/xml.py b/onetl/file/format/xml.py index a4d0573ed..8137b1025 100644 --- a/onetl/file/format/xml.py +++ b/onetl/file/format/xml.py @@ -240,7 +240,7 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: .. note:: - This method parses each XML string in the DataFrame individually; therefore, each string must contain exactly one occurrence of the ``rowTag`` without any surrounding root tags. If your XML data includes a root tag that encapsulates the row tags, you must preprocess the XML to remove or ignore this root tag before parsing. + This method parses each DataFrame row individually; therefore for specific column each row must contain exactly one occurrence of the ``rowTag`` without any surrounding root tags. If your XML data includes a root tag that encapsulates the row tags, you must preprocess the XML to remove or ignore this root tag before parsing. .. code-block:: xml diff --git a/tests/tests_integration/test_file_format_integration/test_xml_integration.py b/tests/tests_integration/test_file_format_integration/test_xml_integration.py index 916a3ae01..9d14e90e7 100644 --- a/tests/tests_integration/test_file_format_integration/test_xml_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_xml_integration.py @@ -4,6 +4,8 @@ Do not test all the possible options and combinations, we are not testing Spark here. """ +import datetime + import pytest from onetl._util.spark import get_spark_version @@ -11,6 +13,7 @@ from onetl.file.format import XML try: + from pyspark.sql import Row from pyspark.sql.functions import col from tests.util.assert_df import assert_equal_df @@ -170,58 +173,39 @@ def test_xml_reader_with_attributes( assert_equal_df(read_df, expected_xml_attributes_df, order_by="id") +@pytest.mark.parametrize( + "xml_input, expected_row", + [ + ( + """ + 1 + Alice + 123 + 2021-01-01 + 2021-01-01T07:01:01Z + 1.23 + """, + Row( + xml_string=Row( + id=1, + str_value="Alice", + int_value=123, + date_value=datetime.date(2021, 1, 1), + datetime_value=datetime.datetime(2021, 1, 1, 7, 1, 1), + float_value=1.23, + ), + ), + ), + ], + ids=["basic-case"], +) @pytest.mark.parametrize("column_type", [str, col]) -def test_xml_parse_column( - spark, - local_fs_file_df_connection_with_path_and_files, - expected_xml_attributes_df, - file_df_dataframe, - file_df_schema, - column_type, -): - from pyspark.sql.types import StringType - +def test_xml_parse_column(spark, xml_input: str, expected_row: Row, column_type, file_df_schema): from onetl.file.format import XML - spark_version = get_spark_version(spark) - if spark_version.major < 3: - pytest.skip("XML files are supported on Spark 3.x only") - - def to_xml(row): - # convert datetime to UTC - import pytz - - utc_datetime = row.datetime_value.astimezone(pytz.utc) - utc_datetime_str = utc_datetime.isoformat() - - return f""" - {row.id} - {row.str_value} - {row.int_value} - {row.date_value} - {utc_datetime_str} - {row.float_value} - """ - - xml_rdd = spark.sparkContext.parallelize(expected_xml_attributes_df.rdd.map(to_xml).collect()) - df = spark.createDataFrame(xml_rdd, StringType()).toDF("xml_string") - xml = XML(row_tag="item") + df = spark.createDataFrame([(xml_input,)], ["xml_string"]) parsed_df = df.select(xml.parse_column(column_type("xml_string"), schema=file_df_schema)) - transformed_df = parsed_df.select( - "xml_string.id", - "xml_string.str_value", - "xml_string.int_value", - "xml_string.date_value", - "xml_string.datetime_value", - "xml_string.float_value", - ) - expected_df_selected = expected_xml_attributes_df.select( - "id", - "str_value", - "int_value", - "date_value", - "datetime_value", - "float_value", - ) - assert_equal_df(transformed_df, expected_df_selected) + result_row = parsed_df.first() + + assert result_row == expected_row