Skip to content

Commit

Permalink
fuse_bound_symbols: perf improvements via faster DAG construction (…
Browse files Browse the repository at this point in the history
…#2091)
  • Loading branch information
nikitaved authored Feb 13, 2024
1 parent 41625e3 commit b4295cd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
41 changes: 25 additions & 16 deletions thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import chain

import thunder.core.utils as utils
from thunder.core.trace import TraceCtx
from thunder.core.symbol import BoundSymbol
from thunder.core.proxies import variableify, Proxy
from thunder.core.prims import PrimIDs
Expand Down Expand Up @@ -73,18 +74,24 @@ def push_node(n: Node, index: int):


# assumes bound_symbol comes in as a DAG and in valid topo order
# NOTE: consolidate graph implementations, we have several almost identical
# implementations already
class Graph:
def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
def __init__(self, trace: TraceCtx):
self.roots: list[Node] = []
self.return_node: None | Node = None
self.counter = 0
self.counter = len(trace.bound_symbols)

# NOTE This part is dog slow. Folks have been suggesting that it's coming from the dict hash on BoundSymbolInterface
bsym_to_node_map: dict[BoundSymbolInterface, Node] = {}
for bsym in bound_symbols:
node = Node(self.counter, [bsym], [self.counter], self.counter, self.counter)
self.counter = self.counter + 1
bsym_to_node_map[bsym] = node
producers = utils.producers(trace, _map_to_numbers=True)
consumers = utils.consumers(trace, _map_to_numbers=True)

# Note, even though BoundSymbolInterface is hashable, it's hash is very slow
# as it appears to be far off from being universal.
# We use indices as hash values instead.
bsym_id_to_node_map: list[int] = []
for bsym_id, bsym in enumerate(trace.bound_symbols):
node = Node(bsym_id, [bsym], [bsym_id], bsym_id, bsym_id)
bsym_id_to_node_map.append(node)

if bsym.sym.id is PrimIDs.RETURN:
utils.check(
Expand All @@ -93,14 +100,16 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
)
self.return_node = node

for bsym, node in bsym_to_node_map.items():
for bsym_id, node in enumerate(bsym_id_to_node_map):
has_parents: bool = False

bsym = node.group_bsyms[0]
for inp in bsym.flat_args:
if not isinstance(inp, Proxy):
continue

producer = producers[inp]
parent = bsym_to_node_map[producer]
producer_id = producers[inp]
parent = bsym_id_to_node_map[producer_id]
node.parents.add(parent)
has_parents = True

Expand All @@ -115,9 +124,9 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
if variableify(out) in (variableify(x) for x in bsym.flat_args):
continue

children = consumers.get(out, [])
for child in children:
child_node = bsym_to_node_map[child]
children_ids = consumers.get(out, [])
for child_id in children_ids:
child_node = bsym_id_to_node_map[child_id]
node.children.add(child_node)

def __repr__(self) -> str:
Expand Down Expand Up @@ -280,8 +289,8 @@ def update_candidate(schedule_op):
return topo_order_groups


def fuse_bound_symbols(producer, consumer, bound_symbols, merge_func: Callable):
graph = Graph(producer, consumer, bound_symbols)
def fuse_bound_symbols(trace: TraceCtx, merge_func: Callable):
graph = Graph(trace)
dataflow_merge(graph, merge_func)
ret = horizontal_merge(graph, merge_func)
return ret
2 changes: 1 addition & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def _can_fuse_node(n: Node):

return _can_fuse_node(a) and _can_fuse_node(b)

bound_symbol_groups = fuse_bound_symbols(producers, consumers, trace.bound_symbols, _should_fuse)
bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse)

# Counts how many fusions (per executor) have been constructed
# (Used to name fusions like nvFusion0, nvFusion1, ...)
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _can_fuse_node(n: Node):

return _can_fuse_node(a) and _can_fuse_node(b)

bound_symbol_groups = fuse_bound_symbols(producers, consumers, trace.bound_symbols, _should_fuse)
bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse)

# Counts how many fusions (per executor) have been constructed
fusion_counter: int = 0
Expand Down

0 comments on commit b4295cd

Please sign in to comment.