diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 3367907bd..e2e2e6619 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -15,8 +15,12 @@ use inkwell::{ use crate::{ custom::CodegenExtsBuilder, emit::{ - emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op, - ops::emit_custom_unary_op, EmitOpArgs, + emit_value, + func::EmitFuncContext, + get_intrinsic, + ops::emit_custom_binary_op, + ops::emit_custom_unary_op, + EmitOpArgs, }, types::TypingSession, }; @@ -47,6 +51,63 @@ fn emit_icmp<'c, H: HugrView>( }) } +/// Emit an ipow operation. This isn't directly supported in llvm, so we do a +/// loop over the exponent, performing `imul`s instead. +/// The insertion pointer is expected to be pointing to the end of `launch_bb`. +fn emit_ipow<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, ExtensionOp, H>, +) -> Result<()> { + emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| { + let done_bb = ctx.new_basic_block("done", None); + let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb)); + let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb)); + let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb)); + + let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?; + let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?; + ctx.builder().build_store(acc_p, lhs)?; + ctx.builder().build_store(exp_p, rhs)?; + ctx.builder().build_unconditional_branch(pow_bb)?; + + let zero = rhs.get_type().into_int_type().const_int(0, false); + // Assumes RHS type is the same as output type (which it should be) + let one = rhs.get_type().into_int_type().const_int(1, false); + + // Block for just returning one + ctx.builder().position_at_end(return_one_bb); + ctx.builder().build_store(acc_p, one)?; + ctx.builder().build_unconditional_branch(done_bb)?; + + ctx.builder().position_at_end(pow_bb); + let acc = ctx.builder().build_load(acc_p, "acc")?; + let exp = ctx.builder().build_load(exp_p, "exp")?; + + // Special case if the exponent is 0 or 1 + ctx.builder().build_switch( + exp.into_int_value(), + pow_body_bb, + &[(one, done_bb), (zero, return_one_bb)], + )?; + + // Block that performs one `imul` and modifies the values in the store + ctx.builder().position_at_end(pow_body_bb); + let new_acc = + ctx.builder() + .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?; + let new_exp = ctx + .builder() + .build_int_sub(exp.into_int_value(), one, "new_exp")?; + ctx.builder().build_store(acc_p, new_acc)?; + ctx.builder().build_store(exp_p, new_exp)?; + ctx.builder().build_unconditional_branch(pow_bb)?; + + ctx.builder().position_at_end(done_bb); + let result = ctx.builder().build_load(acc_p, "result")?; + Ok(vec![result.as_basic_value_enum()]) + }) +} + fn emit_int_op<'c, H: HugrView>( context: &mut EmitFuncContext<'c, '_, H>, args: EmitOpArgs<'c, '_, ExtensionOp, H>, @@ -223,6 +284,7 @@ fn emit_int_op<'c, H: HugrView>( .build_and(lhs.into_int_value(), rhs.into_int_value(), "")? .as_basic_value_enum()]) }), + IntOpDef::ipow => emit_ipow(context, args), _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())), } } @@ -364,6 +426,7 @@ mod test { #[rstest] #[case::iadd("iadd", 3)] #[case::isub("isub", 6)] + #[case::ipow("ipow", 3)] fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) { llvm_ctx.add_extensions(add_int_extensions); let hugr = test_binary_int_op(op.clone(), width); @@ -397,6 +460,9 @@ mod test { #[case::iand("iand", 6, 15, 6)] #[case::iand("iand", 15, 6, 6)] #[case::iand("iand", 15, 15, 15)] + #[case::ipow("ipow", 2, 3, 8)] + #[case::ipow("ipow", 42, 1, 42)] + #[case::ipow("ipow", 42, 0, 1)] fn test_exec_unsigned_bin_op( mut exec_ctx: TestContext, #[case] op: String, diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@llvm14.snap new file mode 100644 index 000000000..4646a42b5 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@llvm14.snap @@ -0,0 +1,41 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0, i8 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %acc_ptr = alloca i8, align 1 + %exp_ptr = alloca i8, align 1 + store i8 %0, i8* %acc_ptr, align 1 + store i8 %1, i8* %exp_ptr, align 1 + br label %pow + +pow: ; preds = %pow_body, %entry_block + %acc = load i8, i8* %acc_ptr, align 1 + %exp = load i8, i8* %exp_ptr, align 1 + switch i8 %exp, label %pow_body [ + i8 1, label %done + i8 0, label %power_of_zero + ] + +power_of_zero: ; preds = %pow + store i8 1, i8* %acc_ptr, align 1 + br label %done + +pow_body: ; preds = %pow + %new_acc = mul i8 %acc, %0 + %new_exp = sub i8 %exp, 1 + store i8 %new_acc, i8* %acc_ptr, align 1 + store i8 %new_exp, i8* %exp_ptr, align 1 + br label %pow + +done: ; preds = %pow, %power_of_zero + %result = load i8, i8* %acc_ptr, align 1 + ret i8 %result +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..5f2b7e68e --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__ipow@pre-mem2reg@llvm14.snap @@ -0,0 +1,53 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0, i8 %1) { +alloca_block: + %"0" = alloca i8, align 1 + %"2_0" = alloca i8, align 1 + %"2_1" = alloca i8, align 1 + %"4_0" = alloca i8, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8 %0, i8* %"2_0", align 1 + store i8 %1, i8* %"2_1", align 1 + %"2_01" = load i8, i8* %"2_0", align 1 + %"2_12" = load i8, i8* %"2_1", align 1 + %acc_ptr = alloca i8, align 1 + %exp_ptr = alloca i8, align 1 + store i8 %"2_01", i8* %acc_ptr, align 1 + store i8 %"2_12", i8* %exp_ptr, align 1 + br label %pow + +pow: ; preds = %pow_body, %entry_block + %acc = load i8, i8* %acc_ptr, align 1 + %exp = load i8, i8* %exp_ptr, align 1 + switch i8 %exp, label %pow_body [ + i8 1, label %done + i8 0, label %power_of_zero + ] + +power_of_zero: ; preds = %pow + store i8 1, i8* %acc_ptr, align 1 + br label %done + +pow_body: ; preds = %pow + %new_acc = mul i8 %acc, %"2_01" + %new_exp = sub i8 %exp, 1 + store i8 %new_acc, i8* %acc_ptr, align 1 + store i8 %new_exp, i8* %exp_ptr, align 1 + br label %pow + +done: ; preds = %pow, %power_of_zero + %result = load i8, i8* %acc_ptr, align 1 + store i8 %result, i8* %"4_0", align 1 + %"4_03" = load i8, i8* %"4_0", align 1 + store i8 %"4_03", i8* %"0", align 1 + %"04" = load i8, i8* %"0", align 1 + ret i8 %"04" +}