diff --git a/hugr-passes/src/call_graph.rs b/hugr-passes/src/call_graph.rs index 47f56a695..13a83f4e1 100644 --- a/hugr-passes/src/call_graph.rs +++ b/hugr-passes/src/call_graph.rs @@ -77,31 +77,24 @@ impl CallGraph { fn reachable_funcs<'a>( cg: &'a CallGraph, h: &impl HugrView, - roots: impl IntoIterator, + entry_points: impl IntoIterator, ) -> impl Iterator + 'a { - let mut entry_points = roots.into_iter().collect_vec(); + let mut roots = entry_points.into_iter().collect_vec(); let mut b = if h.get_optype(h.root()).is_module() { - if entry_points.is_empty() { - entry_points.push( - h.children(h.root()) - .filter(|n| { - h.get_optype(*n) - .as_func_defn() - .is_some_and(|fd| fd.name == "main") - }) - .exactly_one() - .ok() - .expect("No entry_points provided, Module must contain `main`"), - ); + if roots.is_empty() { + roots.extend(h.children(h.root()).filter(|n| { + h.get_optype(*n) + .as_func_defn() + .is_some_and(|fd| fd.name == "main") + })); + assert!(!roots.is_empty(), "No entry_points for Module and no `main`"); } - let mut entry_points = entry_points - .into_iter() - .map(|i| cg.node_to_g.get(&i).unwrap()); - let mut b = Bfs::new(&cg.g, *entry_points.next().unwrap()); - b.stack.extend(entry_points); + let mut roots = roots.into_iter().map(|i| cg.node_to_g.get(&i).unwrap()); + let mut b = Bfs::new(&cg.g, *roots.next().unwrap()); + b.stack.extend(roots); b } else { - assert!(entry_points.is_empty()); + assert!(roots.is_empty()); Bfs::new(&cg.g, *cg.node_to_g.get(&h.root()).unwrap()) }; std::iter::from_fn(move || b.next(&cg.g)).map(|i| *cg.g.node_weight(i).unwrap())