Skip to content

Commit

Permalink
feat(hugr-llvm): Emit more int ops
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Jan 3, 2025
1 parent ffb8395 commit 865af3f
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 38 deletions.
9 changes: 9 additions & 0 deletions hugr-llvm/src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ impl<'c> Emission<'c> {
Ok(gv.as_int(false))
}

/// JIT and execute the function named `entry` in the inner module.
///
/// That function must take no arguments and return an `i64`.
pub fn exec_i64(&self, entry: impl AsRef<str>) -> Result<i64> {
let gv = self.exec_impl(entry)?;
let x: u64 = gv.as_int(true).try_into().unwrap();
Ok(x as i64)
}

/// JIT and execute the function named `entry` in the inner module.
///
/// That function must take no arguments and return an `f64`.
Expand Down
195 changes: 157 additions & 38 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ use inkwell::{
use crate::{
custom::CodegenExtsBuilder,
emit::{
emit_value, func::EmitFuncContext, ops::emit_custom_binary_op, ops::emit_custom_unary_op,
EmitOpArgs,
get_intrinsic,
emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op,
ops::emit_custom_unary_op, EmitOpArgs,
},
types::TypingSession,
};
Expand Down Expand Up @@ -97,42 +96,98 @@ fn emit_int_op<'c, H: HugrView>(
.as_basic_value_enum()])
}),
IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
let intr = get_intrinsic(ctx.get_current_module(), "llvm.abs.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()])?;
let r = ctx.builder().build_call(intr, &[arg.into_int_value().into()], "")?.try_as_basic_value().unwrap_left();
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.abs.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let true_ = ctx.iw_context().bool_type().const_all_ones();
let r = ctx
.builder()
.build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(ctx.get_current_module(), "llvm.smax.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?;
let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left();
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.smax.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(ctx.get_current_module(), "llvm.umax.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?;
let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left();
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.umax.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(ctx.get_current_module(), "llvm.smin.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?;
let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left();
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.smin.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(ctx.get_current_module(), "llvm.umin.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?;
let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left();
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.umin.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs,rhs), _| {
IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
.builder()
.build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs,rhs), _| {
IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
.as_basic_value_enum()])
.builder()
.build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
.as_basic_value_enum()])
}),
IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
Expand All @@ -146,21 +201,21 @@ fn emit_int_op<'c, H: HugrView>(
IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
.builder()
.build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
.builder()
.build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
Ok(vec![ctx
.builder()
.build_not(arg.into_int_value(), "")?
.as_basic_value_enum()])
.builder()
.build_not(arg.into_int_value(), "")?
.as_basic_value_enum()])
}),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
Expand Down Expand Up @@ -228,7 +283,10 @@ mod test {
use hugr_core::{
builder::{Dataflow, DataflowSubContainer},
extension::prelude::bool_t,
std_extensions::arithmetic::{int_ops, int_types::{ConstInt, INT_TYPES}},
std_extensions::arithmetic::{
int_ops,
int_types::{ConstInt, INT_TYPES},
},
types::TypeRow,
Hugr,
};
Expand Down Expand Up @@ -275,18 +333,17 @@ mod test {
fn test_binary_int_op_with_results_inputs(
name: impl AsRef<str>,
log_width: u8,
inputs: &[u64],
inputs: Vec<ConstInt>,
output_types: impl Into<TypeRow>,
) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
SimpleHugrConfig::new()
.with_ins(vec![])
.with_outs(output_types.into())
.with_extensions(STD_REG.clone())
.finish(|mut hugr_builder| {
let mut input_wires = Vec::new();
inputs.iter().for_each(|i| {
let w = hugr_builder.add_load_value(ConstInt::new_u(6, *i).unwrap());
inputs.into_iter().for_each(|i| {
let w = hugr_builder.add_load_value(i);
input_wires.push(w);
});
let ext_op = int_ops::EXTENSION
Expand Down Expand Up @@ -348,10 +405,72 @@ mod test {
#[case::imax("imax_u", 1, 2, 2)]
#[case::imax("imax_u", 2, 1, 2)]
#[case::imax("imax_u", 2, 2, 2)]
fn test_exec_unsigned_op(mut exec_ctx: TestContext, #[case] op: String, #[case] lhs: u64, #[case] rhs: u64, #[case] result: u64) {
#[case::imin("imin_u", 1, 2, 1)]
#[case::imin("imin_u", 2, 1, 1)]
#[case::imin("imin_u", 2, 2, 2)]
#[case::ishl("ishl", 73, 1, 146)]
#[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
#[case::ishr("ishr", 73, 1, 36)]
fn test_exec_unsigned_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] lhs: u64,
#[case] rhs: u64,
#[case] result: u64,
) {
exec_ctx.add_extensions(add_int_extensions);
let ty = &INT_TYPES[6].clone();
let hugr = test_binary_int_op_with_results_inputs(op, 6, &[lhs,rhs], vec![ty.clone()]);
let args = vec![
ConstInt::new_u(6, lhs).unwrap(),
ConstInt::new_u(6, rhs).unwrap(),
];
let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]);
assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
}

#[rstest]
#[case::imax("imax_s", 1, 2, 2)]
#[case::imax("imax_s", 2, 1, 2)]
#[case::imax("imax_s", 2, 2, 2)]
#[case::imax("imax_s", -1, -2, -1)]
#[case::imax("imax_s", -2, -1, -1)]
#[case::imax("imax_s", -2, -2, -2)]
#[case::imin("imin_s", 1, 2, 1)]
#[case::imin("imin_s", 2, 1, 1)]
#[case::imin("imin_s", 2, 2, 2)]
#[case::imin("imin_s", -1, -2, -2)]
#[case::imin("imin_s", -2, -1, -2)]
#[case::imin("imin_s", -2, -2, -2)]
fn test_exec_signed_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] lhs: i64,
#[case] rhs: i64,
#[case] result: i64,
) {
exec_ctx.add_extensions(add_int_extensions);
let ty = &INT_TYPES[6].clone();
let args = vec![
ConstInt::new_s(6, lhs).unwrap(),
ConstInt::new_s(6, rhs).unwrap(),
];
let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]);
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}

#[rstest]
#[case::iabs("iabs", 42, 42)]
#[case::iabs("iabs", -42, 42)]
fn test_exec_signed_unary_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] arg: i64,
#[case] result: i64,
) {
exec_ctx.add_extensions(add_int_extensions);
let ty = &INT_TYPES[6].clone();
let args = vec![ConstInt::new_s(6, arg).unwrap()];
let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]);
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}
}
7 changes: 7 additions & 0 deletions hugr-llvm/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ impl TestContext {
emission.exec_u64(entry_point).unwrap()
}

pub fn exec_hugr_i64(&self, hugr: THugrView, entry_point: impl AsRef<str>) -> i64 {
let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap();
emission.verify().unwrap();

emission.exec_i64(entry_point).unwrap()
}

pub fn exec_hugr_f64(&self, hugr: THugrView, entry_point: impl AsRef<str>) -> f64 {
let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap();

Expand Down

0 comments on commit 865af3f

Please sign in to comment.