Skip to content

Commit

Permalink
feat(builder): handle constants for bin ops
Browse files Browse the repository at this point in the history
  • Loading branch information
0xLucqs committed Jul 18, 2024
1 parent 212d1a2 commit 88c6edf
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 20 deletions.
4 changes: 4 additions & 0 deletions examples/fib/fib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[no_mangle]
pub fn fib(a: u32, b: u32, n: u32) -> u32 {
if n == 0 { b } else { fib(b, a + b, n - 1) }
}
4 changes: 4 additions & 0 deletions examples/increment/increment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[no_mangle]
pub fn increment(left: u128) -> u128 {
left + u128::MAX / 2
}
41 changes: 29 additions & 12 deletions src/builder/function/binary.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
use inkwell::values::{AnyValue, InstructionValue};
use inkwell::values::{AnyValue, InstructionValue, IntValue};

use super::CairoFunctionBuilder;
use crate::builder::get_name;

impl<'ctx> CairoFunctionBuilder<'ctx> {
fn extract_const_int_value(val: IntValue) -> String {
let const_val = val.print_to_string()
.to_string()
.split_whitespace()
.last()
.unwrap()
// Sanity check
.parse::<u128>()
.expect("Rust doesn't handle numbers bigger than u128");
let ty = val.get_type().print_to_string().to_string();
format!("{const_val}_{ty}")
}
/// Translates an LLVM binary operation to cairo. This can be anything that expects exactly 1
/// operator with a left and right operand.
pub fn process_binary_op(&mut self, instruction: &InstructionValue<'ctx>, operator: &str) -> String {
pub fn process_binary_int_op(&mut self, instruction: &InstructionValue<'ctx>, operator: &str) -> String {
// Get th left operand.
let left = unsafe {
instruction
Expand All @@ -29,18 +41,23 @@ impl<'ctx> CairoFunctionBuilder<'ctx> {
// Save the result variable in our mapping to be able to use later.
self.variables.insert(basic_value_enum, instr_name.clone());
}

// The operand is either a variable or a constant so either we get it from our mapping or it's
// unnamed as it's translated into a literal.
let left_name = self
.variables
.get(&left)
.cloned()
.unwrap_or_else(|| get_name(left.get_name()).unwrap_or("left".to_owned()));
let right_name = self
.variables
.get(&right)
.cloned()
.unwrap_or_else(|| get_name(right.get_name()).unwrap_or("right".to_owned()));
let left_name = self.variables.get(&left).cloned().unwrap_or_else(|| {
if right.into_int_value().is_const() {
Self::extract_const_int_value(left.into_int_value())
} else {
unreachable!("Left operand should either be a variable or a constant")
}
});
let right_name = self.variables.get(&right).cloned().unwrap_or_else(|| {
if right.into_int_value().is_const() {
Self::extract_const_int_value(right.into_int_value())
} else {
unreachable!("Left should either be a variable or a constant")
}
});

format!("let {} = {} {} {};", instr_name, left_name, operator, right_name)
}
Expand Down
13 changes: 12 additions & 1 deletion src/builder/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ impl Display for CairoFunction {
#[derive(Default, Clone, PartialEq, Debug)]
pub struct CairoFunctionBody(Vec<String>);

impl CairoFunctionBody {
pub fn new(body: Vec<String>) -> Self {
Self(body)
}
}

impl CairoFunctionBody {
pub fn push_line(&mut self, line: String) {
self.0.push(line)
Expand All @@ -47,7 +53,7 @@ pub struct CairoFunctionSignature {
}

impl CairoFunctionSignature {
fn new(name: String, parameters: Vec<CairoParameter>, return_type: String) -> Self {
pub fn new(name: String, parameters: Vec<CairoParameter>, return_type: String) -> Self {
Self { name, parameters: CairoParameters(parameters), return_type }
}
}
Expand All @@ -72,6 +78,11 @@ pub struct CairoParameter {
pub(crate) name: String,
pub(crate) ty: String,
}
impl CairoParameter {
pub fn new(name: String, ty: String) -> Self {
Self { name, ty }
}
}

impl Display for CairoParameter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
13 changes: 11 additions & 2 deletions src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ pub struct CairoBuilder<'ctx> {

#[derive(Default, Clone, PartialEq, Debug)]
pub struct CairoFunctions(Vec<CairoFunction>);
impl CairoFunctions {
pub fn functions(&self) -> &[CairoFunction] {
&self.0
}

pub fn count_functions(&self) -> usize {
self.0.len()
}
}

impl CairoFunctions {
pub fn push_function(&mut self, function: CairoFunction) {
Expand Down Expand Up @@ -51,8 +60,8 @@ impl<'ctx> CairoBuilder<'ctx> {
for instruction in bb.get_instructions() {
// Get the opcode of the instruction
let code_line = match instruction.get_opcode() {
InstructionOpcode::Add => function_builder.process_binary_op(&instruction, "+"),
InstructionOpcode::Sub => function_builder.process_binary_op(&instruction, "-"),
InstructionOpcode::Add => function_builder.process_binary_int_op(&instruction, "+"),
InstructionOpcode::Sub => function_builder.process_binary_int_op(&instruction, "-"),
InstructionOpcode::Return => function_builder.process_return(&instruction),
_ => "".to_owned(),
};
Expand Down
63 changes: 58 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::path::Path;

use builder::CairoBuilder;
use builder::{CairoBuilder, CairoFunctions};
use inkwell::context::Context;
use inkwell::memory_buffer::MemoryBuffer;

pub mod builder;

pub fn compile(path: &str) {
pub fn compile(path: &str) -> CairoFunctions {
// Initialize LLVM context
let context = Context::create();
// Parse the LLVM IR
Expand All @@ -21,16 +21,69 @@ pub fn compile(path: &str) {
let translated_func = builder.translate_function(&func);
builder.functions.push_function(translated_func);
});
// println!("Compiling LLVM IR {}", module.to_string());
println!("Cairo code:\n{}", builder.functions);
builder.functions
}

#[cfg(test)]
mod tests {
use builder::function::{CairoFunctionBody, CairoFunctionSignature, CairoParameter};

use super::*;

#[test]
fn it_compiles() {
compile("examples/add/add.ll");
println!("Cairo code:\n{}", compile("examples/increment/increment.ll"));
}

#[test]
fn test_add() {
let expected_name = "add".to_owned();
let expected_return_type = "i64".to_owned();
let expected_params = vec![
CairoParameter::new("left".to_owned(), "i64".to_owned()),
CairoParameter::new("right".to_owned(), "i64".to_owned()),
];
let code = compile("examples/add/add.ll");

// Check number of functions generated
assert_eq!(code.count_functions(), 1, "Add function should generate exactly 1 function");
let function = code.functions().first().unwrap();
// Check function signature
assert_eq!(
function.signature,
CairoFunctionSignature::new(expected_name, expected_params, expected_return_type)
);

// Check function body
assert_eq!(
function.body,
CairoFunctionBody::new(vec!["let _0 = right + left;".to_owned(), "return _0;".to_owned()])
);
}

#[test]
fn test_increment() {
let expected_name = "increment".to_owned();
let expected_return_type = "i128".to_owned();
let expected_params = vec![CairoParameter::new("left".to_owned(), "i128".to_owned())];
let code = compile("examples/increment/increment.ll");

// Check number of functions generated
assert_eq!(code.count_functions(), 1, "Add function should generate exactly 1 function");
let function = code.functions().first().unwrap();
// Check function signature
assert_eq!(
function.signature,
CairoFunctionSignature::new(expected_name, expected_params, expected_return_type)
);

// Check function body
assert_eq!(
function.body,
CairoFunctionBody::new(vec![
"let _0 = left + 170141183460469231731687303715884105727_i128;".to_owned(),
"return _0;".to_owned()
])
);
}
}

0 comments on commit 88c6edf

Please sign in to comment.