Skip to content

Commit

Permalink
feat: add MonomorphizePass and deprecate monomorphize (#1809)
Browse files Browse the repository at this point in the history
Co-authored-by: Douglas Wilson <[email protected]>
  • Loading branch information
ss2165 and doug-q authored Dec 18, 2024
1 parent e065d70 commit 7ee173c
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 33 deletions.
9 changes: 8 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ mod half_node;
pub mod lower;
pub mod merge_bbs;
mod monomorphize;
pub use monomorphize::{monomorphize, remove_polyfuncs};
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
)]
#[allow(deprecated)]
pub use monomorphize::monomorphize;
pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass};
pub mod nest_cfgs;
pub mod non_local;
pub mod validation;
Expand Down
128 changes: 96 additions & 32 deletions hugr-passes/src/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use hugr_core::{
Node,
};

use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType};
use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType};
use itertools::Itertools as _;
use thiserror::Error;

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
Expand All @@ -28,26 +29,25 @@ use itertools::Itertools as _;
/// children of the root node. We make best effort to ensure that names (derived
/// from parent function names and concrete type args) of new functions are unique
/// whenever the names of their parents are unique, but this is not guaranteed.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
)]
// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`.
pub fn monomorphize(mut h: Hugr) -> Hugr {
let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}"));

// We clone the extension registry because we will need a reference to
// create our mutable substitutions. This is cannot cause a problem because
// we will not be adding any new types or extension ops to the HUGR.
#[cfg(debug_assertions)]
validate(&h);
monomorphize_ref(&mut h);
h
}

fn monomorphize_ref(h: &mut impl HugrMut) {
let root = h.root();
// If the root is a polymorphic function, then there are no external calls, so nothing to do
if !is_polymorphic_funcdefn(h.get_optype(root)) {
mono_scan(&mut h, root, None, &mut HashMap::new());
mono_scan(h, root, None, &mut HashMap::new());
if !h.get_optype(root).is_module() {
return remove_polyfuncs(h);
remove_polyfuncs_ref(h);
}
}
#[cfg(debug_assertions)]
validate(&h);
h
}

/// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have
Expand All @@ -57,6 +57,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
/// TODO replace this with a more general remove-unused-functions pass
/// <https://github.com/CQCL/hugr/issues/1753>
pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
remove_polyfuncs_ref(&mut h);
h
}

