Skip to content

Commit

Permalink
feat(hugr-llvm): Emit more int ops (#1835)
Browse files Browse the repository at this point in the history
More work on #1702. Adds `ine`, `iabs`, `imax_{s,u}`, `imin_{s,u}`,
`ishl`, `ishr`, `ixor`, `ior`, `inot`, `iand`

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
croyzor and ss2165 authored Jan 6, 2025
1 parent cad7484 commit 2fa09ee
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 34 deletions.
9 changes: 9 additions & 0 deletions hugr-llvm/src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ impl<'c> Emission<'c> {
Ok(gv.as_int(false))
}

/// JIT and execute the function named `entry` in the inner module.
///
/// That function must take no arguments and return an `i64`.
pub fn exec_i64(&self, entry: impl AsRef<str>) -> Result<i64> {
let gv = self.exec_impl(entry)?;
let x: u64 = gv.as_int(true).try_into().unwrap();
Ok(x as i64)
}

/// JIT and execute the function named `entry` in the inner module.
///
/// That function must take no arguments and return an `f64`.
Expand Down
288 changes: 254 additions & 34 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use hugr_core::{
HugrView,
};
use inkwell::{
types::{BasicTypeEnum, IntType},
types::{BasicType, BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum},
};

use crate::{
custom::CodegenExtsBuilder,
emit::{
emit_value, func::EmitFuncContext, ops::emit_custom_binary_op, ops::emit_custom_unary_op,
EmitOpArgs,
emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op,
ops::emit_custom_unary_op, EmitOpArgs,
},
types::TypingSession,
};
Expand Down Expand Up @@ -95,7 +95,102 @@ fn emit_int_op<'c, H: HugrView>(
.build_int_neg(arg.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.abs.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let true_ = ctx.iw_context().bool_type().const_all_ones();
let r = ctx
.builder()
.build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.smax.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.umax.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.smin.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let intr = get_intrinsic(
ctx.get_current_module(),
"llvm.umin.i64",
[ctx.iw_context().i64_type().as_basic_type_enum()],
)?;
let r = ctx
.builder()
.build_call(
intr,
&[lhs.into_int_value().into(), rhs.into_int_value().into()],
"",
)?
.try_as_basic_value()
.unwrap_left();
Ok(vec![r])
}),
IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
.as_basic_value_enum()])
}),
IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
Expand All @@ -104,6 +199,30 @@ fn emit_int_op<'c, H: HugrView>(
IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
Ok(vec![ctx
.builder()
.build_not(arg.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
Ok(vec![ctx
.builder()
.build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
}
Expand Down Expand Up @@ -169,8 +288,11 @@ mod test {
use hugr_core::{
builder::{Dataflow, DataflowSubContainer},
extension::prelude::bool_t,
std_extensions::arithmetic::{int_ops, int_types::INT_TYPES},
types::TypeRow,
std_extensions::arithmetic::{
int_ops,
int_types::{ConstInt, INT_TYPES},
},
types::Type,
Hugr,
};
use rstest::rstest;
Expand All @@ -179,53 +301,52 @@ mod test {
check_emission,
emit::test::SimpleHugrConfig,
extension::int::add_int_extensions,
test::{llvm_ctx, TestContext},
test::{exec_ctx, llvm_ctx, TestContext},
};

fn test_binary_int_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
test_binary_int_op_with_results(name, log_width, vec![ty.clone()])
test_int_op_with_results::<2>(name, log_width, None, ty.clone())
}

fn test_binary_icmp_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
test_binary_int_op_with_results(name, log_width, vec![bool_t()])
test_int_op_with_results::<2>(name, log_width, None, bool_t())
}
fn test_binary_int_op_with_results(

fn test_int_op_with_results<const N: usize>(
// N is the number of inputs to the hugr
name: impl AsRef<str>,
log_width: u8,
output_types: impl Into<TypeRow>,
inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op
output_type: Type,
) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
let input_tys = if inputs.is_some() {
vec![]
} else {
itertools::repeat_n(ty.clone(), N).collect()
};
SimpleHugrConfig::new()
.with_ins(vec![ty.clone(), ty.clone()])
.with_outs(output_types.into())
.with_extensions(STD_REG.clone())
.finish(|mut hugr_builder| {
let [in1, in2] = hugr_builder.input_wires_arr();
let ext_op = int_ops::EXTENSION
.instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
.unwrap();
let outputs = hugr_builder
.add_dataflow_op(ext_op, [in1, in2])
.unwrap()
.outputs();
hugr_builder.finish_with_outputs(outputs).unwrap()
})
}

fn test_unary_int_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
SimpleHugrConfig::new()
.with_ins(vec![ty.clone()])
.with_outs(vec![ty.clone()])
.with_ins(input_tys)
.with_outs(vec![output_type])
.with_extensions(STD_REG.clone())
.finish(|mut hugr_builder| {
let [in1] = hugr_builder.input_wires_arr();
let input_wires = match inputs {
None => hugr_builder.input_wires_arr::<N>().to_vec(),
Some(inputs) => {
let mut input_wires = Vec::new();
inputs.into_iter().for_each(|i| {
let w = hugr_builder.add_load_value(i);
input_wires.push(w);
});
input_wires
}
};
let ext_op = int_ops::EXTENSION
.instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
.unwrap();
let outputs = hugr_builder
.add_dataflow_op(ext_op, [in1])
.add_dataflow_op(ext_op, input_wires.into_iter())
.unwrap()
.outputs();
hugr_builder.finish_with_outputs(outputs).unwrap()
Expand All @@ -235,7 +356,8 @@ mod test {
#[rstest]
fn test_neg_emission(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_unary_int_op("ineg", 2);
let ty = INT_TYPES[2].clone();
let hugr = test_int_op_with_results::<1>("ineg", 2, None, ty.clone());
check_emission!("ineg", hugr, llvm_ctx);
}

Expand All @@ -256,4 +378,102 @@ mod test {
let hugr = test_binary_icmp_op(op.clone(), width);
check_emission!(op.clone(), hugr, llvm_ctx);
}

#[rstest]
#[case::imax("imax_u", 1, 2, 2)]
#[case::imax("imax_u", 2, 1, 2)]
#[case::imax("imax_u", 2, 2, 2)]
#[case::imin("imin_u", 1, 2, 1)]
#[case::imin("imin_u", 2, 1, 1)]
#[case::imin("imin_u", 2, 2, 2)]
#[case::ishl("ishl", 73, 1, 146)]
#[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
#[case::ishr("ishr", 73, 1, 36)]
#[case::ior("ior", 6, 9, 15)]
#[case::ior("ior", 6, 15, 15)]
#[case::ixor("ixor", 6, 9, 15)]
#[case::ixor("ixor", 6, 15, 9)]
#[case::ixor("ixor", 15, 6, 9)]
#[case::iand("iand", 6, 15, 6)]
#[case::iand("iand", 15, 6, 6)]
#[case::iand("iand", 15, 15, 15)]
fn test_exec_unsigned_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] lhs: u64,
#[case] rhs: u64,
#[case] result: u64,
) {
exec_ctx.add_extensions(add_int_extensions);
let ty = &INT_TYPES[6].clone();
let inputs = [
ConstInt::new_u(6, lhs).unwrap(),
ConstInt::new_u(6, rhs).unwrap(),
];
let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone());
assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
}

#[rstest]
#[case::imax("imax_s", 1, 2, 2)]
#[case::imax("imax_s", 2, 1, 2)]
#[case::imax("imax_s", 2, 2, 2)]
#[case::imax("imax_s", -1, -2, -1)]
#[case::imax("imax_s", -2, -1, -1)]
#[case::imax("imax_s", -2, -2, -2)]
#[case::imin("imin_s", 1, 2, 1)]
#[case::imin("imin_s", 2, 1, 1)]
#[case::imin("imin_s", 2, 2, 2)]
#[case::imin("imin_s", -1, -2, -2)]
#[case::imin("imin_s", -2, -1, -2)]
#[case::imin("imin_s", -2, -2, -2)]
fn test_exec_signed_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] lhs: i64,
#[case] rhs: i64,
#[case] result: i64,
) {
exec_ctx.add_extensions(add_int_extensions);
let ty = &INT_TYPES[6].clone();
let inputs = [
ConstInt::new_s(6, lhs).unwrap(),
ConstInt::new_s(6, rhs).unwrap(),
];
let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone());
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}

