Skip to content

Commit 968c8b0

Browse files
ss2165dependabot[bot]acl-cqclmondadacqc-alec
authored
feat: constant folding for arithmetic conversion operations (#720)
Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Alan Lawrence <[email protected]> Co-authored-by: Luca Mondada <[email protected]> Co-authored-by: Alec Edgington <[email protected]> Co-authored-by: Alan Lawrence <[email protected]> Co-authored-by: Agustín Borgna <[email protected]> Co-authored-by: Luca Mondada <[email protected]>
1 parent cf69e01 commit 968c8b0

File tree

3 files changed

+181
-12
lines changed

3 files changed

+181
-12
lines changed

src/algorithm/const_fold.rs

+33-12
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,24 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
217217
#[cfg(test)]
218218
mod test {
219219

220+
use super::*;
221+
use crate::extension::prelude::sum_with_error;
220222
use crate::extension::{ExtensionRegistry, PRELUDE};
221223
use crate::std_extensions::arithmetic;
222-
224+
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
223225
use crate::std_extensions::arithmetic::float_ops::FloatOps;
224226
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
225-
227+
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
226228
use rstest::rstest;
227229

228-
use super::*;
230+
/// int to constant
231+
fn i2c(b: u64) -> Const {
232+
Const::new(
233+
ConstIntU::new(5, b).unwrap().into(),
234+
INT_TYPES[5].to_owned(),
235+
)
236+
.unwrap()
237+
}
229238

230239
/// float to constant
231240
fn f2c(f: f64) -> Const {
@@ -244,19 +253,19 @@ mod test {
244253

245254
assert_eq!(&out[..], &[(0.into(), f2c(c))]);
246255
}
247-
248256
#[test]
249257
fn test_big() {
250258
/*
251-
Test hugr approximately calculates
252-
let x = (5.5, 3.25);
253-
x.0 - x.1 == 2.25
259+
Test approximately calculates
260+
let x = (5.6, 3.2);
261+
int(x.0 - x.1) == 2
254262
*/
263+
let sum_type = sum_with_error(INT_TYPES[5].to_owned());
255264
let mut build =
256-
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();
265+
DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap();
257266

258267
let tup = build
259-
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
268+
.add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)]))
260269
.unwrap();
261270

262271
let unpack = build
@@ -271,19 +280,31 @@ mod test {
271280
let sub = build
272281
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
273282
.unwrap();
283+
let to_int = build
284+
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
285+
.unwrap();
274286

275287
let reg = ExtensionRegistry::try_new([
276288
PRELUDE.to_owned(),
289+
arithmetic::int_types::EXTENSION.to_owned(),
277290
arithmetic::float_types::EXTENSION.to_owned(),
278291
arithmetic::float_ops::EXTENSION.to_owned(),
292+
arithmetic::conversions::EXTENSION.to_owned(),
279293
])
280294
.unwrap();
281-
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
282-
assert_eq!(h.node_count(), 7);
295+
let mut h = build
296+
.finish_hugr_with_outputs(to_int.outputs(), &reg)
297+
.unwrap();
298+
assert_eq!(h.node_count(), 8);
283299

284300
constant_fold_pass(&mut h, &reg);
285301

286-
assert_fully_folded(&h, &f2c(2.25));
302+
let expected = Value::Sum {
303+
tag: 0,
304+
value: Box::new(i2c(2).value().clone()),
305+
};
306+
let expected = Const::new(expected, sum_type).unwrap();
307+
assert_fully_folded(&h, &expected);
287308
}
288309
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
289310
// check the hugr just loads and returns a single const

src/std_extensions/arithmetic/conversions.rs

+14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::{
1919
use super::int_types::int_tv;
2020
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
2121
use lazy_static::lazy_static;
22+
mod const_fold;
2223
/// The extension identifier.
2324
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
2425

@@ -63,8 +64,21 @@ impl MakeOpDef for ConvertOpDef {
6364
}
6465
.to_string()
6566
}
67+
68+
fn post_opdef(&self, def: &mut OpDef) {
69+
const_fold::set_fold(self, def)
70+
}
6671
}
6772

