Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-llvm): Emit ipow #1839

Merged
merged 8 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ use inkwell::{
use crate::{
custom::CodegenExtsBuilder,
emit::{
emit_value, func::EmitFuncContext, get_intrinsic, ops::emit_custom_binary_op,
ops::emit_custom_unary_op, EmitOpArgs,
emit_value,
func::EmitFuncContext,
get_intrinsic,
ops::emit_custom_binary_op,
ops::emit_custom_unary_op,
EmitOpArgs,
},
types::TypingSession,
};
Expand Down Expand Up @@ -47,6 +51,63 @@ fn emit_icmp<'c, H: HugrView>(
})
}

/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
/// loop over the exponent, performing `imul`s instead.
/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
fn emit_ipow<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
let done_bb = ctx.new_basic_block("done", None);
let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));

let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
ctx.builder().build_store(acc_p, lhs)?;
ctx.builder().build_store(exp_p, rhs)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

let zero = rhs.get_type().into_int_type().const_int(0, false);
// Assumes RHS type is the same as output type (which it should be)
let one = rhs.get_type().into_int_type().const_int(1, false);

// Block for just returning one
ctx.builder().position_at_end(return_one_bb);
ctx.builder().build_store(acc_p, one)?;
ctx.builder().build_unconditional_branch(done_bb)?;

ctx.builder().position_at_end(pow_bb);
let acc = ctx.builder().build_load(acc_p, "acc")?;
let exp = ctx.builder().build_load(exp_p, "exp")?;

// Special case if the exponent is 0 or 1
ctx.builder().build_switch(
exp.into_int_value(),
pow_body_bb,
&[(one, done_bb), (zero, return_one_bb)],
)?;

// Block that performs one `imul` and modifies the values in the store
ctx.builder().position_at_end(pow_body_bb);
let new_acc =
ctx.builder()
.build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
let new_exp = ctx
.builder()
.build_int_sub(exp.into_int_value(), one, "new_exp")?;
ctx.builder().build_store(acc_p, new_acc)?;
ctx.builder().build_store(exp_p, new_exp)?;
ctx.builder().build_unconditional_branch(pow_bb)?;

ctx.builder().position_at_end(done_bb);
let result = ctx.builder().build_load(acc_p, "result")?;
Ok(vec![result.as_basic_value_enum()])
})
}

fn emit_int_op<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
Expand Down Expand Up @@ -223,6 +284,7 @@ fn emit_int_op<'c, H: HugrView>(
.build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
.as_basic_value_enum()])
}),
IntOpDef::ipow => emit_ipow(context, args),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
}
Expand Down Expand Up @@ -364,6 +426,7 @@ mod test {
#[rstest]
#[case::iadd("iadd", 3)]
#[case::isub("isub", 6)]
#[case::ipow("ipow", 3)]
fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op(op.clone(), width);
Expand Down Expand Up @@ -397,6 +460,9 @@ mod test {
#[case::iand("iand", 6, 15, 6)]
#[case::iand("iand", 15, 6, 6)]
#[case::iand("iand", 15, 15, 15)]
#[case::ipow("ipow", 2, 3, 8)]
#[case::ipow("ipow", 42, 1, 42)]
#[case::ipow("ipow", 42, 0, 1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some similar tests to test_exec_signed_bin_op for negative bases?

fn test_exec_unsigned_bin_op(
mut exec_ctx: TestContext,
#[case] op: String,
Expand Down
41 changes: 41 additions & 0 deletions hugr-llvm/src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %0, i8* %acc_ptr, align 1
store i8 %1, i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %0
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
ret i8 %result
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0, i8 %1) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca i8, align 1
%"2_1" = alloca i8, align 1
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
store i8 %1, i8* %"2_1", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%"2_12" = load i8, i8* %"2_1", align 1
%acc_ptr = alloca i8, align 1
%exp_ptr = alloca i8, align 1
store i8 %"2_01", i8* %acc_ptr, align 1
store i8 %"2_12", i8* %exp_ptr, align 1
br label %pow

pow: ; preds = %pow_body, %entry_block
%acc = load i8, i8* %acc_ptr, align 1
%exp = load i8, i8* %exp_ptr, align 1
switch i8 %exp, label %pow_body [
i8 1, label %done
i8 0, label %power_of_zero
]

power_of_zero: ; preds = %pow
store i8 1, i8* %acc_ptr, align 1
br label %done

pow_body: ; preds = %pow
%new_acc = mul i8 %acc, %"2_01"
%new_exp = sub i8 %exp, 1
store i8 %new_acc, i8* %acc_ptr, align 1
store i8 %new_exp, i8* %exp_ptr, align 1
br label %pow

done: ; preds = %pow, %power_of_zero
%result = load i8, i8* %acc_ptr, align 1
store i8 %result, i8* %"4_0", align 1
%"4_03" = load i8, i8* %"4_0", align 1
store i8 %"4_03", i8* %"0", align 1
%"04" = load i8, i8* %"0", align 1
ret i8 %"04"
}
Loading