diff --git a/src/atc/delta/__init__.py b/src/atc/delta/__init__.py index 675d4d37..ff4e09ea 100644 --- a/src/atc/delta/__init__.py +++ b/src/atc/delta/__init__.py @@ -1,2 +1,4 @@ +from .autoloaderstream_handle import AutoloaderStreamHandle # noqa: F401 from .db_handle import DbHandle # noqa: F401 from .delta_handle import DeltaHandle # noqa: F401 +from .deltastream_handle import DeltaStreamHandle # noqa: F401 diff --git a/src/atc/delta/autoloaderstream_handle.py b/src/atc/delta/autoloaderstream_handle.py new file mode 100644 index 00000000..f722b20f --- /dev/null +++ b/src/atc/delta/autoloaderstream_handle.py @@ -0,0 +1,74 @@ +from pyspark.sql import DataFrame + +from atc.configurator.configurator import Configurator +from atc.spark import Spark +from atc.tables import TableHandle +from atc.tables.SparkHandle import DeltaHandleInvalidFormat + + +class AutoloaderStreamHandle(TableHandle): + def __init__( + self, + *, + location: str, + checkpoint_path: str, + data_format: str, + ): + """ + location: the location of the delta table + + checkpoint_path: The location of the checkpoints, /_checkpoints + The Delta Lake VACUUM function removes all files not managed by Delta Lake + but skips any directories that begin with _. You can safely store + checkpoints alongside other data and metadata for a Delta table + using a directory structure such as /_checkpoints + See: https://docs.databricks.com/structured-streaming/delta-lake.html + + data_format: the data format of the files that are read + + """ + + assert ( + Spark.version() >= Spark.DATABRICKS_RUNTIME_10_4 + ), f"AutoloaderStreamHandle not available for Spark version {Spark.version()}" + + self._location = location + self._data_format = data_format + self._checkpoint_path = checkpoint_path + + self._validate() + self._validate_checkpoint() + + @classmethod + def from_tc(cls, id: str) -> "AutoloaderStreamHandle": + tc = Configurator() + return cls( + location=tc.table_property(id, "path", None), + data_format=tc.table_property(id, "format", None), + checkpoint_path=tc.table_property(id, "checkpoint_path", None), + ) + + def _validate(self): + """Validates that the name is either db.table or just table.""" + if self._data_format == "delta": + raise DeltaHandleInvalidFormat("Use DeltaStreamHandle for delta.") + + def _validate_checkpoint(self): + if "/_" not in self._checkpoint_path: + print( + "RECOMMENDATION: You can safely store checkpoints alongside " + "other data and metadata for a Delta table using a directory " + "structure such as /_checkpoints" + ) + + def read(self) -> DataFrame: + + reader = ( + Spark.get() + .readStream.format("cloudFiles") + .option("cloudFiles.format", self._data_format) + .option("cloudFiles.schemaLocation", self._checkpoint_path) + .load(self._location) + ) + + return reader diff --git a/src/atc/delta/delta_handle.py b/src/atc/delta/delta_handle.py index bae5e2f1..fe8358ba 100644 --- a/src/atc/delta/delta_handle.py +++ b/src/atc/delta/delta_handle.py @@ -1,35 +1,19 @@ -from typing import List, Optional, Union +from typing import List, Union from pyspark.sql import DataFrame from atc.configurator.configurator import Configurator -from atc.exceptions import AtcException from atc.functions import get_unique_tempview_name, init_dbutils from atc.spark import Spark -from atc.tables.TableHandle import TableHandle +from atc.tables.SparkHandle import SparkHandle from atc.utils.CheckDfMerge import CheckDfMerge from atc.utils.GetMergeStatement import GetMergeStatement -class DeltaHandleException(AtcException): - pass - - -class DeltaHandleInvalidName(DeltaHandleException): - pass - - -class DeltaHandleInvalidFormat(DeltaHandleException): - pass - - -class DeltaHandle(TableHandle): +class DeltaHandle(SparkHandle): def __init__(self, name: str, location: str = None, data_format: str = "delta"): - self._name = name - self._location = location - self._data_format = data_format - self._partitioning: Optional[List[str]] = None + super().__init__(name, location, data_format) self._validate() @@ -42,29 +26,6 @@ def from_tc(cls, id: str) -> "DeltaHandle": data_format=tc.table_property(id, "format", "delta"), ) - def _validate(self): - """Validates that the name is either db.table or just table.""" - if not self._name: - if not self._location: - raise DeltaHandleInvalidName( - "Cannot create DeltaHandle without name or path" - ) - self._name = f"delta.`{self._location}`" - else: - name_parts = self._name.split(".") - if len(name_parts) == 1: - self._db = None - self._table_name = name_parts[0] - elif len(name_parts) == 2: - self._db = name_parts[0] - self._table_name = name_parts[1] - else: - raise DeltaHandleInvalidName(f"Could not parse name {self._name}") - - # only format delta is supported. - if self._data_format != "delta": - raise DeltaHandleInvalidFormat("Only format delta is supported.") - def read(self) -> DataFrame: """Read table by path if location is given, otherwise from name.""" if self._location: @@ -102,60 +63,6 @@ def drop_and_delete(self) -> None: if self._location: init_dbutils().fs.rm(self._location, True) - def create_hive_table(self) -> None: - sql = f"CREATE TABLE IF NOT EXISTS {self._name} " - if self._location: - sql += f" USING DELTA LOCATION '{self._location}'" - Spark.get().sql(sql) - - def recreate_hive_table(self): - self.drop() - self.create_hive_table() - - def get_partitioning(self): - """The result of DESCRIBE TABLE tablename is like this: - +-----------------+---------------+-------+ - | col_name| data_type|comment| - +-----------------+---------------+-------+ - | mycolA| string| | - | myColB| int| | - | | | | - | # Partitioning| | | - | Part 0| mycolA| | - +-----------------+---------------+-------+ - but this method return the partitioning in the form ['mycolA'], - if there is no partitioning, an empty list is returned. - """ - if self._partitioning is None: - # create an iterator object and use it in two steps - rows_iter = iter( - Spark.get().sql(f"DESCRIBE TABLE {self.get_tablename()}").collect() - ) - - # roll over the iterator until you see the title line - for row in rows_iter: - # discard rows until the important section header - if row.col_name.strip() == "# Partitioning": - break - # at this point, the iterator has moved past the section heading - # leaving only the rows with "Part 1" etc. - - # create a list from the rest of the iterator like [(0,colA), (1,colB)] - parts = [ - (int(row.col_name[5:]), row.data_type) - for row in rows_iter - if row.col_name.startswith("Part ") - ] - # sort, just in case the parts were out of order. - parts.sort() - - # discard the index and put into an ordered list. - self._partitioning = [p[1] for p in parts] - return self._partitioning - - def get_tablename(self) -> str: - return self._name - def upsert( self, df: DataFrame, diff --git a/src/atc/delta/deltastream_handle.py b/src/atc/delta/deltastream_handle.py new file mode 100644 index 00000000..834d3fb4 --- /dev/null +++ b/src/atc/delta/deltastream_handle.py @@ -0,0 +1,239 @@ +from typing import List + +from pyspark.sql import DataFrame +from pyspark.sql.streaming import DataStreamWriter + +from atc.configurator.configurator import Configurator +from atc.functions import init_dbutils +from atc.spark import Spark +from atc.tables.SparkHandle import SparkHandle +from atc.utils import GetMergeStatement +from atc.utils.FileExists import file_exists + + +class DeltaStreamHandle(SparkHandle): + def __init__( + self, + *, + checkpoint_path: str, + name: str = None, + data_format: str = None, + location: str = None, + trigger_type: str = "availablenow", + trigger_time: str = None, + await_termination=False, + ): + """ + name: name of the delta table + + checkpoint_path: The location of the checkpoints, /_checkpoints + The Delta Lake VACUUM function removes all files not managed by Delta Lake + but skips any directories that begin with _. You can safely store + checkpoints alongside other data and metadata for a Delta table + using a directory structure such as /_checkpoints + See: https://docs.databricks.com/structured-streaming/delta-lake.html + + location: the location of the delta table (Optional) + + data_format: the data format of the files that are read (Default delta) + + trigger_type: the trigger type of the stream. + See: https://docs.databricks.com/structured-streaming/triggers.html + + trigger_time: if the trigger has is "processingtime", + it should have a trigger time associated + + awaitTermination: if true, the ETL will wait for the termination of THIS query. + + """ + assert ( + Spark.version() >= Spark.DATABRICKS_RUNTIME_10_4 + ), f"DeltaStreamHandle not available for Spark version {Spark.version()}" + + super().__init__(name, location, data_format) + + self._checkpoint_path = checkpoint_path + self._trigger_type = trigger_type.lower() if trigger_type else "availablenow" + self._trigger_time = trigger_time.lower() if trigger_time else None + self._awaitTermination = await_termination if await_termination else True + self._validate() + self._validate_trigger_type() + self._validate_checkpoint() + + @classmethod + def from_tc(cls, id: str) -> "DeltaStreamHandle": + tc = Configurator() + return cls( + name=tc.table_property(id, "name", ""), + location=tc.table_property(id, "path", ""), + data_format=tc.table_property(id, "format", None), + checkpoint_path=tc.table_property(id, "checkpoint_path", None), + trigger_type=tc.table_property(id, "trigger_type", ""), + trigger_time=tc.table_property(id, "trigger_time", ""), + await_termination=tc.table_property(id, "await_termination", ""), + ) + + def _validate_trigger_type(self): + valid_trigger_types = {"availablenow", "once", "processingtime", "continuous"} + assert ( + self._trigger_type in valid_trigger_types + ), f"Triggertype should either be {valid_trigger_types}" + + # if trigger type is processingtime, then it should have a trigger time + assert (self._trigger_type == "processingtime") is ( + self._trigger_time is not None + ) + + def _validate_checkpoint(self): + if "/_" not in self._checkpoint_path: + print( + "RECOMMENDATION: You can safely store checkpoints alongside " + "other data and metadata for a Delta table using a directory " + "structure such as /_checkpoints" + ) + + def read(self) -> DataFrame: + + reader = Spark.get().readStream.format("delta") + + if self._location: + reader = reader.load(self._location) + else: + reader = reader.table(self._name) + + return reader + + def write_or_append( + self, df: DataFrame, mode: str, mergeSchema: bool = None + ) -> None: + + assert mode in {"append", "overwrite", "complete"} + assert df.isStreaming + + if mode == "overwrite": + print("WARNING: The term overwrite is called complete in streaming.") + mode = "complete" + + writer = df.writeStream.option( + "checkpointLocation", self._checkpoint_path + ).outputMode(mode) + + writer = self._add_trigger_type(writer) + + writer = self._add_write_options(writer, mergeSchema) + + if self._awaitTermination: + writer.awaitTermination() + + def overwrite(self, df: DataFrame, mergeSchema: bool = None) -> None: + return self.write_or_append(df, "complete", mergeSchema) + + def append(self, df: DataFrame, mergeSchema: bool = None) -> None: + return self.write_or_append(df, "append", mergeSchema) + + def truncate(self) -> None: + Spark.get().sql(f"TRUNCATE TABLE {self._name};") + + self.remove_checkpoint() + + def drop(self) -> None: + Spark.get().sql(f"DROP TABLE IF EXISTS {self._name};") + + self.remove_checkpoint() + + def drop_and_delete(self) -> None: + self.drop() + if self._location: + init_dbutils().fs.rm(self._location, True) + + def upsert(self, df: DataFrame, join_cols: List[str]) -> None: + assert df.isStreaming + + target_table_name = self.get_tablename() + non_join_cols = [col for col in df.columns if col not in join_cols] + + merge_sql_statement = GetMergeStatement( + merge_statement_type="delta", + target_table_name=target_table_name, + source_table_name="stream_updates", + join_cols=join_cols, + insert_cols=df.columns, + update_cols=non_join_cols, + special_update_set="", + ) + + streamingmerge = UpsertHelper(query=merge_sql_statement) + + writer = ( + df.writeStream.format("delta") + .foreachBatch(streamingmerge.upsert_to_delta) + .outputMode("update") + .option("checkpointLocation", self._checkpoint_path) + ) + + writer = self._add_trigger_type(writer) + + if self._awaitTermination: + writer.start().awaitTermination() # Consider removing awaitTermination + else: + writer.start() + + def _add_write_options(self, writer: DataStreamWriter, mergeSchema: bool): + + if self._partitioning: + writer = writer.partitionBy(self._partitioning) + + if mergeSchema is not None: + writer = writer.option("mergeSchema", "true" if mergeSchema else "false") + + if self._location: + writer = writer.start(self._location) + else: + writer = writer.toTable(self._name) + + return writer + + def _add_trigger_type(self, writer: DataStreamWriter): + + if self._trigger_type == "availablenow": + return writer.trigger(availableNow=True) + elif self._trigger_type == "once": + return writer.trigger(once=True) + elif self._trigger_type == "processingtime": + return writer.trigger(processingTime=self._trigger_time) + elif self._trigger_type == "continuous": + return writer.trigger(continuous=self._trigger_time) + else: + raise ValueError("Unknown trigger type.") + + def create_hive_table(self) -> None: + self.remove_checkpoint() + + sql = f"CREATE TABLE IF NOT EXISTS {self._name} " + if self._location: + sql += f" USING DELTA LOCATION '{self._location}'" + Spark.get().sql(sql) + + def remove_checkpoint(self): + if not file_exists(self._checkpoint_path): + init_dbutils().fs.mkdirs(self._checkpoint_path) + + +class UpsertHelper: + """ + In order to write upserts from a streaming query, this helper class can be used + in the foreachBatch method. + + The class helps upserting microbatches to the target table. + + See: https://docs.databricks.com/structured-streaming/ + delta-lake.html#upsert-from-streaming-queries-using-foreachbatch + """ + + def __init__(self, query: str, update_temp: str = "stream_updates"): + self.query = query + self.update_temp = update_temp + + def upsert_to_delta(self, micro_batch_df, batch): + micro_batch_df.createOrReplaceTempView(self.update_temp) + micro_batch_df._jdf.sparkSession().sql(self.query) diff --git a/src/atc/orchestrators/streaming/__init__.py b/src/atc/orchestrators/streaming/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/atc/spark.py b/src/atc/spark.py index 538d616f..6e6efb10 100644 --- a/src/atc/spark.py +++ b/src/atc/spark.py @@ -71,3 +71,4 @@ def version(cls) -> Tuple: return tuple(int(p) for p in cls.get().version.split(".")) DATABRICKS_RUNTIME_9_1 = (3, 1, 2) + DATABRICKS_RUNTIME_10_4 = (3, 2, 1) diff --git a/src/atc/tables/SparkHandle.py b/src/atc/tables/SparkHandle.py new file mode 100644 index 00000000..a94b3c17 --- /dev/null +++ b/src/atc/tables/SparkHandle.py @@ -0,0 +1,105 @@ +from typing import List, Optional + +from atc.exceptions import AtcException +from atc.spark import Spark +from atc.tables import TableHandle + + +class DeltaHandleException(AtcException): + pass + + +class DeltaHandleInvalidName(DeltaHandleException): + pass + + +class DeltaHandleInvalidFormat(DeltaHandleException): + pass + + +class SparkHandle(TableHandle): + """Common handle class for both DeltaHandle and StreamingHandle""" + + def __init__(self, name: str, location: str = None, data_format: str = "delta"): + self._name = name + self._location = location + self._data_format = data_format + + self._partitioning: Optional[List[str]] = None + + def _validate(self): + """Validates that the name is either db.table or just table.""" + if not self._name: + if not self._location: + raise DeltaHandleInvalidName( + "Cannot create DeltaHandle without name or path" + ) + self._name = f"delta.`{self._location}`" + else: + name_parts = self._name.split(".") + if len(name_parts) == 1: + self._db = None + self._table_name = name_parts[0] + elif len(name_parts) == 2: + self._db = name_parts[0] + self._table_name = name_parts[1] + else: + raise DeltaHandleInvalidName(f"Could not parse name {self._name}") + + # only format delta is supported. + if self._data_format != "delta": + raise DeltaHandleInvalidFormat("Only format delta is supported.") + + def create_hive_table(self) -> None: + sql = f"CREATE TABLE IF NOT EXISTS {self._name} " + if self._location: + sql += f" USING DELTA LOCATION '{self._location}'" + Spark.get().sql(sql) + + def recreate_hive_table(self): + self.drop() + self.create_hive_table() + + def get_partitioning(self): + """The result of DESCRIBE TABLE tablename is like this: + +-----------------+---------------+-------+ + | col_name| data_type|comment| + +-----------------+---------------+-------+ + | mycolA| string| | + | myColB| int| | + | | | | + | # Partitioning| | | + | Part 0| mycolA| | + +-----------------+---------------+-------+ + but this method return the partitioning in the form ['mycolA'], + if there is no partitioning, an empty list is returned. + """ + if self._partitioning is None: + # create an iterator object and use it in two steps + rows_iter = iter( + Spark.get().sql(f"DESCRIBE TABLE {self.get_tablename()}").collect() + ) + + # roll over the iterator until you see the title line + for row in rows_iter: + # discard rows until the important section header + if row.col_name.strip() == "# Partitioning": + break + # at this point, the iterator has moved past the section heading + # leaving only the rows with "Part 1" etc. + + # create a list from the rest of the iterator like [(0,colA), (1,colB)] + parts = [ + (int(row.col_name[5:]), row.data_type) + for row in rows_iter + if row.col_name.startswith("Part ") + ] + # sort, just in case the parts were out of order. + parts.sort() + + # discard the index and put into an ordered list. + self._partitioning = [p[1] for p in parts] + return self._partitioning + + def get_tablename(self) -> str: + return self._name diff --git a/src/atc/tables/TableHandle.py b/src/atc/tables/TableHandle.py index 15a35813..7ba6d37c 100644 --- a/src/atc/tables/TableHandle.py +++ b/src/atc/tables/TableHandle.py @@ -1,33 +1,32 @@ -from abc import ABC from typing import List, Union from pyspark.sql import DataFrame -class TableHandle(ABC): +class TableHandle: def read(self) -> DataFrame: - pass + raise NotImplementedError() def overwrite(self, df: DataFrame) -> None: - pass + raise NotImplementedError() def append(self, df: DataFrame) -> None: - pass + raise NotImplementedError() def truncate(self) -> None: - pass + raise NotImplementedError() def drop(self) -> None: - pass + raise NotImplementedError() def drop_and_delete(self) -> None: - pass + raise NotImplementedError() def write_or_append(self, df: DataFrame, mode: str) -> None: - pass + raise NotImplementedError() def upsert(self, df: DataFrame, join_cols: List[str]) -> Union[DataFrame, None]: - pass + raise NotImplementedError() def get_tablename(self) -> str: - pass + raise NotImplementedError() diff --git a/src/atc/utils/FileExists.py b/src/atc/utils/FileExists.py new file mode 100644 index 00000000..82b2463b --- /dev/null +++ b/src/atc/utils/FileExists.py @@ -0,0 +1,12 @@ +from atc.functions import init_dbutils + + +def file_exists(path: str): + """ + Helper function to check whether a file or folder exists. + """ + try: + init_dbutils().fs.ls(path) + return True + except Exception: + return False diff --git a/src/atc/utils/stop_all_streams.py b/src/atc/utils/stop_all_streams.py new file mode 100644 index 00000000..b8c8f75d --- /dev/null +++ b/src/atc/utils/stop_all_streams.py @@ -0,0 +1,8 @@ +from atc.spark import Spark + + +def stop_all_streams(): + for stream in Spark.get().streams.active: + print(f'Stopping the stream "{stream.name}"') + stream.stop() + stream.awaitTermination() diff --git a/tests/cluster/delta/extras/tablenames.yml b/tests/cluster/delta/extras/tablenames.yml index ee3a8d95..5b2be772 100644 --- a/tests/cluster/delta/extras/tablenames.yml +++ b/tests/cluster/delta/extras/tablenames.yml @@ -24,3 +24,6 @@ UpsertLoaderDb: UpsertLoaderDummy: name: "{UpsertLoaderDb}.Dummy" path: "{UpsertLoaderDb_path}/dummy" + format: "delta" + checkpoint_path: /tmp/checkpoints/_upsertcheckpoints + await_termination: True diff --git a/tests/cluster/delta/test_autoloaderstream_handle.py b/tests/cluster/delta/test_autoloaderstream_handle.py new file mode 100644 index 00000000..4184e1ef --- /dev/null +++ b/tests/cluster/delta/test_autoloaderstream_handle.py @@ -0,0 +1,175 @@ +import unittest +import uuid as _uuid +from typing import List, Tuple + +from atc import Configurator +from atc.delta import AutoloaderStreamHandle, DbHandle, DeltaHandle, DeltaStreamHandle +from atc.etl import Orchestrator +from atc.etl.extractors import SimpleExtractor +from atc.etl.loaders import SimpleLoader +from atc.functions import init_dbutils +from atc.spark import Spark +from atc.utils.FileExists import file_exists +from atc.utils.stop_all_streams import stop_all_streams +from tests.cluster.values import resourceName + + +@unittest.skipUnless( + Spark.version() >= Spark.DATABRICKS_RUNTIME_10_4, + f"Autoloader not available for Spark version {Spark.version()}", +) +class AutoloaderTests(unittest.TestCase): + avrosource_checkpoint_path = ( + f"/mnt/{resourceName()}/silver/{resourceName()}" + f"/avrolocation/_checkpoint_path_avro" + ) + + avro_source_path = ( + f"/mnt/{resourceName()}/silver/{resourceName()}/avrolocation/AvroSource" + ) + + @classmethod + def setUpClass(cls) -> None: + Configurator().clear_all_configurations() + Configurator().set_debug() + + if not file_exists(cls.avrosource_checkpoint_path): + init_dbutils().fs.mkdirs(cls.avrosource_checkpoint_path) + + if not file_exists(cls.avro_source_path): + init_dbutils().fs.mkdirs(cls.avro_source_path) + + @classmethod + def tearDownClass(cls) -> None: + DbHandle.from_tc("MyDb").drop_cascade() + if file_exists(cls.avrosource_checkpoint_path): + init_dbutils().fs.rm(cls.avrosource_checkpoint_path, True) + + if file_exists(cls.avro_source_path): + init_dbutils().fs.rm(cls.avro_source_path, True) + stop_all_streams() + + def test_01_configure(self): + tc = Configurator() + tc.register( + "MyDb", {"name": "TestDb{ID}", "path": "/mnt/atc/silver/testdb{ID}"} + ) + # add avro source + tc.register( + "AvroSource", + { + "name": "AvroSource", + "path": self.avro_source_path, + "format": "avro", + "partitioning": "ymd", + "checkpoint_path": self.avrosource_checkpoint_path, + }, + ) + + # Add sink table + sink_checkpoint_path = "/mnt/atc/silver/testdb{ID}/_checkpoint_path_avrosink" + init_dbutils().fs.mkdirs(sink_checkpoint_path) + # add eventhub sink + tc.register( + "AvroSink", + { + "name": "{MyDb}.AvroSink", + # "path": "{MyDb_path}/AvroSink", + "format": "delta", + "checkpoint_path": sink_checkpoint_path, + "await_termination": True, + }, + ) + + # test instantiation without error + DbHandle.from_tc("MyDb") + AutoloaderStreamHandle.from_tc("AvroSource") + DeltaStreamHandle.from_tc("AvroSink") + + def test_01_read_avro(self): + DbHandle.from_tc("MyDb").create() + + dsh_sink = DeltaStreamHandle.from_tc("AvroSink") + Spark.get().sql( + f""" + CREATE TABLE {dsh_sink.get_tablename()} + ( + id int, + name string, + _rescued_data string + ) + """ + ) + + self._add_avro_data_to_source([(1, "a", "None"), (2, "b", "None")]) + + o = Orchestrator() + o.extract_from( + SimpleExtractor( + AutoloaderStreamHandle.from_tc("AvroSource"), dataset_key="AvroSource" + ) + ) + + o.load_into(SimpleLoader(dsh_sink, mode="append")) + o.execute() + + result = DeltaHandle.from_tc("AvroSink").read() + + self.assertEqual(2, result.count()) + + # Run again. Should not append more. + o.execute() + self.assertEqual(2, result.count()) + + self._add_avro_data_to_source([(3, "c", "None"), (4, "d", "None")]) + + # Run again. Should append. + o.execute() + self.assertEqual(4, result.count()) + + # Add specific data to source + self._add_specific_data_to_source() + o.execute() + self.assertEqual(5, result.count()) + + # If the same file is altered + # the new row is appended also + self._alter_specific_data() + o.execute() + self.assertEqual(6, result.count()) + + def _create_tbl_mirror(self): + dh = DeltaHandle.from_tc("MyTblMirror") + Spark.get().sql( + f""" + CREATE TABLE {dh.get_tablename()} + ( + id int, + name string, + _rescued_data string + ) + """ + ) + + def _add_avro_data_to_source(self, input_data: List[Tuple[int, str, str]]): + df = Spark.get().createDataFrame( + input_data, "id int, name string, _rescued_data string" + ) + + df.write.format("avro").save(self.avro_source_path + "/" + str(_uuid.uuid4())) + + def _add_specific_data_to_source(self): + df = Spark.get().createDataFrame( + [(10, "specific", "None")], "id int, name string, _rescued_data string" + ) + + df.write.format("avro").save(self.avro_source_path + "/specific") + + def _alter_specific_data(self): + df = Spark.get().createDataFrame( + [(11, "specific", "None")], "id int, name string, _rescued_data string" + ) + + df.write.format("avro").mode("overwrite").save( + self.avro_source_path + "/specific" + ) diff --git a/tests/cluster/delta/test_delta_class.py b/tests/cluster/delta/test_delta_class.py index 1fb80c84..126a9e9e 100644 --- a/tests/cluster/delta/test_delta_class.py +++ b/tests/cluster/delta/test_delta_class.py @@ -14,6 +14,7 @@ class DeltaTests(unittest.TestCase): @classmethod def setUpClass(cls) -> None: Configurator().clear_all_configurations() + Configurator().set_debug() def test_01_configure(self): tc = Configurator() @@ -39,7 +40,21 @@ def test_01_configure(self): tc.register( "MyTbl3", { - "path": "/mnt/atc/silver/testdb/testtbl3", + "path": "/mnt/atc/silver/testdb{ID}/testtbl3", + }, + ) + + tc.register( + "MyTbl4", + { + "name": "TestDb{ID}.TestTbl4", + }, + ) + + tc.register( + "MyTbl5", + { + "name": "TestDb{ID}.TestTbl5", }, ) @@ -47,6 +62,9 @@ def test_01_configure(self): DbHandle.from_tc("MyDb") DeltaHandle.from_tc("MyTbl") DeltaHandle.from_tc("MyTbl2") + DeltaHandle.from_tc("MyTbl3") + DeltaHandle.from_tc("MyTbl4") + DeltaHandle.from_tc("MyTbl5") def test_02_write(self): dh = DeltaHandle.from_tc("MyTbl") @@ -75,7 +93,7 @@ def test_03_create(self): dh.create_hive_table() # test hive access: - df = Spark.get().table("TestDb.TestTbl") + df = dh.read() self.assertTrue(6, df.count()) def test_04_read(self): @@ -115,7 +133,7 @@ def test_08_delete(self): dh.read() def test_09_partitioning(self): - dh = DeltaHandle.from_tc("MyTbl") + dh = DeltaHandle.from_tc("MyTbl4") Spark.get().sql( f""" CREATE TABLE {dh.get_tablename()} @@ -130,7 +148,7 @@ def test_09_partitioning(self): self.assertEqual(dh.get_partitioning(), ["colB", "colA"]) - dh2 = DeltaHandle.from_tc("MyTbl2") + dh2 = DeltaHandle.from_tc("MyTbl5") Spark.get().sql( f""" CREATE TABLE {dh2.get_tablename()} diff --git a/tests/cluster/delta/test_delta_stream_handle.py b/tests/cluster/delta/test_delta_stream_handle.py new file mode 100644 index 00000000..adb5557b --- /dev/null +++ b/tests/cluster/delta/test_delta_stream_handle.py @@ -0,0 +1,215 @@ +import unittest + +from pyspark.sql.utils import AnalysisException + +from atc import Configurator +from atc.delta import DbHandle, DeltaHandle, DeltaStreamHandle +from atc.etl import Orchestrator +from atc.etl.extractors import SimpleExtractor +from atc.etl.loaders import SimpleLoader +from atc.spark import Spark +from atc.utils.stop_all_streams import stop_all_streams + + +@unittest.skipUnless( + Spark.version() >= Spark.DATABRICKS_RUNTIME_10_4, + f"DeltaStreamHandle not available for Spark version {Spark.version()}", +) +class DeltaStreamHandleTests(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + Configurator().clear_all_configurations() + Configurator().set_debug() + + @classmethod + def tearDownClass(cls) -> None: + DbHandle.from_tc("MyDb").drop_cascade() + stop_all_streams() + + def test_01_configure(self): + tc = Configurator() + tc.register( + "MyDb", {"name": "TestDb{ID}", "path": "/mnt/atc/silver/testdb{ID}"} + ) + + tc.register( + "MyTbl", + { + "name": "TestDb{ID}.TestTbl", + "path": "/mnt/atc/silver/testdb{ID}/testtbl", + "format": "delta", + "checkpoint_path": "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tbl", + }, + ) + + mirror_cp_path = "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tblmirror" + tc.register( + "MyTblMirror", + { + "name": "TestDb{ID}.TestTblMirror", + "path": "/mnt/atc/silver/testdb{ID}/testtblmirror", + "format": "delta", + "checkpoint_path": mirror_cp_path, + "await_termination": True, + }, + ) + + tc.register( + "MyTbl2", + { + "name": "TestDb{ID}.TestTbl2", + "format": "delta", + "checkpoint_path": "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tbl2", + }, + ) + + tc.register( + "MyTbl3", + { + "path": "/mnt/atc/silver/testdb{ID}/testtbl3", + "format": "delta", + "checkpoint_path": "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tbl3", + "await_termination": True, + }, + ) + + tc.register( + "MyTbl4", + { + "name": "TestDb{ID}.TestTbl4", + "path": "/mnt/atc/silver/testdb{ID}/testtbl4", + "format": "delta", + "checkpoint_path": "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tbl4", + }, + ) + + tc.register( + "MyTbl5", + { + "name": "TestDb{ID}.TestTbl5", + "path": "/mnt/atc/silver/testdb{ID}/testtbl5", + "format": "delta", + "checkpoint_path": "/mnt/atc/silver/testdb{ID}/_checkpoint_path_tbl5", + }, + ) + + # test instantiation without error + DbHandle.from_tc("MyDb") + DeltaStreamHandle.from_tc("MyTbl") + DeltaStreamHandle.from_tc("MyTblMirror") + DeltaStreamHandle.from_tc("MyTbl2") + DeltaStreamHandle.from_tc("MyTbl3") + DeltaStreamHandle.from_tc("MyTbl4") + DeltaStreamHandle.from_tc("MyTbl5") + + def test_02_write_data_with_deltahandle(self): + self._overwrite_two_rows_to_table("MyTbl") + + def test_03_create(self): + db = DbHandle.from_tc("MyDb") + db.create() + + dsh = DeltaStreamHandle.from_tc("MyTbl") + dsh.create_hive_table() + + # test hive access: + df = DeltaHandle.from_tc("MyTbl").read() + self.assertEqual(2, df.count()) + + def test_04_read(self): + df = DeltaStreamHandle.from_tc("MyTbl").read() + self.assertTrue(df.isStreaming) + + def test_05_truncate(self): + dsh = DeltaStreamHandle.from_tc("MyTbl") + dsh.truncate() + + result = DeltaHandle.from_tc("MyTbl").read() + self.assertEqual(0, result.count()) + + def test_06_etl(self): + self._overwrite_two_rows_to_table("MyTbl") + self._create_tbl_mirror() + + o = Orchestrator() + o.extract_from( + SimpleExtractor(DeltaStreamHandle.from_tc("MyTbl"), dataset_key="MyTbl") + ) + o.load_into( + SimpleLoader(DeltaStreamHandle.from_tc("MyTblMirror"), mode="append") + ) + o.execute() + + result = DeltaHandle.from_tc("MyTblMirror").read() + self.assertEqual(2, result.count()) + + def test_07_write_path_only(self): + self._overwrite_two_rows_to_table("MyTbl") + # check that we can write to the table with no "name" property + ah = DeltaStreamHandle.from_tc("MyTbl").read() + + dsh3 = DeltaStreamHandle.from_tc("MyTbl3") + + dsh3.append(ah, mergeSchema=True) + + # Read data from mytbl3 + result = DeltaHandle.from_tc("MyTbl3").read() + self.assertEqual(2, result.count()) + + def test_08_delete(self): + dsh = DeltaStreamHandle.from_tc("MyTbl") + dsh.drop_and_delete() + + ah = DeltaStreamHandle.from_tc("MyTbl") + + with self.assertRaises(AnalysisException): + ah.read() + + def test_09_partitioning(self): + dsh = DeltaStreamHandle.from_tc("MyTbl4") + Spark.get().sql( + f""" + CREATE TABLE IF NOT EXISTS {dsh.get_tablename()} + ( + colA string, + colB int, + payload string + ) + PARTITIONED BY (colB,colA) + """ + ) + + self.assertEqual(dsh.get_partitioning(), ["colB", "colA"]) + + dsh2 = DeltaStreamHandle.from_tc("MyTbl5") + Spark.get().sql( + f""" + CREATE TABLE IF NOT EXISTS {dsh2.get_tablename()} + ( + colA string, + colB int, + payload string + ) + """ + ) + + self.assertEqual(dsh2.get_partitioning(), []) + + def _overwrite_two_rows_to_table(self, tblid: str): + dh = DeltaHandle.from_tc(tblid) + + df = Spark.get().createDataFrame([(1, "a"), (2, "b")], "id int, name string") + + dh.overwrite(df, mergeSchema=True) + + def _create_tbl_mirror(self): + dh = DeltaHandle.from_tc("MyTblMirror") + Spark.get().sql( + f""" + CREATE TABLE {dh.get_tablename()} + ( + id int, + name string + ) + """ + ) diff --git a/tests/cluster/etl/test_upsertloader_streaming.py b/tests/cluster/etl/test_upsertloader_streaming.py new file mode 100644 index 00000000..301d3777 --- /dev/null +++ b/tests/cluster/etl/test_upsertloader_streaming.py @@ -0,0 +1,176 @@ +import unittest +from typing import List, Tuple + +from atc_tools.testing import DataframeTestCase + +from atc import Configurator +from atc.delta import DbHandle, DeltaHandle, DeltaStreamHandle +from atc.etl.loaders.UpsertLoader import UpsertLoader +from atc.functions import init_dbutils +from atc.spark import Spark +from atc.utils import DataframeCreator +from atc.utils.FileExists import file_exists +from atc.utils.stop_all_streams import stop_all_streams +from tests.cluster.delta import extras +from tests.cluster.delta.SparkExecutor import SparkSqlExecutor + + +@unittest.skipUnless( + Spark.version() >= Spark.DATABRICKS_RUNTIME_10_4, + f"UpsertLoader for streaming not available for Spark version {Spark.version()}", +) +class UpsertLoaderTestsDeltaStream(DataframeTestCase): + + source_table_checkpoint_path = None + join_cols = ["col1", "col2"] + + data1 = [ + (5, 6, "foo"), + (7, 8, "bar"), + ] + data2 = [ + (1, 2, "baz"), + ] + data3 = [(5, 6, "boo"), (5, 7, "spam")] + # data5 is the merge result of data2 + data3 + data4 + data4 = [(1, 2, "baz"), (5, 6, "boo"), (5, 7, "spam"), (7, 8, "bar")] + + dummy_columns: List[str] = ["col1", "col2", "col3"] + source_table_id: str = "Test1Table" + + dummy_schema = None + target_dh_dummy: DeltaHandle = None + target_ah_dummy: DeltaStreamHandle = None + + @classmethod + def setUpClass(cls) -> None: + tc = Configurator() + tc.add_resource_path(extras) + tc.set_debug() + + # Database for the source table + tc.register( + "AutoDbUpsert", + {"name": "TestUpsertAutoDb{ID}", "path": "/mnt/atc/silver/testdb{ID}"}, + ) + DbHandle.from_tc("AutoDbUpsert").create() + + # Register the source table + source_table_checkpoint_path = ( + "tmp/" + cls.source_table_id + "/_checkpoint_path" + ) + tc.register( + cls.source_table_id, + { + "name": "TestUpsertAutoDb{ID}." + cls.source_table_id, + "path": "/mnt/atc/silver/TestUpsertAutoDb{ID}/" + cls.source_table_id, + "format": "delta", + "checkpoint_path": source_table_checkpoint_path, + "await_termination": True, + }, + ) + + if not file_exists(source_table_checkpoint_path): + init_dbutils().fs.mkdirs(source_table_checkpoint_path) + + # Autoloader pointing at source table + cls.source_ah = DeltaStreamHandle.from_tc(cls.source_table_id) + + # Autoloader/Deltahandle pointing at target table + cls.target_ah_dummy = DeltaStreamHandle.from_tc("UpsertLoaderDummy") + cls.target_dh_dummy = DeltaHandle.from_tc("UpsertLoaderDummy") + + # Create target table + SparkSqlExecutor().execute_sql_file("upsertloader-test") + + cls.dummy_schema = cls.target_dh_dummy.read().schema + + # make sure target is empty and has a schema + df_empty = DataframeCreator.make_partial(cls.dummy_schema, [], []) + cls.target_dh_dummy.overwrite(df_empty) + + @classmethod + def tearDownClass(cls) -> None: + DbHandle.from_tc("UpsertLoaderDb").drop_cascade() + DbHandle.from_tc("AutoDbUpsert").drop_cascade() + + if file_exists(cls.source_table_checkpoint_path): + init_dbutils().fs.rm(cls.source_table_checkpoint_path) + stop_all_streams() + + def test_01_can_perform_incremental_on_empty(self): + """Stream two rows to the empty target table""" + + self._create_test_source_data(data=self.data1) + + loader = UpsertLoader(handle=self.target_ah_dummy, join_cols=self.join_cols) + + source_df = self.source_ah.read() + + loader.save(source_df) + + self.assertDataframeMatches(self.target_dh_dummy.read(), None, self.data1) + + def test_02_can_perform_incremental_append(self): + """The target table is already filled from before. + One new rows appear in the source table to be streamed + """ + + existing_rows = self.target_dh_dummy.read().collect() + self.assertEqual(2, len(existing_rows)) + + loader = UpsertLoader(handle=self.target_ah_dummy, join_cols=self.join_cols) + + self._create_test_source_data(data=self.data2) + + source_df = self.source_ah.read() + + loader.save(source_df) + + self.assertDataframeMatches( + self.target_dh_dummy.read(), None, self.data1 + self.data2 + ) + + def test_03_can_perform_merge(self): + """The target table is already filled from before. + Two new rows appear in the source table - one of them will be merged. + """ + existing_rows = self.target_dh_dummy.read().collect() + self.assertEqual(3, len(existing_rows)) + + loader = UpsertLoader(handle=self.target_ah_dummy, join_cols=self.join_cols) + + self._create_test_source_data(data=self.data3) + + source_df = self.source_ah.read() + + loader.save(source_df) + + self.assertDataframeMatches(self.target_dh_dummy.read(), None, self.data4) + + def _create_test_source_data( + self, tableid: str = None, data: List[Tuple[int, int, str]] = None + ): + + if tableid is None: + tableid = self.source_table_id + if data is None: + raise ValueError("Testdata missing.") + + dh = DeltaHandle.from_tc(tableid) + + Spark.get().sql( + f""" + CREATE TABLE IF NOT EXISTS {dh.get_tablename()} + ( + id int, + name string + ) + """ + ) + + df_source = DataframeCreator.make_partial( + self.dummy_schema, self.dummy_columns, data + ) + + dh.append(df_source)