Skip to content

Commit

Permalink
Allows nesting of subdags
Browse files Browse the repository at this point in the history
This was previously banned for no reason other than I didn't see a
reason to enable it. Now people have been asking for it, and its
well-worth the effort.
  • Loading branch information
elijahbenizzy committed Mar 21, 2023
1 parent 8554bbb commit 9449965
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
10 changes: 1 addition & 9 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,13 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No
:param nodes:
:return:
"""
already_namespaced_nodes = []
# already_namespaced_nodes = []
new_nodes = []
new_name_map = {}
# First pass we validate + collect names so we can alter dependencies
for node_ in nodes:
new_name = assign_namespace(node_.name, namespace)
new_name_map[node_.name] = new_name
current_node_namespaces = node_.namespace
if current_node_namespaces:
already_namespaced_nodes.append(node_)
for dep, value in self.inputs.items():
# We create nodes for both namespace assignment and source assignment
# Why? Cause we need unique parameter names, and with source() some can share params
Expand All @@ -274,11 +271,6 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No
for dep, value in self.config.items():
new_name_map[dep] = assign_namespace(dep, namespace)

if already_namespaced_nodes:
raise ValueError(
f"The following nodes are already namespaced: {already_namespaced_nodes}. "
f"We currently do not allow for multiple namespaces (E.G. layered subDAGs)."
)
# Reassign sources
for node_ in nodes:
new_name = new_name_map[node_.name]
Expand Down
40 changes: 39 additions & 1 deletion tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tests.resources.reuse_subdag
from hamilton import ad_hoc_utils, graph
from hamilton.function_modifiers import config, parameterized_subdag, recursive, value
from hamilton.function_modifiers import config, parameterized_subdag, recursive, subdag, value
from hamilton.function_modifiers.dependencies import source


Expand Down Expand Up @@ -269,3 +269,41 @@ def subdag_processor(foo: int, bar: int, baz: int) -> Tuple[int, int, int]:
assert nodes_by_name["v0.baz"].callable(**{"v0.foo": 1, "v0.bar": 2}) == 3
assert nodes_by_name["v1.baz"].callable(**{"v1.foo": 1, "v1.bar": 2}) == 3
assert nodes_by_name["v2.baz"].callable(**{"v2.foo": 1, "v2.bar": 2}) == 2


def test_nested_subdag():
def bar(input_1: int) -> int:
return input_1 + 1

def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"plus_one": True})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"plus_one": False})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph(full_module, config={})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)

0 comments on commit 9449965

Please sign in to comment.