Skip to content

Commit

Permalink
refactor: avoid hugr clone in simple replace
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 27, 2024
1 parent c87d4f3 commit 44faf10
Showing 1 changed file with 100 additions and 58 deletions.
158 changes: 100 additions & 58 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,87 @@ impl Rewrite for SimpleReplacement {
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.subgraph.get_parent(h);
let Self {
subgraph,
replacement,
nu_inp,
nu_out,
} = self;
let parent = subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
return Err(SimpleReplacementError::InvalidParentNode());
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in self.subgraph.nodes() {
for node in subgraph.nodes() {
if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() {
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}

let replacement_output_node = replacement
.get_io(replacement.root())
.expect("parent already checked.")[1];

// 3. Do the replacement.
// 3.1. Insert the replacement as a whole.
// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
//
// We delay creating these connections to avoid them getting mixed with
// the pre-existing ones in the following logic.
//
// Existing connections to the removed subgraph will be automatically
// removed when the nodes are removed.

// 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
let nu_inp_connects: Vec<_> = nu_inp
.iter()
.filter(|&((rep_inp_node, _), _)| {
replacement.get_optype(*rep_inp_node).tag() != OpTag::Output
})
.map(
|((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| {
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
(
rem_inp_pred_node,
rem_inp_pred_port,
// the new input node will be updated after insertion
rep_inp_node,
rep_inp_port,
)
},
)
.collect();

// 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
let nu_out_connects: Vec<_> = nu_out
.iter()
.filter_map(|((rem_out_node, rem_out_port), rep_out_port)| {
let (rep_out_pred_node, rep_out_pred_port) = replacement
.single_linked_output(replacement_output_node, *rep_out_port)
.unwrap();
(replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({
(
// the new output node will be updated after insertion
rep_out_pred_node,
rep_out_pred_port,
rem_out_node,
rem_out_port,
)
})
})
.collect();

// 3.3. Insert the replacement as a whole.
let InsertionResult {
new_root,
node_map: index_map,
} = h.insert_hugr(parent, self.replacement.clone());
} = h.insert_hugr(parent, replacement);

// remove the Input and Output nodes from the replacement graph
let replace_children = h.children(new_root).collect::<Vec<Node>>();
Expand All @@ -97,60 +161,39 @@ impl Rewrite for SimpleReplacement {
// remove the replacement root (which now has no children and no edges)
h.remove_node(new_root);

// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
//
// We delay creating these connections to avoid them getting mixed with
// the pre-existing ones in the following logic.
//
// Existing connections to the removed subgraph will be automatically
// removed when the nodes are removed.
let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new();

// 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
let new_inp_node = index_map.get(rep_inp_node).unwrap();
connect.insert((
rem_inp_pred_node,
rem_inp_pred_port,
*new_inp_node,
*rep_inp_port,
));
}
}
let replacement_output_node = self
.replacement
.get_io(self.replacement.root())
.expect("parent already checked.")[1];
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
let (rep_out_pred_node, rep_out_pred_port) = self
.replacement
.single_linked_output(replacement_output_node, *rep_out_port)
.unwrap();
if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input {
let new_out_node = index_map.get(&rep_out_pred_node).unwrap();
connect.insert((
*new_out_node,
rep_out_pred_port,
*rem_out_node,
*rem_out_port,
));
}
}
// 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
// 3.4. Update replacement nodes according to insertion mapping and load in to
// connection set.
let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> =
HashSet::with_capacity(nu_inp_connects.len() + nu_out_connects.len() + nu_out.len());

connect.extend(nu_inp_connects.into_iter().map(
|(src_node, src_port, tgt_node, tgt_port)| {
(
src_node,
src_port,
*index_map.get(tgt_node).unwrap(),
*tgt_port,
)
},
));

connect.extend(nu_out_connects.into_iter().map(
|(src_node, src_port, tgt_node, tgt_port)| {
(
*index_map.get(&src_node).unwrap(),
src_port,
*tgt_node,
*tgt_port,
)
},
));

// 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
// to p1.
//
// i.e. the replacement graph has direct edges between the input and output nodes.
for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out {
let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port));
for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out {
let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let (rem_inp_pred_node, rem_inp_pred_port) = h
Expand All @@ -175,9 +218,8 @@ impl Rewrite for SimpleReplacement {
h.connect(src_node, src_port, tgt_node, tgt_port);
});

// 3.5. Remove all nodes in self.removal and edges between them.
Ok(self
.subgraph
// 3.6. Remove all nodes in subgraph and edges between them.
Ok(subgraph
.nodes()
.iter()
.map(|&node| (node, h.remove_node(node)))
Expand Down

0 comments on commit 44faf10

Please sign in to comment.