Skip to content

Commit

Permalink
feat: implement RemoveConst and RemoveConstIgnore (#757)
Browse files Browse the repository at this point in the history
as per spec

refactor!: allow Into<Const> for builder.add_const

BREAKING_CHANGES: existing CustomConst.into() calls will error

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
  • Loading branch information
5 people committed Jan 3, 2024
1 parent 4b6123e commit 9500803
Showing 1 changed file with 19 additions and 37 deletions.
56 changes: 19 additions & 37 deletions src/hugr/rewrite/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,31 @@ use crate::{
hugr::{HugrError, HugrMut},
HugrView, Node,
};
#[rustversion::since(1.75)] // uses impl in return position
use itertools::Itertools;
use thiserror::Error;

use super::Rewrite;

/// Remove a [`crate::ops::LoadConstant`] node with no outputs.
/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
#[derive(Debug, Clone)]
pub struct RemoveConstIgnore(pub Node);

/// Error from an [`RemoveConstIgnore`] operation.
/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveConstIgnoreError {
pub enum RemoveError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not LoadConst).")]
#[error("Node is invalid (either not in HUGR or not correct operation).")]
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Not connected to a Const.
#[error("Node: {0:?} is not connected to a Const node.")]
NoConst(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

#[rustversion::since(1.75)] // uses impl in return position
impl Rewrite for RemoveConstIgnore {
type Error = RemoveConstIgnoreError;
type Error = RemoveError;

// The Const node the LoadConstant was connected to.
type ApplyResult = Node;
Expand All @@ -48,14 +43,14 @@ impl Rewrite for RemoveConstIgnore {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
return Err(RemoveConstIgnoreError::InvalidNode(node));
return Err(RemoveError::InvalidNode(node));
}

if h.out_value_types(node)
.next()
.is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some())
{
return Err(RemoveConstIgnoreError::ValueUsed(node));
return Err(RemoveError::ValueUsed(node));
}

Ok(())
Expand All @@ -67,7 +62,8 @@ impl Rewrite for RemoveConstIgnore {
let source = h
.input_neighbours(node)
.exactly_one()
.map_err(|_| RemoveConstIgnoreError::NoConst(node))?;
.ok()
.expect("Validation should check a Const is connected to LoadConstant.");
h.remove_node(node)?;

Ok(source)
Expand All @@ -82,22 +78,8 @@ impl Rewrite for RemoveConstIgnore {
#[derive(Debug, Clone)]
pub struct RemoveConst(pub Node);

/// Error from an [`RemoveConst`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveConstError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not Const).")]
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

impl Rewrite for RemoveConst {
type Error = RemoveConstError;
type Error = RemoveError;

// The parent of the Const node.
type ApplyResult = Node;
Expand All @@ -110,11 +92,11 @@ impl Rewrite for RemoveConst {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
return Err(RemoveConstError::InvalidNode(node));
return Err(RemoveError::InvalidNode(node));
}

if h.output_neighbours(node).next().is_some() {
return Err(RemoveConstError::ValueUsed(node));
return Err(RemoveError::ValueUsed(node));
}

Ok(())
Expand All @@ -123,20 +105,19 @@ impl Rewrite for RemoveConst {
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let source = h
let parent = h
.get_parent(node)
.expect("Const node without a parent shouldn't happen.");
h.remove_node(node)?;

Ok(source)
Ok(parent)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

#[rustversion::since(1.75)] // uses impl in return position
#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -169,17 +150,18 @@ mod test {
dfg_build.finish_sub_container()?;

let mut h = build.finish_prelude_hugr()?;
// nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
assert_eq!(h.node_count(), 8);
let tup_node = tup.node();
// can't remove invalid node
assert_eq!(
h.apply_rewrite(RemoveConst(tup_node)),
Err(RemoveConstError::InvalidNode(tup_node))
Err(RemoveError::InvalidNode(tup_node))
);

assert_eq!(
h.apply_rewrite(RemoveConstIgnore(tup_node)),
Err(RemoveConstIgnoreError::InvalidNode(tup_node))
Err(RemoveError::InvalidNode(tup_node))
);
let load_1_node = load_1.node();
let load_2_node = load_2.node();
Expand All @@ -202,7 +184,7 @@ mod test {
// can't remove nodes in use
assert_eq!(
h.apply_rewrite(remove_1.clone()),
Err(RemoveConstIgnoreError::ValueUsed(load_1_node))
Err(RemoveError::ValueUsed(load_1_node))
);

// remove the use
Expand All @@ -215,7 +197,7 @@ mod test {
// still can't remove const, in use by second load
assert_eq!(
h.apply_rewrite(remove_con.clone()),
Err(RemoveConstError::ValueUsed(con_node))
Err(RemoveError::ValueUsed(con_node))
);

// remove second use
Expand Down

0 comments on commit 9500803

Please sign in to comment.