Skip to content

Commit

Permalink
Remove igraph in favor of networkx (#1244)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Oct 7, 2024
1 parent 5dcbc55 commit 952db42
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 27 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ torch >=2.3.0
looseversion ==1.3.0
lightning-utilities >=0.7.0
numpy >=1.23.0,<2 # not yet ready for numpy 2
igraph >=0.10.4
networkx >= 3.3
optree >=0.12.1
opt_einsum >= 3.3.0
mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined`
Expand Down
29 changes: 10 additions & 19 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
import time

from igraph import Graph
import networkx as nx

from thunder.core import prims, utils
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
Expand Down Expand Up @@ -317,12 +317,9 @@ def find_cut(

# Create a graph
edges = []
name_to_id = {}
capacities = []

def add_edge(src, dst, capacity):
edges.append((name_to_id.setdefault(src, len(name_to_id)), name_to_id.setdefault(dst, len(name_to_id))))
capacities.append(capacity)
edges.append((src, dst, {"capacity": capacity}))

utils.check(
len(required_consumer_vars) > 0,
Expand Down Expand Up @@ -374,23 +371,17 @@ def add_edges(var):
for var in symbol.flat_proxy_outs:
add_edges(var)

g = Graph(
n=len(name_to_id),
edges=edges,
directed=True,
edge_attrs={"capacity": capacities},
)
source = name_to_id["source"]
sink = name_to_id["sink"]
g = nx.DiGraph()
g.add_edges_from(edges)

_, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink")

id_to_name = dict(map(reversed, name_to_id.items()))
cut_edges = set()
for u, nbrs in ((n, g[n]) for n in reachable):
cut_edges.update((u, v) for v in nbrs if v in non_reachable)

g_edges = g.get_edgelist()
cut = g.mincut(source, sink, "capacity").cut
cut_nodes = set()
for cut_edge_id in cut:
u, v = g_edges[cut_edge_id]
node_in, node_out = id_to_name[u], id_to_name[v]
for node_in, node_out in cut_edges:
if node_out == "sink":
continue
assert node_in.endswith("_in"), node_in
Expand Down
11 changes: 4 additions & 7 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,13 @@ def func(x):

# There are two nvfuser fusion groups separated by the matmul operation.
assert len(fusion_bsyms) == 2
nvf_0, nvf_1 = fusion_bsyms

# CSE removes the redundant (t0 + 5) operation
assert len(nvf_0.subsymbols) == 5
# Return t0 and t1 from the first fusion
assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t1", "t4"]
nvf_0, nvf_1 = fusion_bsyms
assert len(nvf_0.subsymbols) + len(nvf_1.subsymbols) == 7

# CSE does not change the second fusion
assert len(nvf_1.subsymbols) == 2
assert [t.name for t in tree_flatten(nvf_1.output)[0]] == ["t10"]
outside_fusion_syms = ["unpack_trivial", "matmul", "python_return", "python_del"]
assert {el.sym.name for el in fw_trace.bound_symbols if not el.sym.is_fusion} == set(outside_fusion_syms)


@instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,))
Expand Down

0 comments on commit 952db42

Please sign in to comment.