73+
impl ConvertOpDef {
74+
/// Initialise a conversion op with an integer log width type argument.
75+
pub fn with_width(self, log_width: u8) -> ConvertOpType {
76+
ConvertOpType {
77+
def: self,
78+
log_width: log_width as u64,
79+
}
80+
}
81+
}
6882
/// Concrete convert operation with integer width set.
6983
#[derive(Debug, Clone, PartialEq)]
7084
pub struct ConvertOpType {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use crate::{
2+
extension::{
3+
prelude::{sum_with_error, ConstError},
4+
ConstFold, ConstFoldResult, OpDef,
5+
},
6+
ops,
7+
std_extensions::arithmetic::{
8+
float_types::ConstF64,
9+
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES},
10+
},
11+
types::ConstTypeError,
12+
values::{CustomConst, Value},
13+
IncomingPort,
14+
};
15+
16+
use super::ConvertOpDef;
17+
18+
pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) {
19+
use ConvertOpDef::*;
20+
21+
match op {
22+
trunc_u => def.set_constant_folder(TruncU),
23+
trunc_s => def.set_constant_folder(TruncS),
24+
convert_u => def.set_constant_folder(ConvertU),
25+
convert_s => def.set_constant_folder(ConvertS),
26+
}
27+
}
28+
29+
fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> {
30+
let [(_, c)] = consts else {
31+
return None;
32+
};
33+
c.get_custom_value()
34+
}
35+
36+
fn fold_trunc(
37+
type_args: &[crate::types::TypeArg],
38+
consts: &[(IncomingPort, ops::Const)],
39+
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>,
40+
) -> ConstFoldResult {
41+
let f: &ConstF64 = get_input(consts)?;
42+
let f = f.value();
43+
let [arg] = type_args else {
44+
return None;
45+
};
46+
let log_width = get_log_width(arg).ok()?;
47+
let int_type = INT_TYPES[log_width as usize].to_owned();
48+
let sum_type = sum_with_error(int_type.clone());
49+
let err_value = || {
50+
let err_val = ConstError {
51+
signal: 0,
52+
message: "Can't truncate non-finite float".to_string(),
53+
};
54+
let sum_val = Value::Sum {
55+
tag: 1,
56+
value: Box::new(err_val.into()),
57+
};
58+
59+
ops::Const::new(sum_val, sum_type.clone()).unwrap()
60+
};
61+
let out_const: ops::Const = if !f.is_finite() {
62+
err_value()
63+
} else {
64+
let cv = convert(f, log_width);
65+
if let Ok(cv) = cv {
66+
let sum_val = Value::Sum {
67+
tag: 0,
68+
value: Box::new(cv),
69+
};
70+
71+
ops::Const::new(sum_val, sum_type).unwrap()
72+
} else {
73+
err_value()
74+
}
75+
};
76+
77+
Some(vec![(0.into(), out_const)])
78+
}
79+
80+
struct TruncU;
81+
82+
impl ConstFold for TruncU {
83+
fn fold(
84+
&self,
85+
type_args: &[crate::types::TypeArg],
86+
consts: &[(IncomingPort, ops::Const)],
87+
) -> ConstFoldResult {
88+
fold_trunc(type_args, consts, |f, log_width| {
89+
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into)
90+
})
91+
}
92+
}
93+
94+
struct TruncS;
95+
96+
impl ConstFold for TruncS {
97+
fn fold(
98+
&self,
99+
type_args: &[crate::types::TypeArg],
100+
consts: &[(IncomingPort, ops::Const)],
101+
) -> ConstFoldResult {
102+
fold_trunc(type_args, consts, |f, log_width| {
103+
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into)
104+
})
105+
}
106+
}
107+
108+
struct ConvertU;
109+
110+
impl ConstFold for ConvertU {
111+
fn fold(
112+
&self,
113+
_type_args: &[crate::types::TypeArg],
114+
consts: &[(IncomingPort, ops::Const)],
115+
) -> ConstFoldResult {
116+
let u: &ConstIntU = get_input(consts)?;
117+
let f = u.value() as f64;
118+
Some(vec![(0.into(), ConstF64::new(f).into())])
119+
}
120+
}
121+
122+
struct ConvertS;
123+
124+
impl ConstFold for ConvertS {
125+
fn fold(
126+
&self,
127+
_type_args: &[crate::types::TypeArg],
128+
consts: &[(IncomingPort, ops::Const)],
129+
) -> ConstFoldResult {
130+
let u: &ConstIntS = get_input(consts)?;
131+
let f = u.value() as f64;
132+
Some(vec![(0.into(), ConstF64::new(f).into())])
133+
}
134+
}

0 commit comments

Comments
 (0)