Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fold integer operations #759

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bffed99
wip: constant folding
ss2165 Nov 13, 2023
1a27d54
start moving folding to op_def
ss2165 Nov 20, 2023
b84766b
thread through folding methods
ss2165 Nov 23, 2023
8ee49da
integer addition tests passing
ss2165 Nov 23, 2023
520de7c
remove FoldOutput
ss2165 Nov 24, 2023
1d656d6
Merge branch 'main' into feat/const-fold2
ss2165 Dec 18, 2023
9398d9d
refactor int folding to separate repo
ss2165 Dec 18, 2023
7b955a9
add tuple and sum constant folding
ss2165 Dec 18, 2023
6cb3c62
simplify test code
ss2165 Dec 18, 2023
0500624
wip: fold finder
ss2165 Dec 20, 2023
8f554e0
chore(deps): bump actions/upload-artifact from 3 to 4 (#751)
dependabot[bot] Dec 20, 2023
215eb40
chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752)
dependabot[bot] Dec 20, 2023
ff26546
fix: case node should not have an external signature (#749)
ss2165 Dec 20, 2023
64b9199
refactor: move hugr equality check out for reuse
ss2165 Dec 20, 2023
6d7d440
feat: implement RemoveConst and RemoveConstIgnore
ss2165 Dec 21, 2023
cdde503
use remove rewrites while folding
ss2165 Dec 21, 2023
114524c
alllow candidate node specification in find_consts
ss2165 Dec 21, 2023
a087fbc
add exhaustive fold pass
ss2165 Dec 21, 2023
07768b2
refactor!: use enum op traits for floats + conversions
ss2165 Dec 21, 2023
9a81260
Merge branch 'refactor/fops-enum' into feat/const-fold2
ss2165 Dec 21, 2023
658adf4
add folding definitions for float ops
ss2165 Dec 21, 2023
2c0e75b
refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
dc7ff13
refactor: const ConstF64::new
ss2165 Dec 21, 2023
aa73ab2
feat: implement folding for conversion ops
ss2165 Dec 21, 2023
a519f34
fixup! refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
a7a4088
Merge branch 'main' into feat/const-fold2
ss2165 Dec 21, 2023
46075c2
implement bigger tests and fix unearthed bugs
ss2165 Dec 21, 2023
df854e8
Revert "refactor: move hugr equality check out for reuse"
ss2165 Dec 22, 2023
ba81e7b
feat: implement RemoveConst and RemoveConstIgnore
ss2165 Dec 21, 2023
09ce1c9
remove conversion foldin
ss2165 Dec 22, 2023
5a372c7
Merge branch 'main' into feat/const-fold-floats
ss2165 Dec 22, 2023
26bc5ff
add rust version guards
ss2165 Dec 22, 2023
b513ace
Merge branch 'feat/const-rewrites' into feat/const-fold-floats
ss2165 Dec 22, 2023
5a71f75
docs: add public method docstrings
ss2165 Dec 22, 2023
6fa7eb9
add some docstrings and comments
ss2165 Dec 22, 2023
7381432
remove integer folding
ss2165 Dec 22, 2023
3bfda50
Revert "remove integer folding"
ss2165 Dec 22, 2023
0e0411f
remove unused imports
ss2165 Dec 22, 2023
8e88f3e
add docstrings and simplify
ss2165 Dec 22, 2023
41fa47a
Merge branch 'feat/const-fold-floats' into feat/fold-ints
ss2165 Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Algorithms using the Hugr.

pub mod const_fold;
mod half_node;
pub mod nest_cfgs;
352 changes: 352 additions & 0 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
//! Constant folding routines.

use std::collections::{BTreeSet, HashMap};

use itertools::Itertools;

use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::{ConstFoldResult, ExtensionRegistry},
hugr::{
rewrite::consts::{RemoveConst, RemoveConstIgnore},
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp, OpType},
type_row,
types::{FunctionType, Type, TypeEnum},
values::Value,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
fn out_row(consts: impl IntoIterator<Item = Const>) -> ConstFoldResult {
let vec = consts
.into_iter()
.enumerate()
.map(|(i, c)| (i.into(), c))
.collect();
Some(vec)
}

/// Sort folding inputs with [`IncomingPort`] as key
fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> {
let mut v: Vec<_> = consts.iter().collect();
v.sort_by_key(|(i, _)| i);
v
}

/// Sort some input constants by port and just return the constants.
pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
sort_by_in_port(consts)
.into_iter()
.map(|(_, c)| c)
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
let op = op.as_leaf_op()?;

match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())])
}
LeafOp::UnpackTuple { .. } => {
let c = &consts.first()?.1;

if let Value::Tuple { vs } = c.value() {
if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() {
return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| {
Const::new(v.clone(), t.clone())
.expect("types should already have been checked")
}));
}
}
None // could panic
}

LeafOp::Tag { tag, variants } => out_row([Const::new(
Value::sum(*tag, consts.first()?.1.value().clone()),
Type::new_sum(variants.clone()),
)
.unwrap()]),
LeafOp::CustomOp(_) => {
let ext_op = op.as_extension_op()?;

ext_op.constant_fold(consts)
}
_ => None,
}
}

/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Const>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Const::const_type).cloned().collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
.map(|c| b.add_load_const(c).unwrap())
.collect_vec();

b.finish_hugr_with_outputs(outputs, reg).unwrap()
}

/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`,
/// return an iterator of possible constant folding rewrites. The
/// [`SimpleReplacement`] replaces an operation with constants that result from
/// evaluating it, the extension registry `reg` is used to validate the
/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the
/// LoadConstant nodes that could be removed.
pub fn find_consts<'a, 'r: 'a>(
hugr: &'a impl HugrView,
candidate_nodes: impl IntoIterator<Item = Node> + 'a,
reg: &'r ExtensionRegistry,
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveConstIgnore>)> + 'a {
// track nodes for operations that have already been considered for folding
let mut used_neighbours = BTreeSet::new();

candidate_nodes
.into_iter()
.filter_map(move |n| {
// only look at LoadConstant
hugr.get_optype(n).is_load_constant().then_some(())?;

let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?;
let neighbours = hugr
.linked_inputs(n, out_p)
.filter(|(n, _)| used_neighbours.insert(*n))
.collect_vec();
if neighbours.is_empty() {
// no uses of LoadConstant that haven't already been considered.
return None;
}
let fold_iter = neighbours
.into_iter()
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg));
Some(fold_iter)
})
.flatten()
}

/// Attempt to evaluate and generate rewrites for the operation at `op_node`
fn fold_op(
hugr: &impl HugrView,
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> {
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
let (con_op, load_n) = get_const(hugr, op_node, in_p)?;
Some(((in_p, con_op), RemoveConstIgnore(load_n)))
})
.unzip();
let neighbour_op = hugr.get_optype(op_node);
// attempt to evaluate op
let folded = fold_const(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
.into_iter()
.enumerate()
.filter_map(|(i, out)| {
// map from the ports the op was linked to, to the output ports of
// the replacement.
hugr.single_linked_input(op_node, out)
.map(|np| (np, i.into()))
})
.collect();
let replacement = const_graph(consts, reg);
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr)
.expect("Operation should form valid subgraph.");

let simple_replace = SimpleReplacement::new(
sibling_graph,
replacement,
// no inputs to replacement
HashMap::new(),
nu_out,
);
Some((simple_replace, removals))
}

/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant
/// and the LoadConstant node
fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> {
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?;
let load_op = hugr.get_optype(load_n).as_load_constant()?;
let const_node = hugr
.linked_outputs(load_n, load_op.constant_port())
.exactly_one()
.ok()?
.0;

let const_op = hugr.get_optype(const_node).as_const()?;

// TODO avoid const clone here
Some((const_op.clone(), load_n))
}

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
loop {
// would be preferable if the candidates were updated to be just the
// neighbouring nodes of those added.
let rewrites = find_consts(h, h.nodes(), reg).collect_vec();
if rewrites.is_empty() {
break;
}
for (replace, removes) in rewrites {
h.apply_rewrite(replace).unwrap();
for rem in removes {
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
if h.apply_rewrite(RemoveConst(const_node)).is_err() {
// const cannot be removed - no problem
continue;
}
}
}
}
}
}

#[cfg(test)]
mod test {

use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic;

use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_ops::IntOpDef;
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};

use rstest::rstest;

use super::*;

/// int to constant
fn i2c(b: u64) -> Const {
Const::new(
ConstIntU::new(5, b).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

/// float to constant
fn f2c(f: f64) -> Const {
ConstF64::new(f).into()
}

#[rstest]
#[case(0, 0, 0)]
#[case(0, 1, 1)]
#[case(23, 435, 458)]
// c = a + b
fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) {
let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))];
let add_op: OpType = IntOpDef::iadd.with_width(5).into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), i2c(c))]);
}

#[test]
fn test_fold() {
/*
Test hugr calculates
1 + 2 == 3
*/
let mut b = DFGBuilder::new(FunctionType::new(
type_row![],
vec![INT_TYPES[5].to_owned()],
))
.unwrap();

let one = b.add_load_const(i2c(1)).unwrap();
let two = b.add_load_const(i2c(2)).unwrap();

let add = b
.add_dataflow_op(IntOpDef::iadd.with_width(5), [one, two])
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
arithmetic::int_ops::EXTENSION.to_owned(),
])
.unwrap();
let mut h = b.finish_hugr_with_outputs(add.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 8);

let (repl, removes) = find_consts(&h, h.nodes(), &reg).exactly_one().ok().unwrap();
let [remove_1, remove_2] = removes.try_into().unwrap();

h.apply_rewrite(repl).unwrap();
for rem in [remove_1, remove_2] {
let const_node = h.apply_rewrite(rem).unwrap();
h.apply_rewrite(RemoveConst(const_node)).unwrap();
}

assert_fully_folded(&h, &i2c(3));
}

#[test]
fn test_big() {
/*
Test hugr approximately calculates
let x = (5.5, 3.25);
x.0 - x.1 == 2.25
*/
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
.unwrap();

let unpack = build
.add_dataflow_op(
LeafOp::UnpackTuple {
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE],
},
[tup],
)
.unwrap();

let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 7);

constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &f2c(2.25));
}
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
let mut node_count = 0;

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c == expected_const => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}

assert_eq!(node_count, 4);
}
}
6 changes: 3 additions & 3 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -374,7 +374,7 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1))?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -173,7 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let wire = branch_1.add_load_const(ConstUsize::new(2))?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
Loading