Skip to content

Commit

Permalink
Fixes #575 input type mismatch erroring
Browse files Browse the repository at this point in the history
The problem here was that when we pushed materializers we introduced
a regression where somehow input node types got treated differently.

The bug here was that "input" nodes are created at add dependency time,
thus if two functions request the same input, but provide different types
things would correctly error, but in this case in #575 the types were compatible.

So the design choice here was to:

1. Check that if one is the subset of the other, then we allow it.
2. We then make the subset type the type of the node.
3. We attach originating functions to ensure we can create a good error message.
This has the side-effect of propagating all the way through to Variables and such.
4. The mutating node functions are scoped to only work if the Node is deemed External, that
way we don't use that code path inadvertently in the future.
5. We fix an assumption in visualization that assumed inputs didn't have functions.

Note: the input/external/user-defined node creation should really be pulled out into a separate
step, and not created dynamically in the add_dependency part.
  • Loading branch information
skrawcz committed Dec 4, 2023
1 parent b6024db commit 7a8439e
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 17 deletions.
49 changes: 37 additions & 12 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,45 @@ def add_dependency(
"""
if param_name in nodes:
# validate types match
required_node = nodes[param_name]
if not types_match(adapter, param_type, required_node.type):
required_node: Node = nodes[param_name]
types_do_match = types_match(adapter, param_type, required_node.type)
if not types_do_match and required_node.user_defined:
# check the case that two input type expectations are compatible, e.g. is one a subset of the other
# this could be the case when one is a union and the other is a subset of that union
# which is fine for inputs. If they are not compatible, we raise an error.
types_are_compatible = types_match(adapter, required_node.type, param_type)
if not types_are_compatible:
raise ValueError(
f"Error: Two or more functions are requesting {param_name}, but have incompatible types. "
f"{func_name} requires {param_name} to be {param_type}, but found another function(s) "
f"{[f.__name__ for f in required_node.originating_functions]} that minimally require {param_name} "
f"as {required_node.type}. Please fix this by ensuring that all functions requesting {param_name} "
f"have compatible types. If you believe they are equivalent, please reach out to the developers. "
f"Note that, if you have types that are equivalent for your purposes, you can create a "
f"graph adapter that checks the types against each other in a more lenient manner."
)
else:
# replace the current type with this "tighter" type.
required_node.set_type(param_type)
# add to the originating functions
for og_func in func_node.originating_functions:
required_node.add_originating_function(og_func)
elif not types_do_match:
raise ValueError(
f"Error: {func_name} is expecting {param_name}:{param_type}, but found "
f"{param_name}:{required_node.type}. \nHamilton does not consider these types to be "
f"equivalent. If you believe they are equivalent, please reach out to the developers."
f"equivalent. If you believe they are equivalent, please reach out to the developers. "
f"Note that, if you have types that are equivalent for your purposes, you can create a "
f"graph adapter that checks the types against each other in a more lenient manner."
)
else:
# this is a user defined var
required_node = node.Node(param_name, param_type, node_source=node.NodeType.EXTERNAL)
# this is a user defined var, i.e. an input to the graph.
required_node = node.Node(
param_name,
param_type,
node_source=node.NodeType.EXTERNAL,
originating_functions=func_node.originating_functions,
)
nodes[param_name] = required_node
# add edges
func_node.dependencies.append(required_node)
Expand All @@ -78,6 +105,8 @@ def update_dependencies(
it will deepcopy the dict + nodes and return that. Otherwise it will
mutate + return the passed-in dict + nodes.
Note: this will add in "input" nodes if they are not already present.
:param in_place: Whether or not to modify in-place, or copy/return
:param nodes: Nodes that form the DAG we're updating
:param adapter: Adapter to use for type checking
Expand Down Expand Up @@ -114,7 +143,7 @@ def create_function_graph(
nodes = fg.nodes
functions = sum([find_functions(module) for module in modules], [])

# create nodes -- easier to just create this in one loop
# create non-input nodes -- easier to just create this in one loop
for func_name, f in functions:
for n in fm_base.resolve_nodes(f, config):
if n.name in config:
Expand All @@ -125,7 +154,7 @@ def create_function_graph(
f" Already defined by function {f}"
)
nodes[n.name] = n
# add dependencies -- now that all nodes exist, we just run through edges & validate graph.
# add dependencies -- now that all nodes except input nodes, we just run through edges & validate graph.
nodes = update_dependencies(nodes, adapter, reset_dependencies=False) # no dependencies
# present yet
for key in config.keys():
Expand Down Expand Up @@ -204,11 +233,7 @@ def _get_node_type(n: node.Node) -> str:
Config: is external, doesn't originate from a function, no function depedends on it
Function: others
"""
if (
n._node_source == node.NodeType.EXTERNAL
and n._originating_functions is None
and n._depended_on_by
):
if n._node_source == node.NodeType.EXTERNAL and n._depended_on_by:
return "input"
elif (
n._node_source == node.NodeType.EXTERNAL
Expand Down
17 changes: 17 additions & 0 deletions hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def name(self) -> str:
def type(self) -> Any:
return self._type

def set_type(self, typ: Any):
"""Sets the type of the node"""
assert self.user_defined is True, "Cannot reset type of non-user-defined node"
self._type = typ

@property
def callable(self):
return self._callable
Expand Down Expand Up @@ -196,6 +201,18 @@ def originating_functions(self) -> Optional[Tuple[Callable, ...]]:
"""
return self._originating_functions

def add_originating_function(self, fn: Callable):
"""Adds a function to the list of originating functions.
This is used in the case to attach originating functions to user-defined (i.e. external/input nodes).
:param fn: Function to add
"""
assert self.user_defined is True, "Cannot add originating function to non-user-defined node"
if self._originating_functions is None:
self._originating_functions = (fn,)
else:
self._originating_functions += (fn,)

def add_tag(self, tag_name: str, tag_value: str):
self._tags[tag_name] = tag_value

Expand Down
13 changes: 13 additions & 0 deletions tests/resources/compatible_input_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Union


def b(a: Union[int, str]) -> int:
return a


def c(a: str) -> str:
return a


def d(a: str) -> str:
return a
17 changes: 17 additions & 0 deletions tests/resources/incompatible_input_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Union


def b(a: int) -> int:
return a


def c(a: str) -> str:
return a


def e(d: Union[int, str]) -> int:
return d


def f(d: Union[float, int]) -> float:
return d
187 changes: 184 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

import pandas as pd
import pytest

import hamilton.graph_utils
import hamilton.htypes
import tests.resources.bad_functions
import tests.resources.compatible_input_types
import tests.resources.config_modifier
import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.extract_column_nodes
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
import tests.resources.incompatible_input_types
import tests.resources.layered_decorators
import tests.resources.multiple_decorators_together
import tests.resources.optional_dependencies
import tests.resources.parametrized_inputs
import tests.resources.parametrized_nodes
import tests.resources.test_default_args
import tests.resources.typing_vs_not_typing

import hamilton.graph_utils
import hamilton.htypes
from hamilton import ad_hoc_utils, base, graph, node
from hamilton.execution import graph_functions
from hamilton.node import NodeType
Expand Down Expand Up @@ -108,6 +110,185 @@ def test_add_dependency_strict_node_dependencies():
assert func_node.depended_on_by == []


def test_add_dependency_input_nodes_mismatch_on_types():
"""Tests that if two functions request an input that has incompatible types, we error out."""
b_sig = inspect.signature(tests.resources.incompatible_input_types.b)
c_sig = inspect.signature(tests.resources.incompatible_input_types.c)

nodes = {
"b": node.Node.from_fn(tests.resources.incompatible_input_types.b),
"c": node.Node.from_fn(tests.resources.incompatible_input_types.c),
}
nodes["b"]._originating_functions = (tests.resources.incompatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.incompatible_input_types.c,)
param_name = "a"

# this adds 'a' to nodes
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

assert "a" in nodes

# adding dependency of c on a should fail because the types are incompatible
with pytest.raises(ValueError):
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)


def test_add_dependency_input_nodes_mismatch_on_types_complex():
"""Tests a more complex scenario we don't support right now with input types."""
e_sig = inspect.signature(tests.resources.incompatible_input_types.e)
f_sig = inspect.signature(tests.resources.incompatible_input_types.f)

nodes = {
"e": node.Node.from_fn(tests.resources.incompatible_input_types.e),
"f": node.Node.from_fn(tests.resources.incompatible_input_types.f),
}
nodes["e"]._originating_functions = (tests.resources.incompatible_input_types.e,)
nodes["f"]._originating_functions = (tests.resources.incompatible_input_types.f,)
param_name = "d"

# this adds 'a' to nodes
graph.add_dependency(
nodes["e"],
"e",
nodes,
param_name,
e_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

assert "d" in nodes

# adding dependency of c on a should fail because the types are incompatible
with pytest.raises(ValueError):
graph.add_dependency(
nodes["e"],
"e",
nodes,
param_name,
f_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)


def test_add_dependency_input_nodes_compatible_types():
"""Tests that if functions request an input that we correctly accept compatible types."""
b_sig = inspect.signature(tests.resources.compatible_input_types.b)
c_sig = inspect.signature(tests.resources.compatible_input_types.c)
d_sig = inspect.signature(tests.resources.compatible_input_types.d)

nodes = {
"b": node.Node.from_fn(tests.resources.compatible_input_types.b),
"c": node.Node.from_fn(tests.resources.compatible_input_types.c),
"d": node.Node.from_fn(tests.resources.compatible_input_types.d),
}
nodes["b"]._originating_functions = (tests.resources.compatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.compatible_input_types.c,)
nodes["d"]._originating_functions = (tests.resources.compatible_input_types.d,)
# what we want to add
param_name = "a"

# this adds 'a' to nodes
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

assert "a" in nodes

# this adds 'a' to 'c' as well.
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

# test that we shrink the type to the tighter type
assert nodes["a"].type == str

graph.add_dependency(
nodes["d"],
"d",
nodes,
param_name,
d_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)


def test_add_dependency_input_nodes_compatible_types_order_check():
"""Tests that if functions request an input that we correctly accept compatible types independent of order."""
b_sig = inspect.signature(tests.resources.compatible_input_types.b)
c_sig = inspect.signature(tests.resources.compatible_input_types.c)
d_sig = inspect.signature(tests.resources.compatible_input_types.d)

nodes = {
"b": node.Node.from_fn(tests.resources.compatible_input_types.b),
"c": node.Node.from_fn(tests.resources.compatible_input_types.c),
"d": node.Node.from_fn(tests.resources.compatible_input_types.d),
}
nodes["b"]._originating_functions = (tests.resources.compatible_input_types.b,)
nodes["c"]._originating_functions = (tests.resources.compatible_input_types.c,)
nodes["d"]._originating_functions = (tests.resources.compatible_input_types.d,)
# what we want to add
param_name = "a"

# this adds 'a' to nodes
graph.add_dependency(
nodes["c"],
"c",
nodes,
param_name,
c_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

assert "a" in nodes
assert nodes["a"].type == str

# this adds 'a' to 'c' as well.
graph.add_dependency(
nodes["b"],
"b",
nodes,
param_name,
b_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)

# test that type didn't change
assert nodes["a"].type == str

graph.add_dependency(
nodes["d"],
"d",
nodes,
param_name,
d_sig.parameters[param_name].annotation,
base.SimplePythonDataFrameGraphAdapter(),
)


def test_typing_to_primitive_conversion():
"""Tests that we can mix function output being typing type, and dependent function using primitive type."""
b_sig = inspect.signature(tests.resources.typing_vs_not_typing.B)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pandas as pd
import pytest

import tests.resources.cyclic_functions
import tests.resources.dummy_functions
import tests.resources.dynamic_parallelism.parallel_linear_basic
import tests.resources.tagging
import tests.resources.test_default_args
import tests.resources.very_simple_dag

from hamilton import base, node
from hamilton.driver import (
Builder,
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_driver_variables_exposes_original_function():
var.name: var.originating_functions for var in dr.list_available_variables()
}
assert originating_functions["b"] == (tests.resources.very_simple_dag.b,)
assert originating_functions["a"] is None
assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input


@mock.patch("hamilton.telemetry.send_event_json")
Expand Down

0 comments on commit 7a8439e

Please sign in to comment.