From fc609a202353c1c01eec4f316f919b2d39ffcd85 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 27 Nov 2024 14:13:45 +0000 Subject: [PATCH] feat: add HugrView::first_child and HugrMut::remove_subtree (#1721) * Clarify doc of HugrMut::remove_node * Add HugrView::first_child, we have this anyway and it's useful because Rust's non-lexical lifetimes don't go far enough * Add HugrMut::remove_subtree, test closes #1663 --- hugr-core/src/hugr/hugrmut.rs | 46 ++++++++++++++++++++++++++++++++++- hugr-core/src/hugr/rewrite.rs | 5 +--- hugr-core/src/hugr/views.rs | 6 +++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 3c538c357..3d9edc050 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -119,6 +119,8 @@ pub trait HugrMut: HugrMutInternals { } /// Remove a node from the graph and return the node weight. + /// Note that if the node has children, they are not removed; this leaves + /// the Hugr in an invalid state. See [Self::remove_subtree]. /// /// # Panics /// @@ -129,6 +131,19 @@ pub trait HugrMut: HugrMutInternals { self.hugr_mut().remove_node(node) } + /// Remove a node from the graph, along with all its descendants in the hierarchy. + /// + /// # Panics + /// + /// If the node is not in the graph, or is the root (this would leave an empty Hugr). + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + while let Some(ch) = self.first_child(node) { + self.remove_subtree(ch) + } + self.hugr_mut().remove_node(node); + } + /// Connect two nodes at the given ports. /// /// # Panics @@ -524,7 +539,7 @@ mod test { PRELUDE_REGISTRY, }, macros::type_row, - ops::{self, dataflow::IOTrait}, + ops::{self, dataflow::IOTrait, FuncDefn, Input, Output}, types::{Signature, Type}, }; @@ -583,4 +598,33 @@ mod test { hugr.remove_metadata(root, "meta"); assert_eq!(hugr.get_metadata(root, "meta"), None); } + + #[test] + fn remove_subtree() { + let mut hugr = Hugr::default(); + let root = hugr.root(); + let [foo, bar] = ["foo", "bar"].map(|name| { + let fd = hugr.add_node_with_parent( + root, + FuncDefn { + name: name.to_string(), + signature: Signature::new_endo(NAT).into(), + }, + ); + let inp = hugr.add_node_with_parent(fd, Input::new(NAT)); + let out = hugr.add_node_with_parent(fd, Output::new(NAT)); + hugr.connect(inp, 0, out, 0); + fd + }); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 7); + + hugr.remove_subtree(foo); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 4); + + hugr.remove_subtree(bar); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 1); + } } diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index dd26b1ac2..3354fc820 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -70,14 +70,11 @@ impl Rewrite for Transactional { let mut backup = Hugr::new(h.root_type().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply(h); - fn first_child(h: &impl HugrView) -> Option { - h.children(h.root()).next() - } if r.is_err() { // Try to restore backup. h.replace_op(h.root(), backup.root_type().clone()) .expect("The root replacement should always match the old root type"); - while let Some(child) = first_child(h) { + while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } h.insert_from_view(h.root(), &backup); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 7d744c150..442625e33 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -261,6 +261,12 @@ pub trait HugrView: HugrInternals { /// Return iterator over the direct children of node. fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; + /// Returns the first child of the specified node (if it is a parent). + /// Useful because `x.children().next()` leaves x borrowed. + fn first_child(&self, node: Node) -> Option { + self.children(node).next() + } + /// Iterates over neighbour nodes in the given direction. /// May contain duplicates if the graph has multiple links between nodes. fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone;