Skip to content

Commit

Permalink
Adds scaffolding for configurable validators
Browse files Browse the repository at this point in the history
This allows us to configure validation on the per-node level.
Done at config-time, this sets us up to determine how to add a validator
from passed in configuration.
  • Loading branch information
elijahbenizzy committed Feb 26, 2023
1 parent d994046 commit 2944393
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions hamilton/function_modifiers/validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import dataclasses
from typing import Any, Callable, Collection, Dict, List, Type

from hamilton import node
Expand All @@ -12,6 +13,18 @@
DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG = "hamilton.data_quality.source_node"


@dataclasses.dataclass
class ValidatorConfig:
should_run: bool
importance: dq_base.DataValidationLevel

@staticmethod
def from_validator(
validator: dq_base.DataValidator, config: Dict[str, Any]
) -> "ValidatorConfig":
return ValidatorConfig(should_run=True, importance=validator.importance)


class BaseDataValidationDecorator(base.NodeTransformer):
@abc.abstractmethod
def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]:
Expand All @@ -25,6 +38,7 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:

raw_node = node.Node(
name=node_.name
+ "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want
Expand All @@ -37,8 +51,11 @@ def transform_node(
)
validators = self.get_validators(node_)
validator_nodes = []
validator_name_map = {}
validator_name_config_map = {}
for validator in validators:
validator_config = ValidatorConfig.from_validator(validator, config)
if not validator_config.should_run:
continue

def validation_function(validator_to_call: dq_base.DataValidator = validator, **kwargs):
result = list(kwargs.values())[0] # This should just have one kwarg
Expand All @@ -60,11 +77,11 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, **
},
},
)
validator_name_map[validator_node_name] = validator
validator_name_config_map[validator_node_name] = (validator, validator_config)
validator_nodes.append(validator_node)

def final_node_callable(
validator_nodes=validator_nodes, validator_name_map=validator_name_map, **kwargs
validator_nodes=validator_nodes, validator_name_map=validator_name_config_map, **kwargs
):
"""Callable for the final node. First calls the action on every node, then
Expand All @@ -75,11 +92,11 @@ def final_node_callable(
"""
failures = []
for validator_node in validator_nodes:
validator: dq_base.DataValidator = validator_name_map[validator_node.name]
validator, config = validator_name_map[validator_node.name]
validation_result: dq_base.ValidationResult = kwargs[validator_node.name]
if validator.importance == dq_base.DataValidationLevel.WARN:
if config.importance == dq_base.DataValidationLevel.WARN:
dq_base.act_warn(node_.name, validation_result, validator)
else:
elif config.importance == dq_base.DataValidationLevel.FAIL:
failures.append((validation_result, validator))
dq_base.act_fail_bulk(node_.name, failures)
return kwargs[raw_node.name]
Expand Down

0 comments on commit 2944393

Please sign in to comment.