diff --git a/hamilton/graph.py b/hamilton/graph.py index c4d9c6e32..34ba41a12 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -520,6 +520,21 @@ def create_networkx_graph( return digraph +def _validate_all_nodes(adapter: LifecycleAdapterSet, nodes: Dict[str, node.Node]): + invalid_nodes_with_errors = [] + if adapter.does_method("do_validate_node", is_async=False): + for n in nodes.values(): + is_valid, error = adapter.call_lifecycle_method_sync("do_validate_node", node=n) + if not is_valid: + invalid_nodes_with_errors.append((n, error)) + + if invalid_nodes_with_errors: + raise ValueError( + "The following nodes are invalid:\n" + + "\n".join([f"{n.name} : {error}" for n, error in invalid_nodes_with_errors]) + ) + + class FunctionGraph: """Note: this object should be considered private until stated otherwise. @@ -542,6 +557,8 @@ def __init__( if adapter is None: adapter = LifecycleAdapterSet(base.SimplePythonDataFrameGraphAdapter()) + _validate_all_nodes(adapter, nodes) + self._config = config self.nodes = nodes self.adapter = adapter diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index 8a202ec2f..208de5bbd 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -1,6 +1,6 @@ import abc from abc import ABC -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from hamilton import node from hamilton.lifecycle.base import ( @@ -8,9 +8,11 @@ BaseDoCheckEdgeTypesMatch, BaseDoNodeExecute, BaseDoValidateInput, + BaseDoValidateNode, BasePostNodeExecute, BasePreNodeExecute, ) +from hamilton.node import DependencyType try: from typing import override @@ -307,3 +309,53 @@ def run_to_execute_node( :return: The result of the node execution -- up to you to return this. """ pass + + +class NodeValidationMethod(BaseDoValidateNode): + def do_validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]: + return self.validate_node( + node_name=created_node.name, + node_module=created_node.tags.get("module", None), + node_tags=created_node.tags, + required_dependencies=[ + item + for item, dep_type in created_node.input_types.items() + if dep_type == DependencyType.REQUIRED + ], + optional_dependencies=[ + item + for item, dep_type in created_node.input_types.items() + if dep_type == DependencyType.OPTIONAL + ], + node_type=created_node.type, + ) + + def validate_node( + self, + *, + node_name: str, + node_module: Optional[str], + node_tags: Dict[str, str], + required_dependencies: List[str], + optional_dependencies: List[str], + node_type: Type, + **kwargs: Any, + ) -> Tuple[bool, Optional[str]]: + """Validate a node. You have access to tags, types, etc... + We also reserve the right to add future kwargs. This is after node creation, + during graph construction. + + Note that this method allows you to raise an InvalidNodeException if you want to + stop the graph construction. This is useful if you want to do some validation + on tags, for instance. + + :param node_name: Name of the node in question + :param node_module: Module of the function that defined the node, if we know it + :param node_tags: Tags of the node + :param required_dependencies: List of required dependencies for the node + :param optional_dependencies: List of optional dependencies for the node + :param node_type: Return type of the node + :param kwargs: Keyword arguments -- this is kept for future backwards compatibility. + :return: Whether or not the node is valid, and an optional error message + """ + pass diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 0c900b77d..40c3fefea 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -199,7 +199,7 @@ def do_validate_input(self, *, node_type: type, input_value: Any) -> bool: @lifecycle.base_method("do_validate_node") class BaseDoValidateNode(abc.ABC): @abc.abstractmethod - def do_validate_node(self, *, created_node: node.Node) -> bool: + def do_validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]: """Validates a node. Note this is *not* integrated yet, so adding this in will be a No-op. In fact, we will likely be changing the API for this to have an optional error message. This is OK, as this is internal facing. @@ -207,7 +207,7 @@ def do_validate_node(self, *, created_node: node.Node) -> bool: Furthermore, we'll be adding in a user-facing API that takes in the tags, name, module, etc... :param created_node: Node that was created. - :return: Whether or not the node is valid. + :return: Whether or not the node is valid, and, if it is not, an error message """ pass