Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runtime checks for builder position #436

Merged
merged 12 commits into from
Sep 4, 2023
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ llvm13-0 = ["llvm-sys-130"]
llvm14-0 = ["llvm-sys-140"]
llvm15-0 = ["llvm-sys-150"]
llvm16-0 = ["llvm-sys-160"]
# Don't link aganist LLVM libraries. This is useful if another dependency is
# Don't link against LLVM libraries. This is useful if another dependency is
# installing LLVM. See llvm-sys for more details. We can't enable a single
# `no-llvm-linking` feature across the board of llvm versions, as it'll cause
# cargo to try and download and compile them all. See
Expand Down Expand Up @@ -140,6 +140,7 @@ llvm-sys-160 = { package = "llvm-sys", version = "160.1.0", optional = true }
once_cell = "1.16"
parking_lot = "0.12"
static-alloc = { version = "0.2", optional = true }
thiserror = "1.0.48"

[dev-dependencies]
regex = "1"
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ Documentation is automatically [deployed here](https://thedan64.github.io/inkwel
### Tari's [llvm-sys example](https://gitlab.com/taricorp/llvm-sys.rs/blob/6411edb2fed1a805b7ec5029afc9c3ae1cf6c842/examples/jit-function.rs) written in safe code<sup>1</sup> with Inkwell:

```rust
use inkwell::OptimizationLevel;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::execution_engine::{ExecutionEngine, JitFunction};
use inkwell::module::Module;
use inkwell::OptimizationLevel;

use std::error::Error;

/// Convenience type alias for the `sum` function.
Expand Down Expand Up @@ -89,16 +90,15 @@ impl<'ctx> CodeGen<'ctx> {
let y = function.get_nth_param(1)?.into_int_value();
let z = function.get_nth_param(2)?.into_int_value();

let sum = self.builder.build_int_add(x, y, "sum");
let sum = self.builder.build_int_add(sum, z, "sum");
let sum = self.builder.build_int_add(x, y, "sum").unwrap();
let sum = self.builder.build_int_add(sum, z, "sum").unwrap();

self.builder.build_return(Some(&sum));
self.builder.build_return(Some(&sum)).unwrap();

unsafe { self.execution_engine.get_function("sum").ok() }
}
}


fn main() -> Result<(), Box<dyn Error>> {
let context = Context::create();
let module = context.create_module("sum");
Expand Down
6 changes: 3 additions & 3 deletions examples/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ impl<'ctx> CodeGen<'ctx> {
let y = function.get_nth_param(1)?.into_int_value();
let z = function.get_nth_param(2)?.into_int_value();

let sum = self.builder.build_int_add(x, y, "sum");
let sum = self.builder.build_int_add(sum, z, "sum");
let sum = self.builder.build_int_add(x, y, "sum").unwrap();
let sum = self.builder.build_int_add(sum, z, "sum").unwrap();

self.builder.build_return(Some(&sum));
self.builder.build_return(Some(&sum)).unwrap();

unsafe { self.execution_engine.get_function("sum").ok() }
}
Expand Down
75 changes: 44 additions & 31 deletions examples/kaleidoscope/implementation_typed_pointers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
None => builder.position_at_end(entry),
}

builder.build_alloca(self.context.f64_type(), name)
builder.build_alloca(self.context.f64_type(), name).unwrap()
}

/// Compiles the specified `Expr` into an LLVM `FloatValue`.
Expand All @@ -864,7 +864,7 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
Expr::Number(nb) => Ok(self.context.f64_type().const_float(nb)),

Expr::Variable(ref name) => match self.variables.get(name.as_str()) {
Some(var) => Ok(self.builder.build_load(*var, name.as_str()).into_float_value()),
Some(var) => Ok(self.builder.build_load(*var, name.as_str()).unwrap().into_float_value()),
None => Err("Could not find a matching variable."),
},

Expand All @@ -884,7 +884,7 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {

let alloca = self.create_entry_block_alloca(var_name);

self.builder.build_store(alloca, initial_val);
self.builder.build_store(alloca, initial_val).unwrap();

if let Some(old_binding) = self.variables.remove(var_name) {
old_bindings.push(old_binding);
Expand All @@ -909,44 +909,48 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
ref right,
} => {
if op == '=' {
// handle assignement
// handle assignment
let var_name = match *left.borrow() {
Expr::Variable(ref var_name) => var_name,
_ => {
return Err("Expected variable as left-hand operator of assignement.");
return Err("Expected variable as left-hand operator of assignment.");
},
};

let var_val = self.compile_expr(right)?;
let var = self.variables.get(var_name.as_str()).ok_or("Undefined variable.")?;

self.builder.build_store(*var, var_val);
self.builder.build_store(*var, var_val).unwrap();

Ok(var_val)
} else {
let lhs = self.compile_expr(left)?;
let rhs = self.compile_expr(right)?;

match op {
'+' => Ok(self.builder.build_float_add(lhs, rhs, "tmpadd")),
'-' => Ok(self.builder.build_float_sub(lhs, rhs, "tmpsub")),
'*' => Ok(self.builder.build_float_mul(lhs, rhs, "tmpmul")),
'/' => Ok(self.builder.build_float_div(lhs, rhs, "tmpdiv")),
'+' => Ok(self.builder.build_float_add(lhs, rhs, "tmpadd").unwrap()),
'-' => Ok(self.builder.build_float_sub(lhs, rhs, "tmpsub").unwrap()),
'*' => Ok(self.builder.build_float_mul(lhs, rhs, "tmpmul").unwrap()),
'/' => Ok(self.builder.build_float_div(lhs, rhs, "tmpdiv").unwrap()),
'<' => Ok({
let cmp = self
.builder
.build_float_compare(FloatPredicate::ULT, lhs, rhs, "tmpcmp");
.build_float_compare(FloatPredicate::ULT, lhs, rhs, "tmpcmp")
.unwrap();

self.builder
.build_unsigned_int_to_float(cmp, self.context.f64_type(), "tmpbool")
.unwrap()
}),
'>' => Ok({
let cmp = self
.builder
.build_float_compare(FloatPredicate::ULT, rhs, lhs, "tmpcmp");
.build_float_compare(FloatPredicate::ULT, rhs, lhs, "tmpcmp")
.unwrap();

self.builder
.build_unsigned_int_to_float(cmp, self.context.f64_type(), "tmpbool")
.unwrap()
}),

custom => {
Expand All @@ -959,6 +963,7 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
match self
.builder
.build_call(fun, &[lhs.into(), rhs.into()], "tmpbin")
.unwrap()
.try_as_basic_value()
.left()
{
Expand Down Expand Up @@ -988,6 +993,7 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
match self
.builder
.build_call(fun, argsv.as_slice(), "tmp")
.unwrap()
.try_as_basic_value()
.left()
{
Expand All @@ -1010,33 +1016,34 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
let cond = self.compile_expr(cond)?;
let cond = self
.builder
.build_float_compare(FloatPredicate::ONE, cond, zero_const, "ifcond");
.build_float_compare(FloatPredicate::ONE, cond, zero_const, "ifcond")
.unwrap();

// build branch
let then_bb = self.context.append_basic_block(parent, "then");
let else_bb = self.context.append_basic_block(parent, "else");
let cont_bb = self.context.append_basic_block(parent, "ifcont");

self.builder.build_conditional_branch(cond, then_bb, else_bb);
self.builder.build_conditional_branch(cond, then_bb, else_bb).unwrap();

// build then block
self.builder.position_at_end(then_bb);
let then_val = self.compile_expr(consequence)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.build_unconditional_branch(cont_bb).unwrap();

let then_bb = self.builder.get_insert_block().unwrap();

// build else block
self.builder.position_at_end(else_bb);
let else_val = self.compile_expr(alternative)?;
self.builder.build_unconditional_branch(cont_bb);
self.builder.build_unconditional_branch(cont_bb).unwrap();

let else_bb = self.builder.get_insert_block().unwrap();

// emit merge block
self.builder.position_at_end(cont_bb);

let phi = self.builder.build_phi(self.context.f64_type(), "iftmp");
let phi = self.builder.build_phi(self.context.f64_type(), "iftmp").unwrap();

phi.add_incoming(&[(&then_val, then_bb), (&else_val, else_bb)]);

Expand All @@ -1055,12 +1062,12 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
let start_alloca = self.create_entry_block_alloca(var_name);
let start = self.compile_expr(start)?;

self.builder.build_store(start_alloca, start);
self.builder.build_store(start_alloca, start).unwrap();

// go from current block to loop block
let loop_bb = self.context.append_basic_block(parent, "loop");

self.builder.build_unconditional_branch(loop_bb);
self.builder.build_unconditional_branch(loop_bb).unwrap();
self.builder.position_at_end(loop_bb);

let old_val = self.variables.remove(var_name.as_str());
Expand All @@ -1079,22 +1086,28 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
// compile end condition
let end_cond = self.compile_expr(end)?;

let curr_var = self.builder.build_load(start_alloca, var_name);
let curr_var = self.builder.build_load(start_alloca, var_name).unwrap();
let next_var = self
.builder
.build_float_add(curr_var.into_float_value(), step, "nextvar");
.build_float_add(curr_var.into_float_value(), step, "nextvar")
.unwrap();

self.builder.build_store(start_alloca, next_var);
self.builder.build_store(start_alloca, next_var).unwrap();

let end_cond = self.builder.build_float_compare(
FloatPredicate::ONE,
end_cond,
self.context.f64_type().const_float(0.0),
"loopcond",
);
let end_cond = self
.builder
.build_float_compare(
FloatPredicate::ONE,
end_cond,
self.context.f64_type().const_float(0.0),
"loopcond",
)
.unwrap();
let after_bb = self.context.append_basic_block(parent, "afterloop");

self.builder.build_conditional_branch(end_cond, loop_bb, after_bb);
self.builder
.build_conditional_branch(end_cond, loop_bb, after_bb)
.unwrap();
self.builder.position_at_end(after_bb);

self.variables.remove(var_name);
Expand Down Expand Up @@ -1153,15 +1166,15 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> {
let arg_name = proto.args[i].as_str();
let alloca = self.create_entry_block_alloca(arg_name);

self.builder.build_store(alloca, arg);
self.builder.build_store(alloca, arg).unwrap();

self.variables.insert(proto.args[i].clone(), alloca);
}

// compile body
let body = self.compile_expr(self.function.body.as_ref().unwrap())?;

self.builder.build_return(Some(&body));
self.builder.build_return(Some(&body)).unwrap();

// return the whole thing after verification and optimization
if function.verify(true) {
Expand Down
10 changes: 5 additions & 5 deletions src/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<'ctx> BasicBlock<'ctx> {
unsafe { FunctionValue::new(LLVMGetBasicBlockParent(self.basic_block)) }
}

/// Gets the `BasicBlock` preceeding the current one, in its own scope, if any.
/// Gets the `BasicBlock` preceding the current one, in its own scope, if any.
///
/// # Example
/// ```no_run
Expand Down Expand Up @@ -295,9 +295,9 @@ impl<'ctx> BasicBlock<'ctx> {
/// let entry = context.append_basic_block(fn_value, "entry");
/// builder.position_at_end(entry);
///
/// let var = builder.build_alloca(i32_type, "some_number");
/// builder.build_store(var, i32_type.const_int(1 as u64, false));
/// builder.build_return(None);
/// let var = builder.build_alloca(i32_type, "some_number").unwrap();
/// builder.build_store(var, i32_type.const_int(1 as u64, false)).unwrap();
/// builder.build_return(None).unwrap();
///
/// let block = fn_value.get_first_basic_block().unwrap();
/// let some_number = block.get_instruction_with_name("some_number");
Expand Down Expand Up @@ -494,7 +494,7 @@ impl<'ctx> BasicBlock<'ctx> {
/// let bb1 = context.append_basic_block(fn_val, "bb1");
/// let bb2 = context.append_basic_block(fn_val, "bb2");
/// builder.position_at_end(entry);
/// let branch_inst = builder.build_unconditional_branch(bb1);
/// let branch_inst = builder.build_unconditional_branch(bb1).unwrap();
///
/// bb1.replace_all_uses_with(&bb2);
///
Expand Down
Loading