Skip to content

Commit

Permalink
feat: Add CallGraph struct, and dead-function-removal pass (#1796)
Browse files Browse the repository at this point in the history
Closes #1753.

* `remove_polyfuncs` preserved but deprecated, some uses in tests
replaced to give coverage here.
* Future (breaking) release to remove the automatic-`remove_polyfuncs`
that currently follows monomorphization.
  • Loading branch information
acl-cqc authored Dec 24, 2024
1 parent 33dd8fd commit cad7484
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 12 deletions.
99 changes: 99 additions & 0 deletions hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#![warn(missing_docs)]
//! Data structure for call graphs of a Hugr
use std::collections::HashMap;

use hugr_core::{ops::OpType, HugrView, Node};
use petgraph::{graph::NodeIndex, Graph};

/// Weight for an edge in a [CallGraph]
pub enum CallGraphEdge {
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
Call(Node),
/// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr
LoadFunction(Node),
}

/// Weight for a petgraph-node in a [CallGraph]
pub enum CallGraphNode {
/// petgraph-node corresponds to a [FuncDecl](OpType::FuncDecl) node (specified) in the Hugr
FuncDecl(Node),
/// petgraph-node corresponds to a [FuncDefn](OpType::FuncDefn) node (specified) in the Hugr
FuncDefn(Node),
/// petgraph-node corresponds to the root node of the hugr, that is not
/// a [FuncDefn](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
NonFuncRoot,
}

/// Details the [Call]s and [LoadFunction]s in a Hugr.
/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds
/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source.
///
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph
/// will have an additional [CallGraphNode::NonFuncRoot] corresponding to the Hugr's root, with no incoming edges.
///
/// [Call]: OpType::Call
/// [FuncDefn]: OpType::FuncDefn
/// [LoadFunction]: OpType::LoadFunction
pub struct CallGraph {
g: Graph<CallGraphNode, CallGraphEdge>,
node_to_g: HashMap<Node, NodeIndex<u32>>,
}

impl CallGraph {
/// Makes a new CallGraph for a specified (subview) of a Hugr.
/// Calls to functions outside the view will be dropped.
pub fn new(hugr: &impl HugrView) -> Self {
let mut g = Graph::default();
let non_func_root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root());
let node_to_g = hugr
.nodes()
.filter_map(|n| {
let weight = match hugr.get_optype(n) {
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n),
OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n),
_ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?,
};
Some((n, g.add_node(weight)))
})
.collect::<HashMap<_, _>>();
for (func, cg_node) in node_to_g.iter() {
traverse(hugr, *cg_node, *func, &mut g, &node_to_g)
}
fn traverse(
h: &impl HugrView,
enclosing_func: NodeIndex<u32>,
node: Node, // Nonstrict-descendant of `enclosing_func``
g: &mut Graph<CallGraphNode, CallGraphEdge>,
node_to_g: &HashMap<Node, NodeIndex<u32>>,
) {
for ch in h.children(node) {
if h.get_optype(ch).is_func_defn() {
continue;
};
traverse(h, enclosing_func, ch, g, node_to_g);
let weight = match h.get_optype(ch) {
OpType::Call(_) => CallGraphEdge::Call(ch),
OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch),
_ => continue,
};
if let Some(target) = h.static_source(ch) {
g.add_edge(enclosing_func, *node_to_g.get(&target).unwrap(), weight);
}
}
}
CallGraph { g, node_to_g }
}

/// Allows access to the petgraph
pub fn graph(&self) -> &Graph<CallGraphNode, CallGraphEdge> {
&self.g
}

/// Convert a Hugr [Node] into a petgraph node index.
/// Result will be `None` if `n` is not a [FuncDefn](OpType::FuncDefn),
/// [FuncDecl](OpType::FuncDecl) or the hugr root.
pub fn node_index(&self, n: Node) -> Option<NodeIndex<u32>> {
self.node_to_g.get(&n).copied()
}
}
197 changes: 197 additions & 0 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#![warn(missing_docs)]
//! Pass for removing statically-unreachable functions from a Hugr
use std::collections::HashSet;

use hugr_core::{
hugr::hugrmut::HugrMut,
ops::{OpTag, OpTrait},
HugrView, Node,
};
use petgraph::visit::{Dfs, Walker};

use crate::validation::{ValidatePassError, ValidationLevel};

use super::call_graph::{CallGraph, CallGraphNode};

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
/// Errors produced by [ConstantFoldPass].
pub enum RemoveDeadFuncsError {
#[error("Node {0} was not a FuncDefn child of the Module root")]
InvalidEntryPoint(Node),
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
}

