From 7ee173c0cd3daba98131d48fd8b2c499d2a85aa7 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 18 Dec 2024 12:54:37 +0000 Subject: [PATCH] feat: add MonomorphizePass and deprecate monomorphize (#1809) Co-authored-by: Douglas Wilson <141026920+doug-q@users.noreply.github.com> --- hugr-passes/src/lib.rs | 9 ++- hugr-passes/src/monomorphize.rs | 128 ++++++++++++++++++++++++-------- 2 files changed, 104 insertions(+), 33 deletions(-) diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index a4507e156..55e058047 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -7,7 +7,14 @@ mod half_node; pub mod lower; pub mod merge_bbs; mod monomorphize; -pub use monomorphize::{monomorphize, remove_polyfuncs}; +// TODO: Deprecated re-export. Remove on a breaking release. +#[deprecated( + since = "0.14.1", + note = "Use `hugr::algorithms::MonomorphizePass` instead." +)] +#[allow(deprecated)] +pub use monomorphize::monomorphize; +pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass}; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index d7a9bb7ad..9c1e69aca 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -10,8 +10,9 @@ use hugr_core::{ Node, }; -use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType}; +use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; +use thiserror::Error; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -28,26 +29,25 @@ use itertools::Itertools as _; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. +#[deprecated( + since = "0.14.1", + note = "Use `hugr::algorithms::MonomorphizePass` instead." +)] +// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. pub fn monomorphize(mut h: Hugr) -> Hugr { - let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}")); - - // We clone the extension registry because we will need a reference to - // create our mutable substitutions. This is cannot cause a problem because - // we will not be adding any new types or extension ops to the HUGR. - #[cfg(debug_assertions)] - validate(&h); + monomorphize_ref(&mut h); + h +} +fn monomorphize_ref(h: &mut impl HugrMut) { let root = h.root(); // If the root is a polymorphic function, then there are no external calls, so nothing to do if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(&mut h, root, None, &mut HashMap::new()); + mono_scan(h, root, None, &mut HashMap::new()); if !h.get_optype(root).is_module() { - return remove_polyfuncs(h); + remove_polyfuncs_ref(h); } } - #[cfg(debug_assertions)] - validate(&h); - h } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -57,6 +57,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr { /// TODO replace this with a more general remove-unused-functions pass /// pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { + remove_polyfuncs_ref(&mut h); + h +} + +fn remove_polyfuncs_ref(h: &mut impl HugrMut) { let mut pfs_to_delete = Vec::new(); let mut to_scan = Vec::from_iter(h.children(h.root())); while let Some(n) = to_scan.pop() { @@ -69,7 +74,6 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { for n in pfs_to_delete { h.remove_subtree(n); } - h } fn is_polymorphic(fd: &FuncDefn) -> bool { @@ -93,7 +97,7 @@ type Instantiations = HashMap, Node>>; /// Optionally copies the subtree into a new location whilst applying a substitution. /// The subtree should be monomorphic after the substitution (if provided) has been applied. fn mono_scan( - h: &mut Hugr, + h: &mut impl HugrMut, parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, @@ -161,7 +165,7 @@ fn mono_scan( } fn instantiate( - h: &mut Hugr, + h: &mut impl HugrMut, poly_func: Node, type_args: Vec, mono_sig: Signature, @@ -218,20 +222,20 @@ fn instantiate( // 'ext' edges by copying every node before recursing on any of them, // 'dom' edges would *also* require recursing in dominator-tree preorder. for (&old_ch, &new_ch) in node_map.iter() { - for inport in h.node_inputs(old_ch).collect::>() { + for in_port in h.node_inputs(old_ch).collect::>() { // Edges from monomorphized functions to their calls already added during mono_scan() // as these depend not just on the original FuncDefn but also the TypeArgs - if h.linked_outputs(new_ch, inport).next().is_some() { + if h.linked_outputs(new_ch, in_port).next().is_some() { continue; }; - let srcs = h.linked_outputs(old_ch, inport).collect::>(); + let srcs = h.linked_outputs(old_ch, in_port).collect::>(); for (src, outport) in srcs { // Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside h.connect( node_map.get(&src).copied().unwrap_or(src), outport, new_ch, - inport, + in_port, ); } } @@ -240,6 +244,57 @@ fn instantiate( mono_tgt } +use crate::validation::{ValidatePassError, ValidationLevel}; + +/// Replaces calls to polymorphic functions with calls to new monomorphic +/// instantiations of the polymorphic ones. +/// +/// If the Hugr is [Module](OpType::Module)-rooted, +/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them) +/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these +/// * else, the originals are removed (they are invisible from outside the Hugr). +/// +/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic +/// signature then the HUGR will not be modified. +/// +/// Monomorphic copies of polymorphic functions will be added to the HUGR as +/// children of the root node. We make best effort to ensure that names (derived +/// from parent function names and concrete type args) of new functions are unique +/// whenever the names of their parents are unique, but this is not guaranteed. +#[derive(Debug, Clone, Default)] +pub struct MonomorphizePass { + validation: ValidationLevel, +} + +#[derive(Debug, Error)] +#[non_exhaustive] +/// Errors produced by [MonomorphizePass]. +pub enum MonomorphizeError { + #[error(transparent)] + #[allow(missing_docs)] + ValidationError(#[from] ValidatePassError), +} + +impl MonomorphizePass { + /// Sets the validation level used before and after the pass is run. + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Run the Monomorphization pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { + monomorphize_ref(hugr); + Ok(()) + } + + /// Run the pass using specified configuration. + pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } +} + struct TypeArgsList<'a>(&'a [TypeArg]); impl std::fmt::Display for TypeArgsList<'_> { @@ -322,7 +377,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs}; + use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass}; + + use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -342,7 +399,8 @@ mod test { DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); - let hugr2 = monomorphize(hugr.clone()); + let mut hugr2 = hugr.clone(); + MonomorphizePass::default().run(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -397,14 +455,15 @@ mod test { let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); fb.finish_with_outputs([res1, res2])?; } - let hugr = mb.finish_hugr()?; + let mut hugr = mb.finish_hugr()?; assert_eq!( hugr.nodes() .filter(|n| hugr.get_optype(*n).is_func_defn()) .count(), 3 ); - let mono = monomorphize(hugr); + MonomorphizePass::default().run(&mut hugr)?; + let mono = hugr; mono.validate()?; let mut funcs = list_funcs(&mono); @@ -423,8 +482,10 @@ mod test { funcs.into_keys().sorted().collect_vec(), ["double", "main", "triple"] ); + let mut mono2 = mono.clone(); + MonomorphizePass::default().run(&mut mono2)?; - assert_eq!(monomorphize(mono.clone()), mono); // Idempotent + assert_eq!(mono2, mono); // Idempotent let nopoly = remove_polyfuncs(mono); let mut funcs = list_funcs(&nopoly); @@ -527,9 +588,10 @@ mod test { .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) .unwrap() .outputs_arr(); - let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); + let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - let mono_hugr = monomorphize(hugr); + MonomorphizePass::default().run(&mut hugr).unwrap(); + let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); let pf2_name = mangle_inner_func("pf1", "pf2"); @@ -588,8 +650,9 @@ mod test { .outputs_arr(); let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); - let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - let mono_hugr = monomorphize(hugr); + let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); + MonomorphizePass::default().run(&mut hugr)?; + let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); @@ -606,7 +669,7 @@ mod test { #[test] fn load_function() { - let hugr = { + let mut hugr = { let mut module_builder = ModuleBuilder::new(); let foo = { let builder = module_builder @@ -645,9 +708,10 @@ mod test { module_builder.finish_hugr().unwrap() }; - let mono_hugr = remove_polyfuncs(monomorphize(hugr)); + MonomorphizePass::default().run(&mut hugr).unwrap(); + remove_polyfuncs_ref(&mut hugr); - let funcs = list_funcs(&mono_hugr); + let funcs = list_funcs(&hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); }