Skip to content

Commit

Permalink
feat: Const::from_bool function (#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 authored Jan 9, 2024
1 parent 3930f10 commit 492daec
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ mod test {
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use crate::std_extensions::logic::{self, const_from_bool, NaryLogic};
use crate::std_extensions::logic::{self, NaryLogic};
use rstest::rstest;

/// int to constant
Expand Down Expand Up @@ -320,15 +320,15 @@ mod test {
) -> Result<(), Box<dyn std::error::Error>> {
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();

let ins = ins.map(|b| build.add_load_const(const_from_bool(b)).unwrap());
let ins = ins.map(|b| build.add_load_const(Const::from_bool(b)).unwrap());
let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?;

let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(logic_op.outputs(), &reg)?;
constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &const_from_bool(out));
assert_fully_folded(&h, &Const::from_bool(out));
Ok(())
}

Expand Down
10 changes: 10 additions & 0 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ impl Const {
Self::unit_sum(1, 2)
}

/// Generate a constant equivalent of a boolean,
/// see [`Const::true_val`] and [`Const::false_val`].
pub fn from_bool(b: bool) -> Self {
if b {
Self::true_val()
} else {
Self::false_val()
}
}

/// Constant "false" value, i.e. the first variant of Sum((), ()).
pub fn false_val() -> Self {
Self::unit_sum(0, 2)
Expand Down
6 changes: 1 addition & 5 deletions src/std_extensions/arithmetic/float_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,7 @@ impl ConstFold for CmpFold {
) -> ConstFoldResult {
let [f1, f2] = get_floats(consts)?;

let res = if (self.0)(f1, f2) {
ops::Const::true_val()
} else {
ops::Const::false_val()
};
let res = ops::Const::from_bool((self.0)(f1, f2));

Some(vec![(0.into(), res)])
}
Expand Down
12 changes: 2 additions & 10 deletions src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ impl MakeOpDef for NaryLogic {
NaryLogic::And => |consts: &_| {
let inps = read_inputs(consts)?;
let res = inps.into_iter().all(|x| x);
Some(vec![(0.into(), const_from_bool(res))])
Some(vec![(0.into(), ops::Const::from_bool(res))])
},
NaryLogic::Or => |consts: &_| {
let inps = read_inputs(consts)?;
let res = inps.into_iter().any(|x| x);
Some(vec![(0.into(), const_from_bool(res))])
Some(vec![(0.into(), ops::Const::from_bool(res))])
},
})
}
Expand Down Expand Up @@ -206,14 +206,6 @@ fn read_inputs(consts: &[(IncomingPort, ops::Const)]) -> Option<Vec<bool>> {
Some(inps)
}

pub(crate) fn const_from_bool(res: bool) -> ops::Const {
if res {
ops::Const::true_val()
} else {
ops::Const::false_val()
}
}

#[cfg(test)]
pub(crate) mod test {
use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME};
Expand Down

0 comments on commit 492daec

Please sign in to comment.