diff --git a/thunder/executors/data_dependent_partition.py b/thunder/executors/data_dependent_partition.py index 6eab16e926..44dc741179 100644 --- a/thunder/executors/data_dependent_partition.py +++ b/thunder/executors/data_dependent_partition.py @@ -6,6 +6,7 @@ from itertools import chain import thunder.core.utils as utils +from thunder.core.trace import TraceCtx from thunder.core.symbol import BoundSymbol from thunder.core.proxies import variableify, Proxy from thunder.core.prims import PrimIDs @@ -73,18 +74,24 @@ def push_node(n: Node, index: int): # assumes bound_symbol comes in as a DAG and in valid topo order +# NOTE: consolidate graph implementations, we have several almost identical +# implementations already class Graph: - def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]): + def __init__(self, trace: TraceCtx): self.roots: list[Node] = [] self.return_node: None | Node = None - self.counter = 0 + self.counter = len(trace.bound_symbols) - # NOTE This part is dog slow. Folks have been suggesting that it's coming from the dict hash on BoundSymbolInterface - bsym_to_node_map: dict[BoundSymbolInterface, Node] = {} - for bsym in bound_symbols: - node = Node(self.counter, [bsym], [self.counter], self.counter, self.counter) - self.counter = self.counter + 1 - bsym_to_node_map[bsym] = node + producers = utils.producers(trace, _map_to_numbers=True) + consumers = utils.consumers(trace, _map_to_numbers=True) + + # Note, even though BoundSymbolInterface is hashable, it's hash is very slow + # as it appears to be far off from being universal. + # We use indices as hash values instead. + bsym_id_to_node_map: list[int] = [] + for bsym_id, bsym in enumerate(trace.bound_symbols): + node = Node(bsym_id, [bsym], [bsym_id], bsym_id, bsym_id) + bsym_id_to_node_map.append(node) if bsym.sym.id is PrimIDs.RETURN: utils.check( @@ -93,14 +100,16 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]): ) self.return_node = node - for bsym, node in bsym_to_node_map.items(): + for bsym_id, node in enumerate(bsym_id_to_node_map): has_parents: bool = False + + bsym = node.group_bsyms[0] for inp in bsym.flat_args: if not isinstance(inp, Proxy): continue - producer = producers[inp] - parent = bsym_to_node_map[producer] + producer_id = producers[inp] + parent = bsym_id_to_node_map[producer_id] node.parents.add(parent) has_parents = True @@ -115,9 +124,9 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]): if variableify(out) in (variableify(x) for x in bsym.flat_args): continue - children = consumers.get(out, []) - for child in children: - child_node = bsym_to_node_map[child] + children_ids = consumers.get(out, []) + for child_id in children_ids: + child_node = bsym_id_to_node_map[child_id] node.children.add(child_node) def __repr__(self) -> str: @@ -280,8 +289,8 @@ def update_candidate(schedule_op): return topo_order_groups -def fuse_bound_symbols(producer, consumer, bound_symbols, merge_func: Callable): - graph = Graph(producer, consumer, bound_symbols) +def fuse_bound_symbols(trace: TraceCtx, merge_func: Callable): + graph = Graph(trace) dataflow_merge(graph, merge_func) ret = horizontal_merge(graph, merge_func) return ret diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 05dbfa3efe..2b586e521f 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -710,7 +710,7 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - bound_symbol_groups = fuse_bound_symbols(producers, consumers, trace.bound_symbols, _should_fuse) + bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) # Counts how many fusions (per executor) have been constructed # (Used to name fusions like nvFusion0, nvFusion1, ...) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 04956274dc..9d82e12526 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -151,7 +151,7 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - bound_symbol_groups = fuse_bound_symbols(producers, consumers, trace.bound_symbols, _should_fuse) + bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) # Counts how many fusions (per executor) have been constructed fusion_counter: int = 0