Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added validate_df task to CustomerGaugeToADLS flow #784

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/integration/flows/test_customer_gauge_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from viadot.flows import CustomerGaugeToADLS
from viadot.exceptions import ValidationError

DATA = {
"user_name": ["Jane", "Bob"],
Expand Down Expand Up @@ -38,5 +39,56 @@ def test_customer_gauge_to_adls_run_flow(mocked_class):
)
result = flow.run()
assert result.is_successful()
assert len(flow.tasks) == 10
os.remove("test_customer_gauge_to_adls_flow_run.parquet")
os.remove("test_customer_gauge_to_adls_flow_run.json")

@mock.patch(
"viadot.tasks.CustomerGaugeToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_customer_gauge_to_adls_run_flow_validation_success(mocked_class):
flow = CustomerGaugeToADLS(
"test_customer_gauge_to_adls_flow_run",
endpoint="responses",
total_load=False,
anonymize=True,
columns_to_anonymize=COLUMNS,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
overwrite_adls=True,
validate_df_dict={"column_size": {"user_address_state": 2}},
)
result = flow.run()
assert result.is_successful()
assert len(flow.tasks) == 11

os.remove("test_customer_gauge_to_adls_flow_run.parquet")
os.remove("test_customer_gauge_to_adls_flow_run.json")

@mock.patch(
"viadot.tasks.CustomerGaugeToDF.run",
return_value=pd.DataFrame(data=DATA),
)
@pytest.mark.run
def test_customer_gauge_to_adls_run_flow_validation_failure(mocked_class):
flow = CustomerGaugeToADLS(
"test_customer_gauge_to_adls_flow_run",
endpoint="responses",
total_load=False,
anonymize=True,
columns_to_anonymize=COLUMNS,
adls_dir_path=ADLS_DIR_PATH,
adls_file_name=ADLS_FILE_NAME,
overwrite_adls=True,
validate_df_dict={"column_size":{"user_name":5}},
)
try:
flow.run()
except ValidationError:
pass

os.remove("test_customer_gauge_to_adls_flow_run.parquet")
os.remove("test_customer_gauge_to_adls_flow_run.json")

12 changes: 12 additions & 0 deletions viadot/flows/customer_gauge_to_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
df_to_parquet,
dtypes_to_json_task,
update_dtypes_dict,
validate_df,
)
from viadot.tasks import AzureDataLakeUpload, CustomerGaugeToDF

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
adls_sp_credentials_secret: str = None,
overwrite_adls: bool = False,
if_exists: str = "replace",
validate_df_dict: dict = None,
timeout: int = 3600,
*args: List[Any],
**kwargs: Dict[str, Any]
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(
ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake.
Defaults to None.
overwrite_adls (bool, optional): Whether to overwrite files in the lake. Defaults to False.
validate_df_dict (Dict[str], optional): A dictionary with optional list of tests to verify the output dataframe. If defined, triggers the `validate_df` task from task_utils. Defaults to None.
if_exists (str, optional): What to do if the file exists. Defaults to "replace".
timeout (int, optional): The time (in seconds) to wait while running this task before a timeout occurs. Defaults to 3600.
"""
Expand All @@ -105,6 +108,9 @@ def __init__(
self.end_date = end_date
self.customer_gauge_credentials_secret = customer_gauge_credentials_secret

# validate_df
self.validate_df_dict = validate_df_dict

# anonymize_df
self.anonymize = anonymize
self.columns_to_anonymize = columns_to_anonymize
Expand Down Expand Up @@ -169,6 +175,12 @@ def gen_flow(self) -> Flow:
flow=self,
)

if self.validate_df_dict:
validation_task = validate_df.bind(
customerg_df, tests=self.validate_df_dict, flow=self
)
validation_task.set_upstream(customerg_df, flow=self)

if self.anonymize == True:
anonymized_df = anonymize_df.bind(
customerg_df,
Expand Down
Loading