Skip to content

Commit

Permalink
✨ Added validate_df_dict parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
angelika233 committed Oct 25, 2023
1 parent 5379468 commit d6f4e6c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/integration/flows/test_mediatool_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from viadot.flows import MediatoolToADLS
from viadot.exceptions import ValidationError

DATA = {"country": ["DK", "DE"], "sales": [3, 4]}
ADLS_FILE_NAME = "test_mediatool.parquet"
Expand All @@ -28,5 +29,51 @@ def test_mediatool_to_adls_run_flow(mocked_class):
)
result = flow.run()
assert result.is_successful()
assert len(flow.tasks) == 10
os.remove("test_mediatool_to_adls_flow_run.parquet")
os.remove("test_mediatool_to_adls_flow_run.json")


@mock.patch(
"viadot.tasks.MediatoolToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_mediatool_to_adls_run_flow_validate_fail(mocked_class):
flow = MediatoolToADLS(
"test_mediatool_to_adls_flow_run",
organization_ids=["1000000001", "200000001"],
media_entries_columns=["id", "name", "num"],
mediatool_credentials_key="MEDIATOOL-TESTS",
overwrite_adls=True,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
validate_df_dict={"column_size": {"country": 10}},
)
try:
flow.run()
except ValidationError:
pass


@mock.patch(
"viadot.tasks.MediatoolToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_mediatool_to_adls_run_flow_validate_success(mocked_class):
flow = MediatoolToADLS(
"test_mediatool_to_adls_flow_run",
organization_ids=["1000000001", "200000001"],
media_entries_columns=["id", "name", "num"],
mediatool_credentials_key="MEDIATOOL-TESTS",
overwrite_adls=True,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
validate_df_dict={"column_size": {"country": 2}},
)
result = flow.run()
assert result.is_successful()
assert len(flow.tasks) == 11
os.remove("test_mediatool_to_adls_flow_run.parquet")
os.remove("test_mediatool_to_adls_flow_run.json")
10 changes: 10 additions & 0 deletions viadot/flows/mediatool_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
df_to_parquet,
dtypes_to_json_task,
update_dtypes_dict,
validate_df,
)
from ..tasks import AzureDataLakeUpload, MediatoolToDF

Expand All @@ -41,6 +42,7 @@ def __init__(
adls_sp_credentials_secret: str = None,
overwrite_adls: bool = False,
if_exists: str = "replace",
validate_df_dict: Dict[str, Any] = None,
*args: List[Any],
**kwargs: Dict[str, Any],
):
Expand All @@ -66,13 +68,16 @@ def __init__(
Defaults to None.
overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False.
if_exists (str, optional): What to do if the file exists. Defaults to "replace".
validate_df_dict (Dict[str,Any], optional): A dictionary with optional list of tests to verify the output
dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None.
"""
# MediatoolToDF
self.organization_ids = organization_ids
self.mediatool_credentials = mediatool_credentials
self.mediatool_credentials_key = mediatool_credentials_key
self.media_entries_columns = media_entries_columns
self.vault_name = vault_name
self.validate_df_dict = validate_df_dict

# AzureDataLakeUpload
self.overwrite = overwrite_adls
Expand Down Expand Up @@ -119,6 +124,11 @@ def gen_flow(self) -> Flow:
media_entries_columns=self.media_entries_columns,
flow=self,
)
if self.validate_df_dict:
validation_task = validate_df.bind(
df, tests=self.validate_df_dict, flow=self
)
validation_task.set_upstream(df, flow=self)

df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self)
df_casted_to_str = cast_df_to_str(df_with_metadata, flow=self)
Expand Down

0 comments on commit d6f4e6c

Please sign in to comment.