#[rstest]
#[case::iabs("iabs", 42, 42)]
#[case::iabs("iabs", -42, 42)]
fn test_exec_signed_unary_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] arg: i64,
#[case] result: i64,
) {
exec_ctx.add_extensions(add_int_extensions);
let input = ConstInt::new_s(6, arg).unwrap();
let ty = INT_TYPES[6].clone();
let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone());
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}

#[rstest]
// LHS: Most significant bit (2^63), RHS: All the other bits combined
#[case::inot("inot", 9223372036854775808, 9223372036854775807)]
#[case::inot("inot", 1, 0)]
fn test_exec_unsigned_unary_op(
mut exec_ctx: TestContext,
#[case] op: String,
#[case] arg: u64,
#[case] result: u64,
) {
exec_ctx.add_extensions(add_int_extensions);
let input = ConstInt::new_u(6, arg).unwrap();
let ty = INT_TYPES[6].clone();
let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone());
assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
}
}
7 changes: 7 additions & 0 deletions hugr-llvm/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ impl TestContext {
emission.exec_u64(entry_point).unwrap()
}

pub fn exec_hugr_i64(&self, hugr: THugrView, entry_point: impl AsRef<str>) -> i64 {
let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap();
emission.verify().unwrap();

emission.exec_i64(entry_point).unwrap()
}

pub fn exec_hugr_f64(&self, hugr: THugrView, entry_point: impl AsRef<str>) -> f64 {
let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap();

Expand Down

0 comments on commit 2fa09ee

Please sign in to comment.