Skip to content

Commit

Permalink
feat: Emit ipow op
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Jan 6, 2025
1 parent 2fa09ee commit ef6ab9e
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 2 deletions.
70 changes: 68 additions & 2 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ use inkwell::{
use crate::{
custom::CodegenExtsBuilder,
emit::{
emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op,
ops::emit_custom_unary_op, EmitOpArgs,
emit_value,
func::EmitFuncContext,
get_intrinsic,
ops::emit_custom_binary_op,
ops::emit_custom_unary_op,
EmitOpArgs,
},
types::TypingSession,
};
Expand Down Expand Up @@ -47,6 +51,63 @@ fn emit_icmp<'c, H: HugrView>(
})
}

/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
/// loop over the exponent, performing `imul`s instead.
/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
fn emit_ipow<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let done_bb = ctx.new_basic_block("done", None);
let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));

let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
ctx.builder().build_store(acc_p, lhs)?;
ctx.builder().build_store(exp_p, rhs)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

let zero = rhs.get_type().into_int_type().const_int(0, false);
// Assumes RHS type is the same as output type (which it should be)
let one = rhs.get_type().into_int_type().const_int(1, false);

// Block for just returning one
ctx.builder().position_at_end(return_one_bb);
ctx.builder().build_store(acc_p, one)?;
ctx.builder().build_unconditional_branch(done_bb)?;

ctx.builder().position_at_end(pow_bb);
let acc = ctx.builder().build_load(acc_p, "acc")?;
let exp = ctx.builder().build_load(exp_p, "exp")?;

// Special case if the exponent is 0 or 1
ctx.builder().build_switch(
exp.into_int_value(),
pow_body_bb,
&[(one, done_bb), (zero, return_one_bb)],
)?;

// Block that performs one `imul` and modifies the values in the store
ctx.builder().position_at_end(pow_body_bb);
let new_acc =
ctx.builder()
.build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
let new_exp = ctx
.builder()
.build_int_sub(exp.into_int_value(), one, "new_exp")?;
ctx.builder().build_store(acc_p, new_acc)?;
ctx.builder().build_store(exp_p, new_exp)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

ctx.builder().position_at_end(done_bb);
let result = ctx.builder().build_load(acc_p, "result")?;
Ok(vec![result.as_basic_value_enum()])
})
}

fn emit_int_op<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
Expand Down Expand Up @@ -223,6 +284,7 @@ fn emit_int_op<'c, H: HugrView>(
.build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ipow => emit_ipow(context, args),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
}
Expand Down Expand Up @@ -364,6 +426,7 @@ mod test {
#[rstest]
#[case::iadd("iadd", 3)]
#[case::isub("isub", 6)]
#[case::ipow("ipow", 3)]
fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op(op.clone(), width);
Expand Down Expand Up @@ -397,6 +460,9 @@ mod test {
#[case::iand("iand", 6, 15, 6)]
#[case::iand("iand", 15, 6, 6)]
#[case::iand("iand", 15, 15, 15)]
#[case::ipow("ipow", 2, 3, 8)]
#[case::ipow("ipow", 42, 1, 42)]
#[case::ipow("ipow", 42, 0, 1)]
fn test_exec_unsigned_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
Expand Down
41 changes: 41 additions & 0 deletions hugr-llvm/src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %0, i8* %acc_ptr, align 1
store i8 %1, i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %0
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
ret i8 %result
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca i8, align 1
%"2_1" = alloca i8, align 1
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
store i8 %1, i8* %"2_1", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%"2_12" = load i8, i8* %"2_1", align 1
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %"2_01", i8* %acc_ptr, align 1
store i8 %"2_12", i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %"2_01"
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
store i8 %result, i8* %"4_0", align 1
%"4_03" = load i8, i8* %"4_0", align 1
store i8 %"4_03", i8* %"0", align 1
%"04" = load i8, i8* %"0", align 1
ret i8 %"04"
}

0 comments on commit ef6ab9e

Please sign in to comment.