diff --git a/hugr-passes/src/call_graph.rs b/hugr-passes/src/call_graph.rs new file mode 100644 index 000000000..a4ef11516 --- /dev/null +++ b/hugr-passes/src/call_graph.rs @@ -0,0 +1,99 @@ +#![warn(missing_docs)] +//! Data structure for call graphs of a Hugr +use std::collections::HashMap; + +use hugr_core::{ops::OpType, HugrView, Node}; +use petgraph::{graph::NodeIndex, Graph}; + +/// Weight for an edge in a [CallGraph] +pub enum CallGraphEdge { + /// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr + Call(Node), + /// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr + LoadFunction(Node), +} + +/// Weight for a petgraph-node in a [CallGraph] +pub enum CallGraphNode { + /// petgraph-node corresponds to a [FuncDecl](OpType::FuncDecl) node (specified) in the Hugr + FuncDecl(Node), + /// petgraph-node corresponds to a [FuncDefn](OpType::FuncDefn) node (specified) in the Hugr + FuncDefn(Node), + /// petgraph-node corresponds to the root node of the hugr, that is not + /// a [FuncDefn](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module) + /// either, as such a node could not have outgoing edges, so is not represented in the petgraph. + NonFuncRoot, +} + +/// Details the [Call]s and [LoadFunction]s in a Hugr. +/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds +/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source. +/// +/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph +/// will have an additional [CallGraphNode::NonFuncRoot] corresponding to the Hugr's root, with no incoming edges. +/// +/// [Call]: OpType::Call +/// [FuncDefn]: OpType::FuncDefn +/// [LoadFunction]: OpType::LoadFunction +pub struct CallGraph { + g: Graph, + node_to_g: HashMap>, +} + +impl CallGraph { + /// Makes a new CallGraph for a specified (subview) of a Hugr. + /// Calls to functions outside the view will be dropped. + pub fn new(hugr: &impl HugrView) -> Self { + let mut g = Graph::default(); + let non_func_root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root()); + let node_to_g = hugr + .nodes() + .filter_map(|n| { + let weight = match hugr.get_optype(n) { + OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n), + OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n), + _ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?, + }; + Some((n, g.add_node(weight))) + }) + .collect::>(); + for (func, cg_node) in node_to_g.iter() { + traverse(hugr, *cg_node, *func, &mut g, &node_to_g) + } + fn traverse( + h: &impl HugrView, + enclosing_func: NodeIndex, + node: Node, // Nonstrict-descendant of `enclosing_func`` + g: &mut Graph, + node_to_g: &HashMap>, + ) { + for ch in h.children(node) { + if h.get_optype(ch).is_func_defn() { + continue; + }; + traverse(h, enclosing_func, ch, g, node_to_g); + let weight = match h.get_optype(ch) { + OpType::Call(_) => CallGraphEdge::Call(ch), + OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch), + _ => continue, + }; + if let Some(target) = h.static_source(ch) { + g.add_edge(enclosing_func, *node_to_g.get(&target).unwrap(), weight); + } + } + } + CallGraph { g, node_to_g } + } + + /// Allows access to the petgraph + pub fn graph(&self) -> &Graph { + &self.g + } + + /// Convert a Hugr [Node] into a petgraph node index. + /// Result will be `None` if `n` is not a [FuncDefn](OpType::FuncDefn), + /// [FuncDecl](OpType::FuncDecl) or the hugr root. + pub fn node_index(&self, n: Node) -> Option> { + self.node_to_g.get(&n).copied() + } +} diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs new file mode 100644 index 000000000..903d94c3a --- /dev/null +++ b/hugr-passes/src/dead_funcs.rs @@ -0,0 +1,197 @@ +#![warn(missing_docs)] +//! Pass for removing statically-unreachable functions from a Hugr + +use std::collections::HashSet; + +use hugr_core::{ + hugr::hugrmut::HugrMut, + ops::{OpTag, OpTrait}, + HugrView, Node, +}; +use petgraph::visit::{Dfs, Walker}; + +use crate::validation::{ValidatePassError, ValidationLevel}; + +use super::call_graph::{CallGraph, CallGraphNode}; + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +/// Errors produced by [ConstantFoldPass]. +pub enum RemoveDeadFuncsError { + #[error("Node {0} was not a FuncDefn child of the Module root")] + InvalidEntryPoint(Node), + #[error(transparent)] + #[allow(missing_docs)] + ValidationError(#[from] ValidatePassError), +} + +fn reachable_funcs<'a>( + cg: &'a CallGraph, + h: &'a impl HugrView, + entry_points: impl IntoIterator, +) -> Result + 'a, RemoveDeadFuncsError> { + let g = cg.graph(); + let mut entry_points = entry_points.into_iter(); + let searcher = if h.get_optype(h.root()).is_module() { + let mut d = Dfs::new(g, 0.into()); + d.stack.clear(); + for n in entry_points { + if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.root()) { + return Err(RemoveDeadFuncsError::InvalidEntryPoint(n)); + } + d.stack.push(cg.node_index(n).unwrap()) + } + d + } else { + if let Some(n) = entry_points.next() { + // Can't be a child of the module root as there isn't a module root! + return Err(RemoveDeadFuncsError::InvalidEntryPoint(n)); + } + Dfs::new(g, cg.node_index(h.root()).unwrap()) + }; + Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() { + CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n, + CallGraphNode::NonFuncRoot => h.root(), + })) +} + +#[derive(Debug, Clone, Default)] +/// A configuration for the Dead Function Removal pass. +pub struct RemoveDeadFuncsPass { + validation: ValidationLevel, + entry_points: Vec, +} + +impl RemoveDeadFuncsPass { + /// Sets the validation level used before and after the pass is run + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Adds new entry points - these must be [FuncDefn] nodes + /// that are children of the [Module] at the root of the Hugr. + /// + /// [FuncDefn]: hugr_core::ops::OpType::FuncDefn + /// [Module]: hugr_core::ops::OpType::Module + pub fn with_module_entry_points( + mut self, + entry_points: impl IntoIterator, + ) -> Self { + self.entry_points.extend(entry_points); + self + } + + /// Runs the pass (see [remove_dead_funcs]) with this configuration + pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { + self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { + remove_dead_funcs(hugr, self.entry_points.iter().cloned()) + }) + } +} + +/// Delete from the Hugr any functions that are not used by either [Call] or +/// [LoadFunction] nodes in reachable parts. +/// +/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, +/// which must be children of the root. Note that if `entry_points` is empty, this will +/// result in all functions in the module being removed. +/// +/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used. +/// +/// # Errors +/// * If there are any `entry_points` but the root of the hugr is not a [Module] +/// * If any node in `entry_points` is +/// * not a [FuncDefn], or +/// * not a child of the root +/// +/// [Call]: hugr_core::ops::OpType::Call +/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn +/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction +/// [Module]: hugr_core::ops::OpType::Module +pub fn remove_dead_funcs( + h: &mut impl HugrMut, + entry_points: impl IntoIterator, +) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); + let unreachable = h + .nodes() + .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) + .collect::>(); + for n in unreachable { + h.remove_subtree(n); + } + Ok(()) +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use itertools::Itertools; + use rstest::rstest; + + use hugr_core::builder::{ + Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, + }; + use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; + + use super::RemoveDeadFuncsPass; + + #[rstest] + #[case([], vec![])] // No entry_points removes everything! + #[case(["main"], vec!["from_main", "main"])] + #[case(["from_main"], vec!["from_main"])] + #[case(["other1"], vec!["other1", "other2"])] + #[case(["other2"], vec!["other2"])] + #[case(["other1", "other2"], vec!["other1", "other2"])] + fn remove_dead_funcs_entry_points( + #[case] entry_points: impl IntoIterator, + #[case] retained_funcs: Vec<&'static str>, + ) -> Result<(), Box> { + let mut hb = ModuleBuilder::new(); + let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?; + let o2inp = o2.input_wires(); + let o2 = o2.finish_with_outputs(o2inp)?; + let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?; + + let o1c = o1.call(o2.handle(), &[], o1.input_wires())?; + o1.finish_with_outputs(o1c.outputs())?; + + let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?; + let f_inp = fm.input_wires(); + let fm = fm.finish_with_outputs(f_inp)?; + let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?; + let mc = m.call(fm.handle(), &[], m.input_wires())?; + m.finish_with_outputs(mc.outputs())?; + + let mut hugr = hb.finish_hugr()?; + + let avail_funcs = hugr + .nodes() + .filter_map(|n| { + hugr.get_optype(n) + .as_func_defn() + .map(|fd| (fd.name.clone(), n)) + }) + .collect::>(); + + RemoveDeadFuncsPass::default() + .with_module_entry_points( + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .run(&mut hugr) + .unwrap(); + + let remaining_funcs = hugr + .nodes() + .filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str())) + .sorted() + .collect_vec(); + assert_eq!(remaining_funcs, retained_funcs); + Ok(()) + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 55e058047..ffc739933 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,20 +1,30 @@ //! Compilation passes acting on the HUGR program representation. +pub mod call_graph; pub mod const_fold; pub mod dataflow; +mod dead_funcs; +pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsPass}; pub mod force_order; mod half_node; pub mod lower; pub mod merge_bbs; mod monomorphize; // TODO: Deprecated re-export. Remove on a breaking release. +#[deprecated( + since = "0.14.1", + note = "Use `hugr::algorithms::call_graph::RemoveDeadFuncsPass` instead." +)] +#[allow(deprecated)] +pub use monomorphize::remove_polyfuncs; +// TODO: Deprecated re-export. Remove on a breaking release. #[deprecated( since = "0.14.1", note = "Use `hugr::algorithms::MonomorphizePass` instead." )] #[allow(deprecated)] pub use monomorphize::monomorphize; -pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{MonomorphizeError, MonomorphizePass}; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 9c1e69aca..95d26c557 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -19,8 +19,9 @@ use thiserror::Error; /// /// If the Hugr is [Module](OpType::Module)-rooted, /// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them) -/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these -/// * else, the originals are removed (they are invisible from outside the Hugr). +/// - [crate::remove_dead_funcs] can be used when no other Hugr will be linked in that might instantiate these +/// * else, the originals are removed (they are invisible from outside the Hugr); however, note +/// that this behaviour is expected to change in a future release to match Module-rooted Hugrs. /// /// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic /// signature then the HUGR will not be modified. @@ -45,6 +46,7 @@ fn monomorphize_ref(h: &mut impl HugrMut) { if !is_polymorphic_funcdefn(h.get_optype(root)) { mono_scan(h, root, None, &mut HashMap::new()); if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs remove_polyfuncs_ref(h); } } @@ -54,13 +56,21 @@ fn monomorphize_ref(h: &mut impl HugrMut) { /// calls from *monomorphic* code, this will make the Hugr invalid (call [monomorphize] /// first). /// -/// TODO replace this with a more general remove-unused-functions pass -/// +/// Deprecated: use [crate::remove_dead_funcs] instead. +#[deprecated( + since = "0.14.1", + note = "Use hugr::algorithms::dead_funcs::RemoveDeadFuncsPass instead" +)] pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { + #[allow(deprecated)] // we are in a deprecated function, so remove both at same time remove_polyfuncs_ref(&mut h); h } +#[deprecated( + since = "0.14.1", + note = "Use hugr::algorithms::dead_funcs::RemoveDeadFuncsPass instead" +)] fn remove_polyfuncs_ref(h: &mut impl HugrMut) { let mut pfs_to_delete = Vec::new(); let mut to_scan = Vec::from_iter(h.children(h.root())); @@ -377,9 +387,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass}; + use crate::remove_dead_funcs; - use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs}; + use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -443,7 +453,7 @@ mod test { let trip = fb.add_dataflow_op(tag, [elem1, elem2, elem])?; fb.finish_with_outputs(trip.outputs())? }; - { + let mn = { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; let [elem] = fb.input_wires_arr(); @@ -453,8 +463,8 @@ mod test { let pair = fb.call(db.handle(), &[usize_t().into()], [elem])?; let pty = pair_type(usize_t()).into(); let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); - fb.finish_with_outputs([res1, res2])?; - } + fb.finish_with_outputs([res1, res2])? + }; let mut hugr = mb.finish_hugr()?; assert_eq!( hugr.nodes() @@ -487,7 +497,8 @@ mod test { assert_eq!(mono2, mono); // Idempotent - let nopoly = remove_polyfuncs(mono); + let mut nopoly = mono; + remove_dead_funcs(&mut nopoly, [mn.node()])?; let mut funcs = list_funcs(&nopoly); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); @@ -709,7 +720,7 @@ mod test { }; MonomorphizePass::default().run(&mut hugr).unwrap(); - remove_polyfuncs_ref(&mut hugr); + remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));