Skip to content

Commit

Permalink
Move reachable_funcs outside of CallGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Dec 17, 2024
1 parent 1e95bc6 commit 73f65ee
Showing 1 changed file with 34 additions and 35 deletions.
69 changes: 34 additions & 35 deletions hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::{HashMap, HashSet};

use hugr_core::{
hugr::hugrmut::HugrMut,
ops::{OpTag, OpTrait, OpType},
ops::{Call, OpTag, OpTrait, OpType},
HugrView, Node,
};
use itertools::Itertools;
Expand Down Expand Up @@ -72,39 +72,39 @@ impl CallGraph {
}
CallGraph { g, node_to_g }
}
}

fn reachable_funcs(
&self,
h: &impl HugrView,
roots: impl IntoIterator<Item = Node>,
) -> impl Iterator<Item = Node> + '_ {
let mut entry_points = roots.into_iter().collect_vec();
let mut b = if h.get_optype(h.root()).is_module() {
if entry_points.is_empty() {
entry_points.push(
h.children(h.root())
.filter(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.name == "main")
})
.exactly_one()
.ok()
.expect("No entry_points provided, Module must contain `main`"),
);
}
let mut entry_points = entry_points
.into_iter()
.map(|i| self.node_to_g.get(&i).unwrap());
let mut b = Bfs::new(&self.g, *entry_points.next().unwrap());
b.stack.extend(entry_points);
b
} else {
assert!(entry_points.is_empty());
Bfs::new(&self.g, *self.node_to_g.get(&h.root()).unwrap())
};
std::iter::from_fn(move || b.next(&self.g)).map(|i| *self.g.node_weight(i).unwrap())
}
fn reachable_funcs<'a>(
cg: &'a CallGraph,
h: &impl HugrView,
roots: impl IntoIterator<Item = Node>,
) -> impl Iterator<Item = Node> + 'a {
let mut entry_points = roots.into_iter().collect_vec();
let mut b = if h.get_optype(h.root()).is_module() {
if entry_points.is_empty() {
entry_points.push(
h.children(h.root())
.filter(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.name == "main")
})
.exactly_one()
.ok()
.expect("No entry_points provided, Module must contain `main`"),
);
}
let mut entry_points = entry_points
.into_iter()
.map(|i| cg.node_to_g.get(&i).unwrap());
let mut b = Bfs::new(&cg.g, *entry_points.next().unwrap());
b.stack.extend(entry_points);
b
} else {
assert!(entry_points.is_empty());
Bfs::new(&cg.g, *cg.node_to_g.get(&h.root()).unwrap())
};
std::iter::from_fn(move || b.next(&cg.g)).map(|i| *cg.g.node_weight(i).unwrap())
}

/// Delete from the Hugr any functions that are not used by either [Call](OpType::Call) or
Expand All @@ -123,8 +123,7 @@ impl CallGraph {
/// * If the Hugr is Module-rooted, and `entry_points` is non-empty but contains nodes that
/// are not [FuncDefn](OpType::FuncDefn)s
pub fn remove_dead_funcs(h: &mut impl HugrMut, entry_points: impl IntoIterator<Item = Node>) {
let cg = CallGraph::new(h);
let reachable = cg.reachable_funcs(h, entry_points).collect::<HashSet<_>>();
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points).collect::<HashSet<_>>();
let unreachable = h
.nodes()
.filter(|n| h.get_optype(*n).is_func_defn() && !reachable.contains(n))
Expand Down

0 comments on commit 73f65ee

Please sign in to comment.