Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP, Node Validation #608

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,21 @@ def create_networkx_graph(
return digraph


def _validate_all_nodes(adapter: LifecycleAdapterSet, nodes: Dict[str, node.Node]):
invalid_nodes_with_errors = []
if adapter.does_method("do_validate_node", is_async=False):
for n in nodes.values():
is_valid, error = adapter.call_lifecycle_method_sync("do_validate_node", node=n)
if not is_valid:
invalid_nodes_with_errors.append((n, error))

if invalid_nodes_with_errors:
raise ValueError(
"The following nodes are invalid:\n"
+ "\n".join([f"{n.name} : {error}" for n, error in invalid_nodes_with_errors])
)


class FunctionGraph:
"""Note: this object should be considered private until stated otherwise.

Expand All @@ -542,6 +557,8 @@ def __init__(
if adapter is None:
adapter = LifecycleAdapterSet(base.SimplePythonDataFrameGraphAdapter())

_validate_all_nodes(adapter, nodes)

self._config = config
self.nodes = nodes
self.adapter = adapter
Expand Down
54 changes: 53 additions & 1 deletion hamilton/lifecycle/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import abc
from abc import ABC
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type

from hamilton import node
from hamilton.lifecycle.base import (
BaseDoBuildResult,
BaseDoCheckEdgeTypesMatch,
BaseDoNodeExecute,
BaseDoValidateInput,
BaseDoValidateNode,
BasePostNodeExecute,
BasePreNodeExecute,
)
from hamilton.node import DependencyType

try:
from typing import override
Expand Down Expand Up @@ -307,3 +309,53 @@ def run_to_execute_node(
:return: The result of the node execution -- up to you to return this.
"""
pass


class NodeValidationMethod(BaseDoValidateNode):
def do_validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]:
return self.validate_node(
node_name=created_node.name,
node_module=created_node.tags.get("module", None),
node_tags=created_node.tags,
required_dependencies=[
item
for item, dep_type in created_node.input_types.items()
if dep_type == DependencyType.REQUIRED
],
optional_dependencies=[
item
for item, dep_type in created_node.input_types.items()
if dep_type == DependencyType.OPTIONAL
],
node_type=created_node.type,
)

def validate_node(
self,
*,
node_name: str,
node_module: Optional[str],
node_tags: Dict[str, str],
required_dependencies: List[str],
optional_dependencies: List[str],
node_type: Type,
**kwargs: Any,
) -> Tuple[bool, Optional[str]]:
"""Validate a node. You have access to tags, types, etc...
We also reserve the right to add future kwargs. This is after node creation,
during graph construction.

Note that this method allows you to raise an InvalidNodeException if you want to
stop the graph construction. This is useful if you want to do some validation
on tags, for instance.

:param node_name: Name of the node in question
:param node_module: Module of the function that defined the node, if we know it
:param node_tags: Tags of the node
:param required_dependencies: List of required dependencies for the node
:param optional_dependencies: List of optional dependencies for the node
:param node_type: Return type of the node
:param kwargs: Keyword arguments -- this is kept for future backwards compatibility.
:return: Whether or not the node is valid, and an optional error message
"""
pass
4 changes: 2 additions & 2 deletions hamilton/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ def do_validate_input(self, *, node_type: type, input_value: Any) -> bool:
@lifecycle.base_method("do_validate_node")
class BaseDoValidateNode(abc.ABC):
@abc.abstractmethod
def do_validate_node(self, *, created_node: node.Node) -> bool:
def do_validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[str]]:
"""Validates a node. Note this is *not* integrated yet, so adding this in will be a No-op.
In fact, we will likely be changing the API for this to have an optional error message.
This is OK, as this is internal facing.

Furthermore, we'll be adding in a user-facing API that takes in the tags, name, module, etc...

:param created_node: Node that was created.
:return: Whether or not the node is valid.
:return: Whether or not the node is valid, and, if it is not, an error message
"""
pass

Expand Down