fn reachable_funcs<'a>(
cg: &'a CallGraph,
h: &'a impl HugrView,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<impl Iterator<Item = Node> + 'a, RemoveDeadFuncsError> {
let g = cg.graph();
let mut entry_points = entry_points.into_iter();
let searcher = if h.get_optype(h.root()).is_module() {
let mut d = Dfs::new(g, 0.into());
d.stack.clear();
for n in entry_points {
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.root()) {
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
}
d.stack.push(cg.node_index(n).unwrap())
}
d
} else {
if let Some(n) = entry_points.next() {
// Can't be a child of the module root as there isn't a module root!
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
}
Dfs::new(g, cg.node_index(h.root()).unwrap())
};
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() {
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
CallGraphNode::NonFuncRoot => h.root(),
}))
}

#[derive(Debug, Clone, Default)]
/// A configuration for the Dead Function Removal pass.
pub struct RemoveDeadFuncsPass {
validation: ValidationLevel,
entry_points: Vec<Node>,
}

impl RemoveDeadFuncsPass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Adds new entry points - these must be [FuncDefn] nodes
/// that are children of the [Module] at the root of the Hugr.
///
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
/// [Module]: hugr_core::ops::OpType::Module
pub fn with_module_entry_points(
mut self,
entry_points: impl IntoIterator<Item = Node>,
) -> Self {
self.entry_points.extend(entry_points);
self
}

/// Runs the pass (see [remove_dead_funcs]) with this configuration
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
self.validation.run_validated_pass(hugr, |hugr: &mut H, _| {
remove_dead_funcs(hugr, self.entry_points.iter().cloned())
})
}
}

/// Delete from the Hugr any functions that are not used by either [Call] or
/// [LoadFunction] nodes in reachable parts.
///
/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points,
/// which must be children of the root. Note that if `entry_points` is empty, this will
/// result in all functions in the module being removed.
///
/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used.
///
/// # Errors
/// * If there are any `entry_points` but the root of the hugr is not a [Module]
/// * If any node in `entry_points` is
/// * not a [FuncDefn], or
/// * not a child of the root
///
/// [Call]: hugr_core::ops::OpType::Call
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction
/// [Module]: hugr_core::ops::OpType::Module
pub fn remove_dead_funcs(
h: &mut impl HugrMut,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<(), RemoveDeadFuncsError> {
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::<HashSet<_>>();
let unreachable = h
.nodes()
.filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n))
.collect::<Vec<_>>();
for n in unreachable {
h.remove_subtree(n);
}
Ok(())
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView};

use super::RemoveDeadFuncsPass;

#[rstest]
#[case([], vec![])] // No entry_points removes everything!
#[case(["main"], vec!["from_main", "main"])]
#[case(["from_main"], vec!["from_main"])]
#[case(["other1"], vec!["other1", "other2"])]
#[case(["other2"], vec!["other2"])]
#[case(["other1", "other2"], vec!["other1", "other2"])]
fn remove_dead_funcs_entry_points(
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut hb = ModuleBuilder::new();
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
let o2inp = o2.input_wires();
let o2 = o2.finish_with_outputs(o2inp)?;
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;

let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
o1.finish_with_outputs(o1c.outputs())?;

let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
let f_inp = fm.input_wires();
let fm = fm.finish_with_outputs(f_inp)?;
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
let mc = m.call(fm.handle(), &[], m.input_wires())?;
m.finish_with_outputs(mc.outputs())?;

let mut hugr = hb.finish_hugr()?;

let avail_funcs = hugr
.nodes()
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| (fd.name.clone(), n))
})
.collect::<HashMap<_, _>>();

RemoveDeadFuncsPass::default()
.with_module_entry_points(
entry_points
.into_iter()
.map(|name| *avail_funcs.get(name).unwrap())
.collect::<Vec<_>>(),
)
.run(&mut hugr)
.unwrap();

let remaining_funcs = hugr
.nodes()
.filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str()))
.sorted()
.collect_vec();
assert_eq!(remaining_funcs, retained_funcs);
Ok(())
}
}
12 changes: 11 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
//! Compilation passes acting on the HUGR program representation.
pub mod call_graph;
pub mod const_fold;
pub mod dataflow;
mod dead_funcs;
pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsPass};
pub mod force_order;
mod half_node;
pub mod lower;
pub mod merge_bbs;
mod monomorphize;
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::call_graph::RemoveDeadFuncsPass` instead."
)]
#[allow(deprecated)]
pub use monomorphize::remove_polyfuncs;
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
)]
#[allow(deprecated)]
pub use monomorphize::monomorphize;
pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass};
pub use monomorphize::{MonomorphizeError, MonomorphizePass};
pub mod nest_cfgs;
pub mod non_local;
pub mod validation;
Expand Down
Loading

0 comments on commit cad7484

Please sign in to comment.