From d6f4e6c11bda43339bab019e46f9f006fab3294f Mon Sep 17 00:00:00 2001 From: Angelika Tarnawa Date: Wed, 25 Oct 2023 09:58:54 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Added=20validate=5Fdf=5Fdict=20para?= =?UTF-8?q?meter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../flows/test_mediatool_to_adls.py | 47 +++++++++++++++++++ viadot/flows/mediatool_to_adls.py | 10 ++++ 2 files changed, 57 insertions(+) diff --git a/tests/integration/flows/test_mediatool_to_adls.py b/tests/integration/flows/test_mediatool_to_adls.py index 88746d16b..d7b5b2658 100644 --- a/tests/integration/flows/test_mediatool_to_adls.py +++ b/tests/integration/flows/test_mediatool_to_adls.py @@ -5,6 +5,7 @@ import pytest from viadot.flows import MediatoolToADLS +from viadot.exceptions import ValidationError DATA = {"country": ["DK", "DE"], "sales": [3, 4]} ADLS_FILE_NAME = "test_mediatool.parquet" @@ -28,5 +29,51 @@ def test_mediatool_to_adls_run_flow(mocked_class): ) result = flow.run() assert result.is_successful() + assert len(flow.tasks) == 10 + os.remove("test_mediatool_to_adls_flow_run.parquet") + os.remove("test_mediatool_to_adls_flow_run.json") + + +@mock.patch( + "viadot.tasks.MediatoolToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_mediatool_to_adls_run_flow_validate_fail(mocked_class): + flow = MediatoolToADLS( + "test_mediatool_to_adls_flow_run", + organization_ids=["1000000001", "200000001"], + media_entries_columns=["id", "name", "num"], + mediatool_credentials_key="MEDIATOOL-TESTS", + overwrite_adls=True, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + validate_df_dict={"column_size": {"country": 10}}, + ) + try: + flow.run() + except ValidationError: + pass + + +@mock.patch( + "viadot.tasks.MediatoolToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_mediatool_to_adls_run_flow_validate_success(mocked_class): + flow = MediatoolToADLS( + "test_mediatool_to_adls_flow_run", + organization_ids=["1000000001", "200000001"], + media_entries_columns=["id", "name", "num"], + mediatool_credentials_key="MEDIATOOL-TESTS", + overwrite_adls=True, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + validate_df_dict={"column_size": {"country": 2}}, + ) + result = flow.run() + assert result.is_successful() + assert len(flow.tasks) == 11 os.remove("test_mediatool_to_adls_flow_run.parquet") os.remove("test_mediatool_to_adls_flow_run.json") diff --git a/viadot/flows/mediatool_to_adls.py b/viadot/flows/mediatool_to_adls.py index f87a6293b..c4e92432b 100644 --- a/viadot/flows/mediatool_to_adls.py +++ b/viadot/flows/mediatool_to_adls.py @@ -16,6 +16,7 @@ df_to_parquet, dtypes_to_json_task, update_dtypes_dict, + validate_df, ) from ..tasks import AzureDataLakeUpload, MediatoolToDF @@ -41,6 +42,7 @@ def __init__( adls_sp_credentials_secret: str = None, overwrite_adls: bool = False, if_exists: str = "replace", + validate_df_dict: Dict[str, Any] = None, *args: List[Any], **kwargs: Dict[str, Any], ): @@ -66,6 +68,8 @@ def __init__( Defaults to None. overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False. if_exists (str, optional): What to do if the file exists. Defaults to "replace". + validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output + dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None. """ # MediatoolToDF self.organization_ids = organization_ids @@ -73,6 +77,7 @@ def __init__( self.mediatool_credentials_key = mediatool_credentials_key self.media_entries_columns = media_entries_columns self.vault_name = vault_name + self.validate_df_dict = validate_df_dict # AzureDataLakeUpload self.overwrite = overwrite_adls @@ -119,6 +124,11 @@ def gen_flow(self) -> Flow: media_entries_columns=self.media_entries_columns, flow=self, ) + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self) df_casted_to_str = cast_df_to_str(df_with_metadata, flow=self)