-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: constant folding for arithmetic conversion operations #720
Merged
Merged
Changes from all commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
bffed99
wip: constant folding
ss2165 1a27d54
start moving folding to op_def
ss2165 b84766b
thread through folding methods
ss2165 8ee49da
integer addition tests passing
ss2165 520de7c
remove FoldOutput
ss2165 1d656d6
Merge branch 'main' into feat/const-fold2
ss2165 9398d9d
refactor int folding to separate repo
ss2165 7b955a9
add tuple and sum constant folding
ss2165 6cb3c62
simplify test code
ss2165 0500624
wip: fold finder
ss2165 8f554e0
chore(deps): bump actions/upload-artifact from 3 to 4 (#751)
dependabot[bot] 215eb40
chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752)
dependabot[bot] ff26546
fix: case node should not have an external signature (#749)
ss2165 64b9199
refactor: move hugr equality check out for reuse
ss2165 6d7d440
feat: implement RemoveConst and RemoveConstIgnore
ss2165 cdde503
use remove rewrites while folding
ss2165 114524c
alllow candidate node specification in find_consts
ss2165 a087fbc
add exhaustive fold pass
ss2165 07768b2
refactor!: use enum op traits for floats + conversions
ss2165 9a81260
Merge branch 'refactor/fops-enum' into feat/const-fold2
ss2165 658adf4
add folding definitions for float ops
ss2165 2c0e75b
refactor: ERROR_CUSTOM_TYPE
ss2165 dc7ff13
refactor: const ConstF64::new
ss2165 aa73ab2
feat: implement folding for conversion ops
ss2165 a519f34
fixup! refactor: ERROR_CUSTOM_TYPE
ss2165 a7a4088
Merge branch 'main' into feat/const-fold2
ss2165 46075c2
implement bigger tests and fix unearthed bugs
ss2165 df854e8
Revert "refactor: move hugr equality check out for reuse"
ss2165 ba81e7b
feat: implement RemoveConst and RemoveConstIgnore
ss2165 09ce1c9
remove conversion foldin
ss2165 5a372c7
Merge branch 'main' into feat/const-fold-floats
ss2165 26bc5ff
add rust version guards
ss2165 b513ace
Merge branch 'feat/const-rewrites' into feat/const-fold-floats
ss2165 1348891
Revert "remove conversion foldin"
ss2165 5a71f75
docs: add public method docstrings
ss2165 6fa7eb9
add some docstrings and comments
ss2165 7381432
remove integer folding
ss2165 3bfda50
Revert "remove integer folding"
ss2165 0e0411f
remove unused imports
ss2165 8e88f3e
add docstrings and simplify
ss2165 dea6085
Merge branch 'feat/const-fold-floats' into feat/const-fold2
ss2165 48eb430
Merge branch 'feat/fold-ints' into feat/const-fold2
ss2165 41fa47a
Merge branch 'feat/const-fold-floats' into feat/fold-ints
ss2165 ccf789e
Merge branch 'feat/fold-ints' into feat/const-fold2
ss2165 0c060fb
Merge branch 'main' into feat/const-fold-floats
lmondada 4e24c28
Merge branch 'main' into feat/const-fold2
ss2165 4bca931
docs: Spec clarifications (#738)
cqc-alec 3193cdb
docs: Spec updates (#741)
cqc-alec d0513c4
docs: [spec] Remove references to causal cone and Order edges from In…
acl-cqc 89f1827
chore: remove rustversion (#764)
ss2165 4b6123e
ci: Setup release-plz and related files (#765)
aborgna-q 9500803
feat: implement RemoveConst and RemoveConstIgnore (#757)
ss2165 2c6abc6
Merge branch 'main' into feat/const-fold-floats
ss2165 905ef01
address minor review comments
ss2165 a6928e0
Merge branch 'feat/const-fold-floats' into feat/const-fold2
ss2165 6e36684
remove integer folding
ss2165 b0c686d
Merge branch 'main' into feat/const-fold2
ss2165 8f693ac
Update src/std_extensions/arithmetic/conversions.rs
ss2165 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 134 additions & 0 deletions
134
src/std_extensions/arithmetic/conversions/const_fold.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
use crate::{ | ||
extension::{ | ||
prelude::{sum_with_error, ConstError}, | ||
ConstFold, ConstFoldResult, OpDef, | ||
}, | ||
ops, | ||
std_extensions::arithmetic::{ | ||
float_types::ConstF64, | ||
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES}, | ||
}, | ||
types::ConstTypeError, | ||
values::{CustomConst, Value}, | ||
IncomingPort, | ||
}; | ||
|
||
use super::ConvertOpDef; | ||
|
||
pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) { | ||
use ConvertOpDef::*; | ||
|
||
match op { | ||
trunc_u => def.set_constant_folder(TruncU), | ||
trunc_s => def.set_constant_folder(TruncS), | ||
convert_u => def.set_constant_folder(ConvertU), | ||
convert_s => def.set_constant_folder(ConvertS), | ||
} | ||
} | ||
|
||
fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> { | ||
let [(_, c)] = consts else { | ||
return None; | ||
}; | ||
c.get_custom_value() | ||
} | ||
|
||
fn fold_trunc( | ||
type_args: &[crate::types::TypeArg], | ||
consts: &[(IncomingPort, ops::Const)], | ||
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>, | ||
) -> ConstFoldResult { | ||
let f: &ConstF64 = get_input(consts)?; | ||
let f = f.value(); | ||
let [arg] = type_args else { | ||
return None; | ||
}; | ||
let log_width = get_log_width(arg).ok()?; | ||
let int_type = INT_TYPES[log_width as usize].to_owned(); | ||
let sum_type = sum_with_error(int_type.clone()); | ||
let err_value = || { | ||
let err_val = ConstError { | ||
signal: 0, | ||
message: "Can't truncate non-finite float".to_string(), | ||
}; | ||
let sum_val = Value::Sum { | ||
tag: 1, | ||
value: Box::new(err_val.into()), | ||
}; | ||
|
||
ops::Const::new(sum_val, sum_type.clone()).unwrap() | ||
}; | ||
let out_const: ops::Const = if !f.is_finite() { | ||
err_value() | ||
} else { | ||
let cv = convert(f, log_width); | ||
if let Ok(cv) = cv { | ||
let sum_val = Value::Sum { | ||
tag: 0, | ||
value: Box::new(cv), | ||
}; | ||
|
||
ops::Const::new(sum_val, sum_type).unwrap() | ||
} else { | ||
err_value() | ||
} | ||
}; | ||
|
||
Some(vec![(0.into(), out_const)]) | ||
} | ||
|
||
struct TruncU; | ||
|
||
impl ConstFold for TruncU { | ||
fn fold( | ||
&self, | ||
type_args: &[crate::types::TypeArg], | ||
consts: &[(IncomingPort, ops::Const)], | ||
) -> ConstFoldResult { | ||
fold_trunc(type_args, consts, |f, log_width| { | ||
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into) | ||
}) | ||
} | ||
} | ||
|
||
struct TruncS; | ||
|
||
impl ConstFold for TruncS { | ||
fn fold( | ||
&self, | ||
type_args: &[crate::types::TypeArg], | ||
consts: &[(IncomingPort, ops::Const)], | ||
) -> ConstFoldResult { | ||
fold_trunc(type_args, consts, |f, log_width| { | ||
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into) | ||
}) | ||
} | ||
} | ||
|
||
struct ConvertU; | ||
|
||
impl ConstFold for ConvertU { | ||
fn fold( | ||
&self, | ||
_type_args: &[crate::types::TypeArg], | ||
consts: &[(IncomingPort, ops::Const)], | ||
) -> ConstFoldResult { | ||
let u: &ConstIntU = get_input(consts)?; | ||
let f = u.value() as f64; | ||
Some(vec![(0.into(), ConstF64::new(f).into())]) | ||
} | ||
} | ||
|
||
struct ConvertS; | ||
|
||
impl ConstFold for ConvertS { | ||
fn fold( | ||
&self, | ||
_type_args: &[crate::types::TypeArg], | ||
consts: &[(IncomingPort, ops::Const)], | ||
) -> ConstFoldResult { | ||
let u: &ConstIntS = get_input(consts)?; | ||
let f = u.value() as f64; | ||
Some(vec![(0.into(), ConstF64::new(f).into())]) | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if
f
is negative? Doesas u64
cause a panic?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as u64
maps all negative floats to 0, should I add a panic?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is a can of worms. The spec for
trunc_u
currently says "Returns an error when the float is non-finite or cannot be exactly stored in N bits". This should probably say "cannot be exactly represented as au<N>
". But what we're doing here is rounding, which is much more forgiving. WASM seems to say that the result is undefined if the number is negative, but rounded otherwise, which makes me think we should change the definition oftrunc_{u,s}
in the spec.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds like this could be a follow up issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll make one.