From 6f6581b1e1f767d35004482e79fa3ea1c7fc5964 Mon Sep 17 00:00:00 2001 From: gwieloch Date: Tue, 24 Oct 2023 16:16:26 +0200 Subject: [PATCH] added validate_df parameter to EurostatToADLS --- .../flows/test_eurostat_to_adls.py | 26 ++++++++++++++++++- viadot/flows/eurostat_to_adls.py | 13 ++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/integration/flows/test_eurostat_to_adls.py b/tests/integration/flows/test_eurostat_to_adls.py index 9225655e9..e15f10a7d 100644 --- a/tests/integration/flows/test_eurostat_to_adls.py +++ b/tests/integration/flows/test_eurostat_to_adls.py @@ -6,7 +6,11 @@ from viadot.flows import EurostatToADLS -DATA = {"geo": ["PL", "DE", "NL"], "indicator": [35, 55, 77]} +DATA = { + "geo": ["PL", "DE", "NL"], + "indicator": [35, 55, 77], + "time": ["2023-01", "2023-51", "2023-07"], +} ADLS_FILE_NAME = "test_eurostat.parquet" ADLS_DIR_PATH = "raw/tests/" @@ -28,3 +32,23 @@ def test_eurostat_to_adls_run_flow(mocked_class): assert result.is_successful() os.remove("test_eurostat_to_adls_flow_run.parquet") os.remove("test_eurostat_to_adls_flow_run.json") + + +@mock.patch( + "viadot.tasks.EurostatToDF.run", + return_value=pd.DataFrame(data=DATA), +) +@pytest.mark.run +def test_validate_df(mocked_class): + flow = EurostatToADLS( + "test_validate_df", + dataset_code="ILC_DI04", + overwrite_adls=True, + validate_df_dict={"column_size": {"time": 7}}, + adls_dir_path=ADLS_DIR_PATH, + adls_file_name=ADLS_FILE_NAME, + ) + result = flow.run() + assert result.is_successful() + os.remove("test_validate_df.parquet") + os.remove("test_validate_df.json") diff --git a/viadot/flows/eurostat_to_adls.py b/viadot/flows/eurostat_to_adls.py index 0348d7de4..e6c76a084 100644 --- a/viadot/flows/eurostat_to_adls.py +++ b/viadot/flows/eurostat_to_adls.py @@ -15,6 +15,7 @@ df_to_parquet, dtypes_to_json_task, update_dtypes_dict, + validate_df, ) from ..tasks import AzureDataLakeUpload, EurostatToDF @@ -40,6 +41,7 @@ def __init__( adls_file_name: str = None, adls_sp_credentials_secret: str = None, overwrite_adls: bool = False, + validate_df_dict: dict = None, if_exists: str = "replace", *args: List[Any], **kwargs: Dict[str, Any], @@ -70,6 +72,8 @@ def __init__( ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. Defaults to None. overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False. + validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output dataframe. + If defined, triggers the `validate_df` task from task_utils. Defaults to None if_exists (str, optional): What to do if the file exists. Defaults to "replace". """ @@ -79,6 +83,9 @@ def __init__( self.base_url = base_url self.requested_columns = requested_columns + # validate df + self.validate_df_dict = validate_df_dict + # AzureDataLakeUpload self.overwrite = overwrite_adls self.adls_sp_credentials_secret = adls_sp_credentials_secret @@ -123,6 +130,12 @@ def gen_flow(self) -> Flow: df = df.bind(flow=self) + if self.validate_df_dict: + validation_task = validate_df.bind( + df, tests=self.validate_df_dict, flow=self + ) + validation_task.set_upstream(df, flow=self) + df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self) dtypes_dict = df_get_data_types_task.bind(df_with_metadata, flow=self)