fn remove_polyfuncs_ref(h: &mut impl HugrMut) {
let mut pfs_to_delete = Vec::new();
let mut to_scan = Vec::from_iter(h.children(h.root()));
while let Some(n) = to_scan.pop() {
Expand All @@ -69,7 +74,6 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
for n in pfs_to_delete {
h.remove_subtree(n);
}
h
}

fn is_polymorphic(fd: &FuncDefn) -> bool {
Expand All @@ -93,7 +97,7 @@ type Instantiations = HashMap<Node, HashMap<Vec<TypeArg>, Node>>;
/// Optionally copies the subtree into a new location whilst applying a substitution.
/// The subtree should be monomorphic after the substitution (if provided) has been applied.
fn mono_scan(
h: &mut Hugr,
h: &mut impl HugrMut,
parent: Node,
mut subst_into: Option<&mut Instantiating>,
cache: &mut Instantiations,
Expand Down Expand Up @@ -161,7 +165,7 @@ fn mono_scan(
}

fn instantiate(
h: &mut Hugr,
h: &mut impl HugrMut,
poly_func: Node,
type_args: Vec<TypeArg>,
mono_sig: Signature,
Expand Down Expand Up @@ -218,20 +222,20 @@ fn instantiate(
// 'ext' edges by copying every node before recursing on any of them,
// 'dom' edges would *also* require recursing in dominator-tree preorder.
for (&old_ch, &new_ch) in node_map.iter() {
for inport in h.node_inputs(old_ch).collect::<Vec<_>>() {
for in_port in h.node_inputs(old_ch).collect::<Vec<_>>() {
// Edges from monomorphized functions to their calls already added during mono_scan()
// as these depend not just on the original FuncDefn but also the TypeArgs
if h.linked_outputs(new_ch, inport).next().is_some() {
if h.linked_outputs(new_ch, in_port).next().is_some() {
continue;
};
let srcs = h.linked_outputs(old_ch, inport).collect::<Vec<_>>();
let srcs = h.linked_outputs(old_ch, in_port).collect::<Vec<_>>();
for (src, outport) in srcs {
// Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside
h.connect(
node_map.get(&src).copied().unwrap_or(src),
outport,
new_ch,
inport,
in_port,
);
}
}
Expand All @@ -240,6 +244,57 @@ fn instantiate(
mono_tgt
}

use crate::validation::{ValidatePassError, ValidationLevel};

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
///
/// If the Hugr is [Module](OpType::Module)-rooted,
/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them)
/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these
/// * else, the originals are removed (they are invisible from outside the Hugr).
///
/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic
/// signature then the HUGR will not be modified.
///
/// Monomorphic copies of polymorphic functions will be added to the HUGR as
/// children of the root node. We make best effort to ensure that names (derived
/// from parent function names and concrete type args) of new functions are unique
/// whenever the names of their parents are unique, but this is not guaranteed.
#[derive(Debug, Clone, Default)]
pub struct MonomorphizePass {
validation: ValidationLevel,
}

#[derive(Debug, Error)]
#[non_exhaustive]
/// Errors produced by [MonomorphizePass].
pub enum MonomorphizeError {
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
}

impl MonomorphizePass {
/// Sets the validation level used before and after the pass is run.
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Run the Monomorphization pass.
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> {
monomorphize_ref(hugr);
Ok(())
}

/// Run the pass using specified configuration.
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), MonomorphizeError> {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}
}

struct TypeArgsList<'a>(&'a [TypeArg]);

impl std::fmt::Display for TypeArgsList<'_> {
Expand Down Expand Up @@ -322,7 +377,9 @@ mod test {
use hugr_core::{Hugr, HugrView, Node};
use rstest::rstest;

use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs};
use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass};

use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs};

fn pair_type(ty: Type) -> Type {
Type::new_tuple(vec![ty.clone(), ty])
Expand All @@ -342,7 +399,8 @@ mod test {
DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap();
let hugr2 = monomorphize(hugr.clone());
let mut hugr2 = hugr.clone();
MonomorphizePass::default().run(&mut hugr2).unwrap();
assert_eq!(hugr, hugr2);
}

Expand Down Expand Up @@ -397,14 +455,15 @@ mod test {
let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr();
fb.finish_with_outputs([res1, res2])?;
}
let hugr = mb.finish_hugr()?;
let mut hugr = mb.finish_hugr()?;
assert_eq!(
hugr.nodes()
.filter(|n| hugr.get_optype(*n).is_func_defn())
.count(),
3
);
let mono = monomorphize(hugr);
MonomorphizePass::default().run(&mut hugr)?;
let mono = hugr;
mono.validate()?;

let mut funcs = list_funcs(&mono);
Expand All @@ -423,8 +482,10 @@ mod test {
funcs.into_keys().sorted().collect_vec(),
["double", "main", "triple"]
);
let mut mono2 = mono.clone();
MonomorphizePass::default().run(&mut mono2)?;

assert_eq!(monomorphize(mono.clone()), mono); // Idempotent
assert_eq!(mono2, mono); // Idempotent

let nopoly = remove_polyfuncs(mono);
let mut funcs = list_funcs(&nopoly);
Expand Down Expand Up @@ -527,9 +588,10 @@ mod test {
.call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped])
.unwrap()
.outputs_arr();
let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();
let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();

let mono_hugr = monomorphize(hugr);
MonomorphizePass::default().run(&mut hugr).unwrap();
let mono_hugr = hugr;
mono_hugr.validate().unwrap();
let funcs = list_funcs(&mono_hugr);
let pf2_name = mangle_inner_func("pf1", "pf2");
Expand Down Expand Up @@ -588,8 +650,9 @@ mod test {
.outputs_arr();
let mono = mono.finish_with_outputs([a, b]).unwrap();
let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap();
let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
let mono_hugr = monomorphize(hugr);
let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
MonomorphizePass::default().run(&mut hugr)?;
let mono_hugr = hugr;

let mut funcs = list_funcs(&mono_hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand All @@ -606,7 +669,7 @@ mod test {

#[test]
fn load_function() {
let hugr = {
let mut hugr = {
let mut module_builder = ModuleBuilder::new();
let foo = {
let builder = module_builder
Expand Down Expand Up @@ -645,9 +708,10 @@ mod test {
module_builder.finish_hugr().unwrap()
};

let mono_hugr = remove_polyfuncs(monomorphize(hugr));
MonomorphizePass::default().run(&mut hugr).unwrap();
remove_polyfuncs_ref(&mut hugr);

let funcs = list_funcs(&mono_hugr);
let funcs = list_funcs(&hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
}

Expand Down

0 comments on commit 7ee173c

Please sign in to comment.