diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index 1253f251f..93dc20de1 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -1,4 +1,5 @@ import abc +import dataclasses from typing import Any, Callable, Collection, Dict, List, Type from hamilton import node @@ -12,6 +13,18 @@ DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG = "hamilton.data_quality.source_node" +@dataclasses.dataclass +class ValidatorConfig: + should_run: bool + importance: dq_base.DataValidationLevel + + @staticmethod + def from_validator( + validator: dq_base.DataValidator, config: Dict[str, Any] + ) -> "ValidatorConfig": + return ValidatorConfig(should_run=True, importance=validator.importance) + + class BaseDataValidationDecorator(base.NodeTransformer): @abc.abstractmethod def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: @@ -25,6 +38,7 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: + raw_node = node.Node( name=node_.name + "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want @@ -37,8 +51,11 @@ def transform_node( ) validators = self.get_validators(node_) validator_nodes = [] - validator_name_map = {} + validator_name_config_map = {} for validator in validators: + validator_config = ValidatorConfig.from_validator(validator, config) + if not validator_config.should_run: + continue def validation_function(validator_to_call: dq_base.DataValidator = validator, **kwargs): result = list(kwargs.values())[0] # This should just have one kwarg @@ -60,11 +77,11 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, ** }, }, ) - validator_name_map[validator_node_name] = validator + validator_name_config_map[validator_node_name] = (validator, validator_config) validator_nodes.append(validator_node) def final_node_callable( - validator_nodes=validator_nodes, validator_name_map=validator_name_map, **kwargs + validator_nodes=validator_nodes, validator_name_map=validator_name_config_map, **kwargs ): """Callable for the final node. First calls the action on every node, then @@ -75,11 +92,11 @@ def final_node_callable( """ failures = [] for validator_node in validator_nodes: - validator: dq_base.DataValidator = validator_name_map[validator_node.name] + validator, config = validator_name_map[validator_node.name] validation_result: dq_base.ValidationResult = kwargs[validator_node.name] - if validator.importance == dq_base.DataValidationLevel.WARN: + if config.importance == dq_base.DataValidationLevel.WARN: dq_base.act_warn(node_.name, validation_result, validator) - else: + elif config.importance == dq_base.DataValidationLevel.FAIL: failures.append((validation_result, validator)) dq_base.act_fail_bulk(node_.name, failures) return kwargs[raw_node.name]