diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index a4507e156..c7a9c461a 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -7,7 +7,14 @@ mod half_node; pub mod lower; pub mod merge_bbs; mod monomorphize; -pub use monomorphize::{monomorphize, remove_polyfuncs}; +// TODO: Deprecated re-export. Remove on a breaking release. +#[deprecated( + since = "0.14.1", + note = "Use `hugr::algorithms::MonomorphizePass` instead." +)] +#[allow(deprecated)] +pub use monomorphize::monomorphize; +pub use monomorphize::{remove_polyfuncs, MonomorphizePass}; pub mod nest_cfgs; pub mod non_local; pub mod validation; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 5d5eafca6..748e59803 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -12,6 +12,7 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; +use thiserror::Error; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -28,6 +29,11 @@ use itertools::Itertools as _; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. +#[deprecated( + since = "0.14.1", + note = "Use `hugr::algorithms::MonomorphizePass` instead." +)] +// TODO: Deprecated. Remove on a breaking release. pub fn monomorphize(mut h: Hugr) -> Hugr { let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}")); @@ -59,18 +65,7 @@ fn monomorphize_ref(h: &mut impl HugrMut) { /// TODO replace this with a more general remove-unused-functions pass /// pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { - let mut pfs_to_delete = Vec::new(); - let mut to_scan = Vec::from_iter(h.children(h.root())); - while let Some(n) = to_scan.pop() { - if is_polymorphic_funcdefn(h.get_optype(n)) { - pfs_to_delete.push(n) - } else { - to_scan.extend(h.children(n)); - } - } - for n in pfs_to_delete { - h.remove_subtree(n); - } + remove_polyfuncs_ref(&mut h); h } @@ -257,6 +252,57 @@ fn instantiate( mono_tgt } +use crate::validation::{ValidatePassError, ValidationLevel}; + +/// Replaces calls to polymorphic functions with calls to new monomorphic +/// instantiations of the polymorphic ones. +/// +/// If the Hugr is [Module](OpType::Module)-rooted, +/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them) +/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these +/// * else, the originals are removed (they are invisible from outside the Hugr). +/// +/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic +/// signature then the HUGR will not be modified. +/// +/// Monomorphic copies of polymorphic functions will be added to the HUGR as +/// children of the root node. We make best effort to ensure that names (derived +/// from parent function names and concrete type args) of new functions are unique +/// whenever the names of their parents are unique, but this is not guaranteed. +#[derive(Debug, Clone, Default)] +pub struct MonomorphizePass { + validation: ValidationLevel, +} + +#[derive(Debug, Error)] +#[non_exhaustive] +/// Errors produced by [MonomorphizePass]. +pub enum MonomorphizeError { + #[error(transparent)] + #[allow(missing_docs)] + ValidationError(#[from] ValidatePassError), +} + +impl MonomorphizePass { + /// Sets the validation level used before and after the pass is run. + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Run the Monomorphization pass. + fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { + monomorphize_ref(hugr); + Ok(()) + } + + /// Run the pass using specified configuration. + pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { + self.validation + .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) + } +} + struct TypeArgsList<'a>(&'a [TypeArg]); impl std::fmt::Display for TypeArgsList<'_> { @@ -339,7 +385,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs}; + use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass}; + + use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -359,7 +407,10 @@ mod test { DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); - let hugr2 = monomorphize(hugr.clone()); + let mut hugr2 = hugr.clone(); + MonomorphizePass::default() + .run_no_validate(&mut hugr2) + .unwrap(); assert_eq!(hugr, hugr2); } @@ -414,14 +465,15 @@ mod test { let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr(); fb.finish_with_outputs([res1, res2])?; } - let hugr = mb.finish_hugr()?; + let mut hugr = mb.finish_hugr()?; assert_eq!( hugr.nodes() .filter(|n| hugr.get_optype(*n).is_func_defn()) .count(), 3 ); - let mono = monomorphize(hugr); + MonomorphizePass::default().run_no_validate(&mut hugr)?; + let mono = hugr; mono.validate()?; let mut funcs = list_funcs(&mono); @@ -440,8 +492,10 @@ mod test { funcs.into_keys().sorted().collect_vec(), ["double", "main", "triple"] ); + let mut mono2 = mono.clone(); + MonomorphizePass::default().run_no_validate(&mut mono2)?; - assert_eq!(monomorphize(mono.clone()), mono); // Idempotent + assert_eq!(mono2, mono); // Idempotent let nopoly = remove_polyfuncs(mono); let mut funcs = list_funcs(&nopoly); @@ -544,9 +598,12 @@ mod test { .call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped]) .unwrap() .outputs_arr(); - let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); + let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - let mono_hugr = monomorphize(hugr); + MonomorphizePass::default() + .run_no_validate(&mut hugr) + .unwrap(); + let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); let pf2_name = mangle_inner_func("pf1", "pf2"); @@ -605,8 +662,9 @@ mod test { .outputs_arr(); let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); - let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - let mono_hugr = monomorphize(hugr); + let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); + MonomorphizePass::default().run_no_validate(&mut hugr)?; + let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); @@ -623,7 +681,7 @@ mod test { #[test] fn load_function() { - let hugr = { + let mut hugr = { let mut module_builder = ModuleBuilder::new(); let foo = { let builder = module_builder @@ -662,9 +720,12 @@ mod test { module_builder.finish_hugr().unwrap() }; - let mono_hugr = remove_polyfuncs(monomorphize(hugr)); + MonomorphizePass::default() + .run_no_validate(&mut hugr) + .unwrap(); + remove_polyfuncs_ref(&mut hugr); - let funcs = list_funcs(&mono_hugr); + let funcs = list_funcs(&hugr); assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd))); }