Skip to content

Commit

Permalink
fix: hierarchical simple replacement using insert_hugr (#1718)
Browse files Browse the repository at this point in the history
Closes #1715
  • Loading branch information
ss2165 authored Nov 26, 2024
1 parent c5c8a6f commit 6a75f4c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
2 changes: 1 addition & 1 deletion hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ pub(crate) mod test {
pub(super) const QB: Type = crate::extension::prelude::QB_T;

/// Wire up inputs of a Dataflow container to the outputs.
pub(super) fn n_identity<T: DataflowSubContainer>(
pub(crate) fn n_identity<T: DataflowSubContainer>(
dataflow_builder: T,
) -> Result<T::ContainerHandle, BuildError> {
let w = dataflow_builder.input_wires();
Expand Down
109 changes: 75 additions & 34 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
use std::collections::{HashMap, HashSet};

use crate::hugr::hugrmut::InsertionResult;
pub use crate::hugr::internal::HugrMutInternals;
use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite};
use crate::hugr::{HugrMut, HugrView, Rewrite};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::{Hugr, IncomingPort, Node, OutgoingPort};
use thiserror::Error;

use super::inline_dfg::InlineDFGError;

/// Specification of a simple replacement operation.
#[derive(Debug, Clone)]
pub struct SimpleReplacement {
Expand Down Expand Up @@ -62,7 +66,7 @@ impl Rewrite for SimpleReplacement {
unimplemented!()
}

fn apply(mut self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
Expand All @@ -75,39 +79,23 @@ impl Rewrite for SimpleReplacement {
}
}
// 3. Do the replacement.
// 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes.
// Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self).
let mut index_map: HashMap<Node, Node> = HashMap::new();
let replacement_nodes = self
.replacement
.children(self.replacement.root())
.collect::<Vec<Node>>();
// slice of nodes omitting Input and Output:
let replacement_inner_nodes = &replacement_nodes[2..];
let self_output_node = h.children(parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
let op: &OpType = self.replacement.get_optype(node);
let new_node = h.add_node_after(self_output_node, op.clone());
index_map.insert(node, new_node);

// Move the metadata
let meta: Option<NodeMetadataMap> = self.replacement.take_node_metadata(node);
h.overwrite_node_metadata(new_node, meta);
// 3.1. Insert the replacement as a whole.
let InsertionResult {
new_root,
node_map: index_map,
} = h.insert_hugr(parent, self.replacement.clone());

// remove the Input and Output nodes from the replacement graph
let replace_children = h.children(new_root).collect::<Vec<Node>>();
for &io in &replace_children[..2] {
h.remove_node(io);
}
// Add edges between all newly added nodes matching those in replacement.
for &node in replacement_inner_nodes {
let new_node = index_map.get(&node).unwrap();
for outport in self.replacement.node_outputs(node) {
for target in self.replacement.linked_inputs(node, outport) {
if self.replacement.get_optype(target.0).tag() != OpTag::Output {
let new_target = index_map.get(&target.0).unwrap();
h.connect(*new_node, outport, *new_target, target.1);
}
}
}
// make all replacement top level children children of the parent
for &child in &replace_children[2..] {
h.set_parent(child, parent);
}
// remove the replacement root (which now has no children and no edges)
h.remove_node(new_root);

// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
Expand Down Expand Up @@ -136,6 +124,10 @@ impl Rewrite for SimpleReplacement {
));
}
}
let replacement_output_node = self
.replacement
.get_io(self.replacement.root())
.expect("parent already checked.")[1];
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
Expand Down Expand Up @@ -213,6 +205,9 @@ pub enum SimpleReplacementError {
/// Node in replacement graph is invalid.
#[error("A node in the replacement graph is invalid.")]
InvalidReplacementNode(),
/// Inlining replacement failed.
#[error("Inlining replacement failed: {0}")]
InliningFailed(#[from] InlineDFGError),
}

#[cfg(test)]
Expand All @@ -221,11 +216,12 @@ pub(in crate::hugr::rewrite) mod test {
use rstest::{fixture, rstest};
use std::collections::{HashMap, HashSet};

use crate::builder::test::n_identity;
use crate::builder::{
endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::BOOL_T;
use crate::extension::prelude::{BOOL_T, QB_T};
use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Rewrite};
Expand Down Expand Up @@ -774,6 +770,51 @@ pub(in crate::hugr::rewrite) mod test {
assert_eq!(hugr.node_count(), 4);
}

#[rstest]
fn test_nested_replace(dfg_hugr2: Hugr) {
// replace a node with a hugr with children

let mut h = dfg_hugr2;
let h_node = h
.nodes()
.find(|node: &Node| *h.get_optype(*node) == h_gate().into())
.unwrap();

// build a nested identity dfg
let mut nest_build = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap();
let [input] = nest_build.input_wires_arr();
let inner_build = nest_build.dfg_builder_endo([(QB_T, input)]).unwrap();
let inner_dfg = n_identity(inner_build).unwrap();
let inner_dfg_node = inner_dfg.node();
let replacement = nest_build
.finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)])
.unwrap();
let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
let nu_inp = vec![(
(inner_dfg_node, IncomingPort::from(0)),
(h_node, IncomingPort::from(0)),
)]
.into_iter()
.collect();

let nu_out = vec![(
(h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
IncomingPort::from(0),
)]
.into_iter()
.collect();

let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);

assert_eq!(h.node_count(), 4);

rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
h.update_validate(&PRELUDE_REGISTRY)
.unwrap_or_else(|e| panic!("{e}"));

assert_eq!(h.node_count(), 6);
}

use crate::hugr::rewrite::replace::Replacement;
fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement {
use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
Expand Down

0 comments on commit 6a75f4c

Please sign in to comment.