diff --git a/hamilton/data_quality/base.py b/hamilton/data_quality/base.py index 83f106f1e..c3015bd5b 100644 --- a/hamilton/data_quality/base.py +++ b/hamilton/data_quality/base.py @@ -2,7 +2,7 @@ import dataclasses import enum import logging -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Tuple logger = logging.getLogger(__name__) @@ -25,6 +25,13 @@ class ValidationResult: ) # Any extra diagnostics information needed, free-form +def matches_any_type(datatype: type, applicable_types: List[type]) -> bool: + for type_ in applicable_types: + if type_ == Any or issubclass(datatype, type_): + return True + return False + + class DataValidator(abc.ABC): """Base class for a data quality operator. This will be used by the `data_quality` operator""" @@ -35,13 +42,24 @@ def __init__(self, importance: str): def importance(self) -> DataValidationLevel: return self._importance - @abc.abstractmethod - def applies_to(self, datatype: Type[Type]) -> bool: - """Whether or not this data validator can apply to the specified dataset + @classmethod + def applies_to(cls, datatype: type) -> bool: + """Whether or not this data validator can apply to the specified dataset. + Note that overriding this is not the intended API (it was the old one), + but this will be a stable part of the API moving forward, at least until + Hamilton 2.0. - :param datatype: + :param datatype: Datatype to validate. :return: True if it can be run on the specified type, false otherwise """ + return matches_any_type(datatype, cls.applicable_types()) + + @classmethod + def applicable_types(cls) -> List[type]: + """Returns the list of classes for which this is valid. + + :return: List of classes + """ pass @abc.abstractmethod @@ -118,7 +136,7 @@ def __init__(self, importance: str): @classmethod @abc.abstractmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applicable_types(cls) -> List[type]: pass @abc.abstractmethod diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index 23196a0d9..dfce8323d 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -24,8 +24,8 @@ def arg(cls) -> str: return "range" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) # TODO -- handle dataframes? + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that the datapoint falls within the range ({self.range[0]}, {self.range[1]})" @@ -69,8 +69,8 @@ def arg(cls) -> str: return "values_in" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) # TODO -- handle dataframes? + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that all data points are from a fixed set of values: ({self.values}), ignoring NA values." @@ -113,8 +113,8 @@ def __init__(self, range: Tuple[numbers.Real, numbers.Real], importance: str): self.range = range @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, numbers.Real) + def applicable_types(cls) -> List[type]: + return [numbers.Real] def description(self) -> str: return f"Validates that the datapoint falls within the range ({self.range[0]}, {self.range[1]})" @@ -151,10 +151,8 @@ def arg(cls) -> str: return "values_in" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, numbers.Real) or issubclass( - datatype, str - ) # TODO support list, dict and typing.* variants + def applicable_types(cls) -> List[type]: + return [numbers.Real, str] def description(self) -> str: return f"Validates that python values are from a fixed set of values: ({self.values})." @@ -189,8 +187,8 @@ def _to_percent(fraction: float): return "{0:.2%}".format(fraction) @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that no more than {MaxFractionNansValidatorPandasSeries._to_percent(self.max_fraction_nans)} of the data is Nan." @@ -251,8 +249,8 @@ def __init__(self, data_type: Type[Type], importance: str): self.datatype = data_type @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that the datatype of the pandas series is a subclass of: {self.datatype}" @@ -282,8 +280,8 @@ def __init__(self, data_type: Type[Type], importance: str): self.datatype = data_type @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, numbers.Real) or datatype in (str, bool) + def applicable_types(cls) -> List[type]: + return [numbers.Real, str, bool, int, float, list, dict] def description(self) -> str: return f"Validates that the datatype of the pandas series is a subclass of: {self.datatype}" @@ -312,8 +310,8 @@ def __init__(self, max_standard_dev: float, importance: str): self.max_standard_dev = max_standard_dev @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that the standard deviation of a pandas series is no greater than : {self.max_standard_dev}" @@ -340,8 +338,8 @@ def __init__(self, mean_in_range: Tuple[float, float], importance: str): self.mean_in_range = mean_in_range @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass(datatype, pd.Series) + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: return f"Validates that a pandas series has mean in range [{self.mean_in_range[0]}, {self.mean_in_range[1]}]" @@ -368,8 +366,8 @@ def __init__(self, allow_none: bool, importance: str): self.allow_none = allow_none @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return True + def applicable_types(cls) -> List[type]: + return [Any] def description(self) -> str: if self.allow_none: diff --git a/hamilton/data_quality/pandera_validators.py b/hamilton/data_quality/pandera_validators.py index c7c83137d..7b08ff58c 100644 --- a/hamilton/data_quality/pandera_validators.py +++ b/hamilton/data_quality/pandera_validators.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import List import pandas as pd import pandera as pa @@ -14,10 +14,8 @@ def __init__(self, schema: pa.DataFrameSchema, importance: str): self.schema = schema @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass( - datatype, pd.DataFrame - ) # TODO -- allow for modin, etc. as they come for free with pandera + def applicable_types(cls) -> List[type]: + return [pd.DataFrame] def description(self) -> str: return "Validates that the returned dataframe matches the pander" @@ -54,10 +52,8 @@ def __init__(self, schema: pa.SeriesSchema, importance: str): self.schema = schema @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return issubclass( - datatype, pd.Series - ) # TODO -- allow for modin, etc. as they come for free with pandera + def applicable_types(cls) -> List[type]: + return [pd.Series] def description(self) -> str: pass diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index cb18966ab..04bbe55e1 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -3,6 +3,7 @@ import functools import itertools import logging +import uuid from abc import ABC try: @@ -704,6 +705,7 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] which configuration they need. :return: A list of nodes into which this function transforms. """ + try: function_decorators = get_node_decorators(fn, config) node_resolvers = function_decorators[NodeResolver.get_lifecycle_name()] @@ -734,3 +736,35 @@ class InvalidDecoratorException(Exception): class MissingConfigParametersException(Exception): pass + + +def create_anonymous_node_name(original_node_name: str, *suffixes: str) -> str: + """Creates an anonymous node name. This is specifically for decorators that + rely on temporary /intermediate nodes. Note that these are not part of the contract, and might + change at any given point. + + The algorithm that this follows is simple: + + 1. Start with the original node name + 2. Append the suffixes followed by underscores + 3. If that name is taken, then append a number (_1, _2, etc...) to make it unique. + + Note that the "stability" of this depends on the nodes being processed in a specific order, + which should be respected in function_modifiers_base. The order is: + + 1. By node lifecycle + 2. By decorator application order + + The only likely conflicts come from multiple similar decorators (E.G. check_output) decorating + the same node. + + :param original_node_name: Name of the original node that this is related to. + :param suffixes: Suffixes to append to the original node. + :return: new node name + """ + uid = str(uuid.uuid4())[0:8] + name = original_node_name + for suffix in suffixes: + name += f"_{suffix}" + name += f"_{uid}" + return name diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 04d81946f..2dae352e8 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -588,7 +588,7 @@ def validate_return_type(fn: Callable): registry.get_column_type_from_df_type(output_type) except NotImplementedError: raise base.InvalidDecoratorException( - # TODO: capture was dataframe libraries are supported and print here. + # TODO: capture what dataframe libraries are supported and print here. f"Error {fn} does not output a type we know about. Is it a dataframe type we " f"support? " ) diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index 5a6dd055b..efc6fba15 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -26,8 +26,7 @@ def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: raw_node = node.Node( - name=node_.name - + "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want + name=base.create_anonymous_node_name(node_.name, "raw"), typ=node_.type, doc_string=node_.documentation, callabl=node_.callable, diff --git a/tests/function_modifiers/test_validation.py b/tests/function_modifiers/test_validation.py index 71fa26ee7..cb3c5d65e 100644 --- a/tests/function_modifiers/test_validation.py +++ b/tests/function_modifiers/test_validation.py @@ -33,23 +33,27 @@ def fn(input: pd.Series) -> pd.Series: subdag = decorator.transform_node(node_, config={}, fn=fn) assert 4 == len(subdag) subdag_as_dict = {node_.name: node_ for node_ in subdag} - assert sorted(subdag_as_dict.keys()) == [ + prefixes = [ "fn", "fn_dummy_data_validator_2", "fn_dummy_data_validator_3", "fn_raw", ] - # TODO -- change when we change the naming scheme - assert subdag_as_dict["fn_raw"].input_types["input"][1] == DependencyType.REQUIRED + sorted_keys = sorted(subdag_as_dict) + assert all([node_name.startswith(prefix) for node_name, prefix in zip(sorted_keys, prefixes)]) + assert subdag_as_dict[sorted_keys[-1]].input_types["input"][1] == DependencyType.REQUIRED assert 3 == len( subdag_as_dict["fn"].input_types ) # Three dependencies -- the two with DQ + the original # The final function should take in everything but only use the raw results + raw_node_name = sorted_keys[-1] assert ( subdag_as_dict["fn"].callable( - fn_raw="test", - fn_dummy_data_validator_2=ValidationResult(True, "", {}), - fn_dummy_data_validator_3=ValidationResult(True, "", {}), + **{ + raw_node_name: "test", + "fn_dummy_data_validator_2": ValidationResult(True, "", {}), + "fn_dummy_data_validator_3": ValidationResult(True, "", {}), + } ) == "test" ) @@ -68,14 +72,17 @@ def fn(input: pd.Series) -> pd.Series: subdag = decorator.transform_node(node_, config={}, fn=fn) assert 4 == len(subdag) subdag_as_dict = {node_.name: node_ for node_ in subdag} - assert sorted(subdag_as_dict.keys()) == [ + prefixes = [ "fn", "fn_dummy_data_validator_2", "fn_dummy_data_validator_3", "fn_raw", ] + sorted_keys = sorted(subdag_as_dict) + assert all([node_name.startswith(prefix) for node_name, prefix in zip(sorted_keys, prefixes)]) + raw_node_name = sorted_keys[-1] # TODO -- change when we change the naming scheme - assert subdag_as_dict["fn_raw"].input_types["input"][1] == DependencyType.REQUIRED + assert subdag_as_dict[raw_node_name].input_types["input"][1] == DependencyType.REQUIRED assert 3 == len( subdag_as_dict["fn"].input_types ) # Three dependencies -- the two with DQ + the original @@ -98,9 +105,11 @@ def fn(input: pd.Series) -> pd.Series: # The final function should take in everything but only use the raw results assert ( subdag_as_dict["fn"].callable( - fn_raw="test", - fn_dummy_data_validator_2=ValidationResult(True, "", {}), - fn_dummy_data_validator_3=ValidationResult(True, "", {}), + **{ + raw_node_name: "test", + "fn_dummy_data_validator_2": ValidationResult(True, "", {}), + "fn_dummy_data_validator_3": ValidationResult(True, "", {}), + } ) == "test" ) @@ -119,12 +128,15 @@ def fn(input: pd.Series) -> pd.Series: subdag = decorator.transform_node(node_, config={}, fn=fn) assert 4 == len(subdag) subdag_as_dict = {node_.name: node_ for node_ in subdag} + (raw_node_name,) = [item for item in subdag_as_dict if item.startswith("fn_raw_")] with pytest.raises(DataValidationError): subdag_as_dict["fn"].callable( - fn_raw=pd.Series([1.0, 2.0, 3.0]), - fn_dummy_data_validator_2=ValidationResult(False, "", {}), - fn_dummy_data_validator_3=ValidationResult(False, "", {}), + **{ + raw_node_name: pd.Series([1.0, 2.0, 3.0]), + "fn_dummy_data_validator_2": ValidationResult(False, "", {}), + "fn_dummy_data_validator_3": ValidationResult(False, "", {}), + } ) diff --git a/tests/resources/data_quality.py b/tests/resources/data_quality.py index 4d8201ba2..a05ffdf48 100644 --- a/tests/resources/data_quality.py +++ b/tests/resources/data_quality.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd from hamilton.function_modifiers import check_output @@ -11,10 +12,9 @@ def data_might_be_in_range(data_quality_should_fail: bool) -> pd.Series: return pd.Series([0.5]) -# TODO -- enable this once we fix the double-data-quality decorators with the same name bug -# @check_output(data_type=np.float) -# @check_output(range=(0, 1)) -# def multi_layered_validator(data_quality_should_fail: bool) -> pd.Series: -# if data_quality_should_fail: -# return pd.Series([10.0]) -# return pd.Series([0.5]) +@check_output(data_type=np.float64) +@check_output(range=(0, 1)) +def multi_layered_validator(data_quality_should_fail: bool) -> pd.Series: + if data_quality_should_fail: + return pd.Series([10.0]) + return pd.Series([0.5]) diff --git a/tests/resources/dq_dummy_examples.py b/tests/resources/dq_dummy_examples.py index 223f8e614..adf2c5138 100644 --- a/tests/resources/dq_dummy_examples.py +++ b/tests/resources/dq_dummy_examples.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import List import pandas as pd @@ -11,8 +11,8 @@ def __init__(self, equal_to: int, importance: str): self.equal_to = equal_to @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return datatype == int + def applicable_types(cls) -> List[type]: + return [int] def description(self) -> str: return "Data must be equal to 10 to be valid" @@ -60,8 +60,8 @@ def validate(self, dataset: pd.Series) -> ValidationResult: ) @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return datatype == pd.Series + def applicable_types(cls) -> List[type]: + return [pd.Series] @classmethod def arg(cls) -> str: @@ -92,8 +92,8 @@ def validate(self, dataset: pd.Series) -> ValidationResult: ) @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: - return datatype == pd.Series + def applicable_types(cls) -> List[type]: + return [pd.Series] @classmethod def arg(cls) -> str: diff --git a/tests/test_data_quality_base.py b/tests/test_data_quality_base.py new file mode 100644 index 000000000..5c555a44e --- /dev/null +++ b/tests/test_data_quality_base.py @@ -0,0 +1,21 @@ +import numbers + +import pandas as pd +import pytest + +from hamilton.data_quality.base import matches_any_type + + +@pytest.mark.parametrize( + "type_,applicable_types,matches", + [ + (int, [int], True), + (str, [int], False), + (str, [int, str], True), + (str, [int, float], False), + (int, [numbers.Real], True), + (pd.Series, [pd.Series, pd.DataFrame], True), + ], +) +def test_matches_any_type(type_, applicable_types, matches): + assert matches_any_type(type_, applicable_types) == matches diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 6fd2ca2fc..9d8d26822 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -29,10 +29,9 @@ def test_data_quality_workflow_passes(): for var in all_vars if var.tags.get("hamilton.data_quality.contains_dq_results", False) ] - assert len(dq_nodes) == 1 - dq_result = result[dq_nodes[0]] - assert isinstance(dq_result, ValidationResult) - assert dq_result.passes is True + assert len(dq_nodes) == 3 + assert all([result[n].passes for n in dq_nodes]) + assert all([isinstance(result[n], ValidationResult) for n in dq_nodes]) def test_data_quality_workflow_fails(): @@ -44,8 +43,8 @@ def test_data_quality_workflow_fails(): ) -# Adapted from https://stackoverflow.com/questions/41858147/how-to-modify-imported-source-code-on-the-fly -# This is needed to decide whether to import annotations... +# Adapted from https://stackoverflow.com/questions/41858147/how-to-modify-imported-source-code-on +# -the-fly This is needed to decide whether to import annotations... def modify_and_import(module_name, package, modification_func): spec = importlib.util.find_spec(module_name, package) source = spec.loader.get_source(module_name)