From 865af3f31a25226e94ce9a51fa4e530204c65f69 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 3 Jan 2025 15:14:05 +0000 Subject: [PATCH] feat(hugr-llvm): Emit more int ops --- hugr-llvm/src/emit/test.rs | 9 ++ hugr-llvm/src/extension/int.rs | 195 ++++++++++++++++++++++++++------- hugr-llvm/src/test.rs | 7 ++ 3 files changed, 173 insertions(+), 38 deletions(-) diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index ef4ad6f4a..259abfa69 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -78,6 +78,15 @@ impl<'c> Emission<'c> { Ok(gv.as_int(false)) } + /// JIT and execute the function named `entry` in the inner module. + /// + /// That function must take no arguments and return an `i64`. + pub fn exec_i64(&self, entry: impl AsRef) -> Result { + let gv = self.exec_impl(entry)?; + let x: u64 = gv.as_int(true).try_into().unwrap(); + Ok(x as i64) + } + /// JIT and execute the function named `entry` in the inner module. /// /// That function must take no arguments and return an `f64`. diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index a66eae57c..3fe205d17 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -15,9 +15,8 @@ use inkwell::{ use crate::{ custom::CodegenExtsBuilder, emit::{ - emit_value, func::EmitFuncContext, ops::emit_custom_binary_op, ops::emit_custom_unary_op, - EmitOpArgs, - get_intrinsic, + emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op, + ops::emit_custom_unary_op, EmitOpArgs, }, types::TypingSession, }; @@ -97,42 +96,98 @@ fn emit_int_op<'c, H: HugrView>( .as_basic_value_enum()]) }), IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| { - let intr = get_intrinsic(ctx.get_current_module(), "llvm.abs.i64", - [ctx.iw_context().i64_type().as_basic_type_enum()])?; - let r = ctx.builder().build_call(intr, &[arg.into_int_value().into()], "")?.try_as_basic_value().unwrap_left(); + let intr = get_intrinsic( + ctx.get_current_module(), + "llvm.abs.i64", + [ctx.iw_context().i64_type().as_basic_type_enum()], + )?; + let true_ = ctx.iw_context().bool_type().const_all_ones(); + let r = ctx + .builder() + .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")? + .try_as_basic_value() + .unwrap_left(); Ok(vec![r]) }), IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { - let intr = get_intrinsic(ctx.get_current_module(), "llvm.smax.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?; - let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left(); + let intr = get_intrinsic( + ctx.get_current_module(), + "llvm.smax.i64", + [ctx.iw_context().i64_type().as_basic_type_enum()], + )?; + let r = ctx + .builder() + .build_call( + intr, + &[lhs.into_int_value().into(), rhs.into_int_value().into()], + "", + )? + .try_as_basic_value() + .unwrap_left(); Ok(vec![r]) }), IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { - let intr = get_intrinsic(ctx.get_current_module(), "llvm.umax.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?; - let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left(); + let intr = get_intrinsic( + ctx.get_current_module(), + "llvm.umax.i64", + [ctx.iw_context().i64_type().as_basic_type_enum()], + )?; + let r = ctx + .builder() + .build_call( + intr, + &[lhs.into_int_value().into(), rhs.into_int_value().into()], + "", + )? + .try_as_basic_value() + .unwrap_left(); Ok(vec![r]) }), IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { - let intr = get_intrinsic(ctx.get_current_module(), "llvm.smin.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?; - let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left(); + let intr = get_intrinsic( + ctx.get_current_module(), + "llvm.smin.i64", + [ctx.iw_context().i64_type().as_basic_type_enum()], + )?; + let r = ctx + .builder() + .build_call( + intr, + &[lhs.into_int_value().into(), rhs.into_int_value().into()], + "", + )? + .try_as_basic_value() + .unwrap_left(); Ok(vec![r]) }), IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { - let intr = get_intrinsic(ctx.get_current_module(), "llvm.umin.i64", [ctx.iw_context().i64_type().as_basic_type_enum()])?; - let r = ctx.builder().build_call(intr, &[lhs.into_int_value().into(), rhs.into_int_value().into()], "")?.try_as_basic_value().unwrap_left(); + let intr = get_intrinsic( + ctx.get_current_module(), + "llvm.umin.i64", + [ctx.iw_context().i64_type().as_basic_type_enum()], + )?; + let r = ctx + .builder() + .build_call( + intr, + &[lhs.into_int_value().into(), rhs.into_int_value().into()], + "", + )? + .try_as_basic_value() + .unwrap_left(); Ok(vec![r]) }), - IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs,rhs), _| { + IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { Ok(vec![ctx - .builder() - .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")? - .as_basic_value_enum()]) + .builder() + .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")? + .as_basic_value_enum()]) }), - IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs,rhs), _| { + IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { Ok(vec![ctx - .builder() - .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")? - .as_basic_value_enum()]) + .builder() + .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")? + .as_basic_value_enum()]) }), IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ), IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE), @@ -146,21 +201,21 @@ fn emit_int_op<'c, H: HugrView>( IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE), IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { Ok(vec![ctx - .builder() - .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")? - .as_basic_value_enum()]) + .builder() + .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")? + .as_basic_value_enum()]) }), IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { Ok(vec![ctx - .builder() - .build_or(lhs.into_int_value(), rhs.into_int_value(), "")? - .as_basic_value_enum()]) + .builder() + .build_or(lhs.into_int_value(), rhs.into_int_value(), "")? + .as_basic_value_enum()]) }), IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| { Ok(vec![ctx - .builder() - .build_not(arg.into_int_value(), "")? - .as_basic_value_enum()]) + .builder() + .build_not(arg.into_int_value(), "")? + .as_basic_value_enum()]) }), _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())), } @@ -228,7 +283,10 @@ mod test { use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::bool_t, - std_extensions::arithmetic::{int_ops, int_types::{ConstInt, INT_TYPES}}, + std_extensions::arithmetic::{ + int_ops, + int_types::{ConstInt, INT_TYPES}, + }, types::TypeRow, Hugr, }; @@ -275,18 +333,17 @@ mod test { fn test_binary_int_op_with_results_inputs( name: impl AsRef, log_width: u8, - inputs: &[u64], + inputs: Vec, output_types: impl Into, ) -> Hugr { - let ty = &INT_TYPES[log_width as usize]; SimpleHugrConfig::new() .with_ins(vec![]) .with_outs(output_types.into()) .with_extensions(STD_REG.clone()) .finish(|mut hugr_builder| { let mut input_wires = Vec::new(); - inputs.iter().for_each(|i| { - let w = hugr_builder.add_load_value(ConstInt::new_u(6, *i).unwrap()); + inputs.into_iter().for_each(|i| { + let w = hugr_builder.add_load_value(i); input_wires.push(w); }); let ext_op = int_ops::EXTENSION @@ -348,10 +405,72 @@ mod test { #[case::imax("imax_u", 1, 2, 2)] #[case::imax("imax_u", 2, 1, 2)] #[case::imax("imax_u", 2, 2, 2)] - fn test_exec_unsigned_op(mut exec_ctx: TestContext, #[case] op: String, #[case] lhs: u64, #[case] rhs: u64, #[case] result: u64) { + #[case::imin("imin_u", 1, 2, 1)] + #[case::imin("imin_u", 2, 1, 1)] + #[case::imin("imin_u", 2, 2, 2)] + #[case::ishl("ishl", 73, 1, 146)] + #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)] + #[case::ishr("ishr", 73, 1, 36)] + fn test_exec_unsigned_op( + mut exec_ctx: TestContext, + #[case] op: String, + #[case] lhs: u64, + #[case] rhs: u64, + #[case] result: u64, + ) { exec_ctx.add_extensions(add_int_extensions); let ty = &INT_TYPES[6].clone(); - let hugr = test_binary_int_op_with_results_inputs(op, 6, &[lhs,rhs], vec![ty.clone()]); + let args = vec![ + ConstInt::new_u(6, lhs).unwrap(), + ConstInt::new_u(6, rhs).unwrap(), + ]; + let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]); assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result); } + + #[rstest] + #[case::imax("imax_s", 1, 2, 2)] + #[case::imax("imax_s", 2, 1, 2)] + #[case::imax("imax_s", 2, 2, 2)] + #[case::imax("imax_s", -1, -2, -1)] + #[case::imax("imax_s", -2, -1, -1)] + #[case::imax("imax_s", -2, -2, -2)] + #[case::imin("imin_s", 1, 2, 1)] + #[case::imin("imin_s", 2, 1, 1)] + #[case::imin("imin_s", 2, 2, 2)] + #[case::imin("imin_s", -1, -2, -2)] + #[case::imin("imin_s", -2, -1, -2)] + #[case::imin("imin_s", -2, -2, -2)] + fn test_exec_signed_bin_op( + mut exec_ctx: TestContext, + #[case] op: String, + #[case] lhs: i64, + #[case] rhs: i64, + #[case] result: i64, + ) { + exec_ctx.add_extensions(add_int_extensions); + let ty = &INT_TYPES[6].clone(); + let args = vec![ + ConstInt::new_s(6, lhs).unwrap(), + ConstInt::new_s(6, rhs).unwrap(), + ]; + let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]); + assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result); + } + + #[rstest] + #[case::iabs("iabs", 42, 42)] + #[case::iabs("iabs", -42, 42)] + fn test_exec_signed_unary_op( + mut exec_ctx: TestContext, + #[case] op: String, + #[case] arg: i64, + #[case] result: i64, + ) { + exec_ctx.add_extensions(add_int_extensions); + let ty = &INT_TYPES[6].clone(); + let args = vec![ConstInt::new_s(6, arg).unwrap()]; + let hugr = test_binary_int_op_with_results_inputs(op, 6, args, vec![ty.clone()]); + assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result); + } } diff --git a/hugr-llvm/src/test.rs b/hugr-llvm/src/test.rs index 6268deee5..9fb67e5c7 100644 --- a/hugr-llvm/src/test.rs +++ b/hugr-llvm/src/test.rs @@ -160,6 +160,13 @@ impl TestContext { emission.exec_u64(entry_point).unwrap() } + pub fn exec_hugr_i64(&self, hugr: THugrView, entry_point: impl AsRef) -> i64 { + let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap(); + emission.verify().unwrap(); + + emission.exec_i64(entry_point).unwrap() + } + pub fn exec_hugr_f64(&self, hugr: THugrView, entry_point: impl AsRef) -> f64 { let emission = Emission::emit_hugr(hugr.fat_root().unwrap(), self.get_emit_hugr()).unwrap();