From 9449965484f814cc7abdfa6412c46603aa93db6f Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sun, 19 Mar 2023 13:45:25 -0700 Subject: [PATCH] Allows nesting of subdags This was previously banned for no reason other than I didn't see a reason to enable it. Now people have been asking for it, and its well-worth the effort. --- hamilton/function_modifiers/recursive.py | 10 +----- tests/function_modifiers/test_recursive.py | 40 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index 286f7d825..566e4a70b 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -256,16 +256,13 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No :param nodes: :return: """ - already_namespaced_nodes = [] + # already_namespaced_nodes = [] new_nodes = [] new_name_map = {} # First pass we validate + collect names so we can alter dependencies for node_ in nodes: new_name = assign_namespace(node_.name, namespace) new_name_map[node_.name] = new_name - current_node_namespaces = node_.namespace - if current_node_namespaces: - already_namespaced_nodes.append(node_) for dep, value in self.inputs.items(): # We create nodes for both namespace assignment and source assignment # Why? Cause we need unique parameter names, and with source() some can share params @@ -274,11 +271,6 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No for dep, value in self.config.items(): new_name_map[dep] = assign_namespace(dep, namespace) - if already_namespaced_nodes: - raise ValueError( - f"The following nodes are already namespaced: {already_namespaced_nodes}. " - f"We currently do not allow for multiple namespaces (E.G. layered subDAGs)." - ) # Reassign sources for node_ in nodes: new_name = new_name_map[node_.name] diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index e534adce5..8d61508d5 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -4,7 +4,7 @@ import tests.resources.reuse_subdag from hamilton import ad_hoc_utils, graph -from hamilton.function_modifiers import config, parameterized_subdag, recursive, value +from hamilton.function_modifiers import config, parameterized_subdag, recursive, subdag, value from hamilton.function_modifiers.dependencies import source @@ -269,3 +269,41 @@ def subdag_processor(foo: int, bar: int, baz: int) -> Tuple[int, int, int]: assert nodes_by_name["v0.baz"].callable(**{"v0.foo": 1, "v0.bar": 2}) == 3 assert nodes_by_name["v1.baz"].callable(**{"v1.foo": 1, "v1.bar": 2}) == 3 assert nodes_by_name["v2.baz"].callable(**{"v2.foo": 1, "v2.bar": 2}) == 2 + + +def test_nested_subdag(): + def bar(input_1: int) -> int: + return input_1 + 1 + + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"plus_one": True}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"plus_one": False}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + fg = graph.FunctionGraph(full_module, config={}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) + )