-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add CallGraph struct, and dead-function-removal pass (#1796)
Closes #1753. * `remove_polyfuncs` preserved but deprecated, some uses in tests replaced to give coverage here. * Future (breaking) release to remove the automatic-`remove_polyfuncs` that currently follows monomorphization.
- Loading branch information
Showing
4 changed files
with
329 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#![warn(missing_docs)] | ||
//! Data structure for call graphs of a Hugr | ||
use std::collections::HashMap; | ||
|
||
use hugr_core::{ops::OpType, HugrView, Node}; | ||
use petgraph::{graph::NodeIndex, Graph}; | ||
|
||
/// Weight for an edge in a [CallGraph] | ||
pub enum CallGraphEdge { | ||
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr | ||
Call(Node), | ||
/// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr | ||
LoadFunction(Node), | ||
} | ||
|
||
/// Weight for a petgraph-node in a [CallGraph] | ||
pub enum CallGraphNode { | ||
/// petgraph-node corresponds to a [FuncDecl](OpType::FuncDecl) node (specified) in the Hugr | ||
FuncDecl(Node), | ||
/// petgraph-node corresponds to a [FuncDefn](OpType::FuncDefn) node (specified) in the Hugr | ||
FuncDefn(Node), | ||
/// petgraph-node corresponds to the root node of the hugr, that is not | ||
/// a [FuncDefn](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module) | ||
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph. | ||
NonFuncRoot, | ||
} | ||
|
||
/// Details the [Call]s and [LoadFunction]s in a Hugr. | ||
/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds | ||
/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source. | ||
/// | ||
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph | ||
/// will have an additional [CallGraphNode::NonFuncRoot] corresponding to the Hugr's root, with no incoming edges. | ||
/// | ||
/// [Call]: OpType::Call | ||
/// [FuncDefn]: OpType::FuncDefn | ||
/// [LoadFunction]: OpType::LoadFunction | ||
pub struct CallGraph { | ||
g: Graph<CallGraphNode, CallGraphEdge>, | ||
node_to_g: HashMap<Node, NodeIndex<u32>>, | ||
} | ||
|
||
impl CallGraph { | ||
/// Makes a new CallGraph for a specified (subview) of a Hugr. | ||
/// Calls to functions outside the view will be dropped. | ||
pub fn new(hugr: &impl HugrView) -> Self { | ||
let mut g = Graph::default(); | ||
let non_func_root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root()); | ||
let node_to_g = hugr | ||
.nodes() | ||
.filter_map(|n| { | ||
let weight = match hugr.get_optype(n) { | ||
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n), | ||
OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n), | ||
_ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?, | ||
}; | ||
Some((n, g.add_node(weight))) | ||
}) | ||
.collect::<HashMap<_, _>>(); | ||
for (func, cg_node) in node_to_g.iter() { | ||
traverse(hugr, *cg_node, *func, &mut g, &node_to_g) | ||
} | ||
fn traverse( | ||
h: &impl HugrView, | ||
enclosing_func: NodeIndex<u32>, | ||
node: Node, // Nonstrict-descendant of `enclosing_func`` | ||
g: &mut Graph<CallGraphNode, CallGraphEdge>, | ||
node_to_g: &HashMap<Node, NodeIndex<u32>>, | ||
) { | ||
for ch in h.children(node) { | ||
if h.get_optype(ch).is_func_defn() { | ||
continue; | ||
}; | ||
traverse(h, enclosing_func, ch, g, node_to_g); | ||
let weight = match h.get_optype(ch) { | ||
OpType::Call(_) => CallGraphEdge::Call(ch), | ||
OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch), | ||
_ => continue, | ||
}; | ||
if let Some(target) = h.static_source(ch) { | ||
g.add_edge(enclosing_func, *node_to_g.get(&target).unwrap(), weight); | ||
} | ||
} | ||
} | ||
CallGraph { g, node_to_g } | ||
} | ||
|
||
/// Allows access to the petgraph | ||
pub fn graph(&self) -> &Graph<CallGraphNode, CallGraphEdge> { | ||
&self.g | ||
} | ||
|
||
/// Convert a Hugr [Node] into a petgraph node index. | ||
/// Result will be `None` if `n` is not a [FuncDefn](OpType::FuncDefn), | ||
/// [FuncDecl](OpType::FuncDecl) or the hugr root. | ||
pub fn node_index(&self, n: Node) -> Option<NodeIndex<u32>> { | ||
self.node_to_g.get(&n).copied() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
#![warn(missing_docs)] | ||
//! Pass for removing statically-unreachable functions from a Hugr | ||
use std::collections::HashSet; | ||
|
||
use hugr_core::{ | ||
hugr::hugrmut::HugrMut, | ||
ops::{OpTag, OpTrait}, | ||
HugrView, Node, | ||
}; | ||
use petgraph::visit::{Dfs, Walker}; | ||
|
||
use crate::validation::{ValidatePassError, ValidationLevel}; | ||
|
||
use super::call_graph::{CallGraph, CallGraphNode}; | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[non_exhaustive] | ||
/// Errors produced by [ConstantFoldPass]. | ||
pub enum RemoveDeadFuncsError { | ||
#[error("Node {0} was not a FuncDefn child of the Module root")] | ||
InvalidEntryPoint(Node), | ||
#[error(transparent)] | ||
#[allow(missing_docs)] | ||
ValidationError(#[from] ValidatePassError), | ||
} | ||
|
||
fn reachable_funcs<'a>( | ||
cg: &'a CallGraph, | ||
h: &'a impl HugrView, | ||
entry_points: impl IntoIterator<Item = Node>, | ||
) -> Result<impl Iterator<Item = Node> + 'a, RemoveDeadFuncsError> { | ||
let g = cg.graph(); | ||
let mut entry_points = entry_points.into_iter(); | ||
let searcher = if h.get_optype(h.root()).is_module() { | ||
let mut d = Dfs::new(g, 0.into()); | ||
d.stack.clear(); | ||
for n in entry_points { | ||
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.root()) { | ||
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n)); | ||
} | ||
d.stack.push(cg.node_index(n).unwrap()) | ||
} | ||
d | ||
} else { | ||
if let Some(n) = entry_points.next() { | ||
// Can't be a child of the module root as there isn't a module root! | ||
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n)); | ||
} | ||
Dfs::new(g, cg.node_index(h.root()).unwrap()) | ||
}; | ||
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() { | ||
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n, | ||
CallGraphNode::NonFuncRoot => h.root(), | ||
})) | ||
} | ||
|
||
#[derive(Debug, Clone, Default)] | ||
/// A configuration for the Dead Function Removal pass. | ||
pub struct RemoveDeadFuncsPass { | ||
validation: ValidationLevel, | ||
entry_points: Vec<Node>, | ||
} | ||
|
||
impl RemoveDeadFuncsPass { | ||
/// Sets the validation level used before and after the pass is run | ||
pub fn validation_level(mut self, level: ValidationLevel) -> Self { | ||
self.validation = level; | ||
self | ||
} | ||
|
||
/// Adds new entry points - these must be [FuncDefn] nodes | ||
/// that are children of the [Module] at the root of the Hugr. | ||
/// | ||
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn | ||
/// [Module]: hugr_core::ops::OpType::Module | ||
pub fn with_module_entry_points( | ||
mut self, | ||
entry_points: impl IntoIterator<Item = Node>, | ||
) -> Self { | ||
self.entry_points.extend(entry_points); | ||
self | ||
} | ||
|
||
/// Runs the pass (see [remove_dead_funcs]) with this configuration | ||
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { | ||
self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { | ||
remove_dead_funcs(hugr, self.entry_points.iter().cloned()) | ||
}) | ||
} | ||
} | ||
|
||
/// Delete from the Hugr any functions that are not used by either [Call] or | ||
/// [LoadFunction] nodes in reachable parts. | ||
/// | ||
/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, | ||
/// which must be children of the root. Note that if `entry_points` is empty, this will | ||
/// result in all functions in the module being removed. | ||
/// | ||
/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used. | ||
/// | ||
/// # Errors | ||
/// * If there are any `entry_points` but the root of the hugr is not a [Module] | ||
/// * If any node in `entry_points` is | ||
/// * not a [FuncDefn], or | ||
/// * not a child of the root | ||
/// | ||
/// [Call]: hugr_core::ops::OpType::Call | ||
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn | ||
/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction | ||
/// [Module]: hugr_core::ops::OpType::Module | ||
pub fn remove_dead_funcs( | ||
h: &mut impl HugrMut, | ||
entry_points: impl IntoIterator<Item = Node>, | ||
) -> Result<(), RemoveDeadFuncsError> { | ||
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::<HashSet<_>>(); | ||
let unreachable = h | ||
.nodes() | ||
.filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) | ||
.collect::<Vec<_>>(); | ||
for n in unreachable { | ||
h.remove_subtree(n); | ||
} | ||
Ok(()) | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use std::collections::HashMap; | ||
|
||
use itertools::Itertools; | ||
use rstest::rstest; | ||
|
||
use hugr_core::builder::{ | ||
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, | ||
}; | ||
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; | ||
|
||
use super::RemoveDeadFuncsPass; | ||
|
||
#[rstest] | ||
#[case([], vec![])] // No entry_points removes everything! | ||
#[case(["main"], vec!["from_main", "main"])] | ||
#[case(["from_main"], vec!["from_main"])] | ||
#[case(["other1"], vec!["other1", "other2"])] | ||
#[case(["other2"], vec!["other2"])] | ||
#[case(["other1", "other2"], vec!["other1", "other2"])] | ||
fn remove_dead_funcs_entry_points( | ||
#[case] entry_points: impl IntoIterator<Item = &'static str>, | ||
#[case] retained_funcs: Vec<&'static str>, | ||
) -> Result<(), Box<dyn std::error::Error>> { | ||
let mut hb = ModuleBuilder::new(); | ||
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?; | ||
let o2inp = o2.input_wires(); | ||
let o2 = o2.finish_with_outputs(o2inp)?; | ||
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?; | ||
|
||
let o1c = o1.call(o2.handle(), &[], o1.input_wires())?; | ||
o1.finish_with_outputs(o1c.outputs())?; | ||
|
||
let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?; | ||
let f_inp = fm.input_wires(); | ||
let fm = fm.finish_with_outputs(f_inp)?; | ||
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?; | ||
let mc = m.call(fm.handle(), &[], m.input_wires())?; | ||
m.finish_with_outputs(mc.outputs())?; | ||
|
||
let mut hugr = hb.finish_hugr()?; | ||
|
||
let avail_funcs = hugr | ||
.nodes() | ||
.filter_map(|n| { | ||
hugr.get_optype(n) | ||
.as_func_defn() | ||
.map(|fd| (fd.name.clone(), n)) | ||
}) | ||
.collect::<HashMap<_, _>>(); | ||
|
||
RemoveDeadFuncsPass::default() | ||
.with_module_entry_points( | ||
entry_points | ||
.into_iter() | ||
.map(|name| *avail_funcs.get(name).unwrap()) | ||
.collect::<Vec<_>>(), | ||
) | ||
.run(&mut hugr) | ||
.unwrap(); | ||
|
||
let remaining_funcs = hugr | ||
.nodes() | ||
.filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str())) | ||
.sorted() | ||
.collect_vec(); | ||
assert_eq!(remaining_funcs, retained_funcs); | ||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.