diff --git a/pliron-llvm/src/bin/llvm-opt.rs b/pliron-llvm/src/bin/llvm-opt.rs index 1524a20..7b0f642 100644 --- a/pliron-llvm/src/bin/llvm-opt.rs +++ b/pliron-llvm/src/bin/llvm-opt.rs @@ -3,8 +3,10 @@ use std::{path::PathBuf, process::ExitCode}; use inkwell::{context::Context as IWContext, module::Module as IWModule}; use clap::Parser; -use pliron::{arg_error_noloc, context::Context, printable::Printable, result::Result}; -use pliron_llvm::from_inkwell; +use pliron::{ + arg_error_noloc, context::Context, printable::Printable, result::Result, verify_error_noloc, +}; +use pliron_llvm::{from_inkwell, to_inkwell}; #[derive(Parser)] #[command(version, about="LLVM Optimizer", long_about = None)] @@ -30,6 +32,12 @@ fn run(cli: Cli, ctx: &mut Context) -> Result<()> { let pliron_module = from_inkwell::convert_module(ctx, &module)?; println!("{}", pliron_module.disp(ctx)); + let iwctx = &IWContext::create(); + let module = to_inkwell::convert_module(ctx, iwctx, pliron_module)?; + module + .verify() + .map_err(|err| verify_error_noloc!("{}", err.to_string()))?; + if cli.text_output { module .print_to_file(&cli.output) diff --git a/pliron-llvm/src/from_inkwell.rs b/pliron-llvm/src/from_inkwell.rs index dfbb9fa..e3a9aec 100644 --- a/pliron-llvm/src/from_inkwell.rs +++ b/pliron-llvm/src/from_inkwell.rs @@ -8,8 +8,8 @@ use inkwell::{ module::Module as IWModule, types::{AnyType, AnyTypeEnum}, values::{ - AnyValue, AnyValueEnum, BasicValueEnum, FunctionValue, InstructionOpcode, InstructionValue, - PhiValue, + AnyValue, AnyValueEnum, BasicValue, BasicValueEnum, FunctionValue, InstructionOpcode, + InstructionValue, PhiValue, }, IntPredicate, }; @@ -19,9 +19,10 @@ use pliron::{ attributes::IntegerAttr, op_interfaces::{OneRegionInterface, OneResultInterface, SingleBlockRegionInterface}, ops::{FuncOp, ModuleOp}, - types::{FunctionType, IntegerType, Signedness}, + types::{FunctionType, IntegerType}, }, context::{Context, Ptr}, + identifier::Identifier, input_err_noloc, input_error_noloc, op::Op, operation::Operation, @@ -38,7 +39,7 @@ use crate::{ ops::{ AShrOp, AddOp, AllocaOp, AndOp, BitcastOp, BrOp, CondBrOp, ConstantOp, ICmpOp, LShrOp, LoadOp, MulOp, OrOp, ReturnOp, SDivOp, SRemOp, ShlOp, StoreOp, SubOp, UDivOp, URemOp, - XorOp, + UndefOp, XorOp, }, types::{ArrayType, PointerType, StructErr, StructType, VoidType}, }; @@ -57,7 +58,7 @@ pub fn convert_type(ctx: &mut Context, ty: &AnyTypeEnum) -> Result> match ty { AnyTypeEnum::ArrayType(aty) => { let elem = convert_type(ctx, &aty.get_element_type().as_any_type_enum())?; - Ok(ArrayType::get(ctx, elem, aty.len() as usize).into()) + Ok(ArrayType::get(ctx, elem, aty.len() as u64).into()) } AnyTypeEnum::FloatType(_fty) => { todo!() @@ -89,14 +90,16 @@ pub fn convert_type(ctx: &mut Context, ty: &AnyTypeEnum) -> Result> let Some(name) = sty.get_name() else { return input_err_noloc!(StructErr::OpaqueAndAnonymousErr); }; - Ok(StructType::get_named(ctx, name.to_str_res()?, None)?.into()) + let name: Identifier = name.to_str_res()?.try_into()?; + Ok(StructType::get_named(ctx, name, None)?.into()) } else { let field_types: Vec<_> = sty .get_field_types_iter() .map(|ty| convert_type(ctx, &ty.as_any_type_enum())) .collect::>()?; if let Some(name) = sty.get_name() { - Ok(StructType::get_named(ctx, name.to_str_res()?, Some(field_types))?.into()) + let name: Identifier = name.to_str_res()?.try_into()?; + Ok(StructType::get_named(ctx, name, Some(field_types))?.into()) } else { Ok(StructType::get_unnamed(ctx, field_types).into()) } @@ -231,8 +234,11 @@ fn process_constant<'ctx>( BasicValueEnum::IntValue(iv) if iv.is_constant_int() => { // TODO: Zero extend or sign extend? let u64 = iv.get_zero_extended_constant().unwrap(); - let u64_ty = IntegerType::get(ctx, iv.get_type().get_bit_width(), Signedness::Signless); - let val_attr = IntegerAttr::new(u64_ty, ApInt::from_u64(u64)); + let int_ty = TypePtr::::from_ptr( + convert_type(ctx, &iv.get_type().as_any_type_enum())?, + ctx, + )?; + let val_attr = IntegerAttr::new(int_ty, ApInt::from_u64(u64)); let const_op = ConstantOp::new(ctx, Box::new(val_attr)); // Insert at the beginning of the entry block. const_op @@ -240,6 +246,15 @@ fn process_constant<'ctx>( .insert_at_front(cctx.entry_block, ctx); cctx.value_map.insert(any_val, const_op.get_result(ctx)); } + BasicValueEnum::IntValue(iv) if iv.is_undef() => { + let int_ty = convert_type(ctx, &iv.get_type().as_any_type_enum())?; + let undef_op = UndefOp::new(ctx, int_ty); + // Insert at the beginning of the entry block. + undef_op + .get_operation() + .insert_at_front(cctx.entry_block, ctx); + cctx.value_map.insert(any_val, undef_op.get_result(ctx)); + } BasicValueEnum::FloatValue(fv) if fv.is_const() => todo!(), BasicValueEnum::PointerValue(pv) if pv.is_const() => todo!(), BasicValueEnum::StructValue(_sv) => todo!(), @@ -286,10 +301,11 @@ fn get_operand(opds: &[T], idx: usize) -> Result { } /// Compute the arguments to be passed when branching from `src` to `dest`. -fn convert_branch_args( - cctx: &ConversionContext, - src_block: IWBasicBlock, - dst_block: IWBasicBlock, +fn convert_branch_args<'ctx>( + ctx: &mut Context, + cctx: &mut ConversionContext<'ctx>, + src_block: IWBasicBlock<'ctx>, + dst_block: IWBasicBlock<'ctx>, ) -> Result> { let mut args = vec![]; for inst in dst_block.get_instructions() { @@ -301,6 +317,7 @@ fn convert_branch_args( src_block.get_name().to_str_res().unwrap().to_string() )); }; + process_constant(ctx, cctx, incoming_val.as_basic_value_enum())?; let Some(m_incoming_val) = cctx.value_map.get(&incoming_val.as_any_value_enum()) else { return input_err_noloc!(ConversionErr::UndefinedValue( incoming_val @@ -358,26 +375,29 @@ fn convert_instruction<'ctx>( "Conditional branch must have two successors" ); let true_dest_opds = convert_branch_args( + ctx, cctx, inst.get_parent().unwrap(), - inst.get_operand(1).unwrap().unwrap_right(), + inst.get_operand(2).unwrap().unwrap_right(), )?; let false_dest_opds = convert_branch_args( + ctx, cctx, inst.get_parent().unwrap(), - inst.get_operand(2).unwrap().unwrap_right(), + inst.get_operand(1).unwrap().unwrap_right(), )?; Ok(CondBrOp::new( ctx, get_operand(opds, 0)?, - get_operand(succs, 0)?, - true_dest_opds, get_operand(succs, 1)?, + true_dest_opds, + get_operand(succs, 0)?, false_dest_opds, ) .get_operation()) } else { let dest_opds = convert_branch_args( + ctx, cctx, inst.get_parent().unwrap(), inst.get_operand(0).unwrap().unwrap_right(), @@ -445,7 +465,14 @@ fn convert_instruction<'ctx>( } InstructionOpcode::PtrToInt => todo!(), InstructionOpcode::Resume => todo!(), - InstructionOpcode::Return => Ok(ReturnOp::new(ctx, get_operand(opds, 0)?).get_operation()), + InstructionOpcode::Return => { + let retval = if inst.get_num_operands() == 1 { + Some(get_operand(opds, 0)?) + } else { + None + }; + Ok(ReturnOp::new(ctx, retval).get_operation()) + } InstructionOpcode::SDiv => { let (lhs, rhs) = (get_operand(opds, 0)?, get_operand(opds, 1)?); Ok(SDivOp::new(ctx, lhs, rhs).get_operation()) @@ -502,7 +529,7 @@ fn convert_block<'ctx>( ) -> Result<()> { for inst in block.get_instructions() { let inst_val = inst.as_any_value_enum(); - if inst_val.is_phi_value() { + if inst.get_opcode() == InstructionOpcode::Phi { let ty = convert_type(ctx, &inst.get_type().as_any_type_enum())?; let arg_idx = m_block.deref_mut(ctx).add_argument(ty); cctx.value_map diff --git a/pliron-llvm/src/lib.rs b/pliron-llvm/src/lib.rs index 75ed3a2..f601d37 100644 --- a/pliron-llvm/src/lib.rs +++ b/pliron-llvm/src/lib.rs @@ -9,6 +9,7 @@ pub mod attributes; pub mod from_inkwell; pub mod op_interfaces; pub mod ops; +pub mod to_inkwell; pub mod types; /// Register LLVM dialect, its ops, types and attributes into context. diff --git a/pliron-llvm/src/ops.rs b/pliron-llvm/src/ops.rs index 9a353fb..3e5f6a5 100644 --- a/pliron-llvm/src/ops.rs +++ b/pliron-llvm/src/ops.rs @@ -53,10 +53,28 @@ use super::{ #[def_op("llvm.return")] pub struct ReturnOp {} impl ReturnOp { - pub fn new(ctx: &mut Context, value: Value) -> Self { - let op = Operation::new(ctx, Self::get_opid_static(), vec![], vec![value], vec![], 0); + /// Create a new [ReturnOp] + pub fn new(ctx: &mut Context, value: Option) -> Self { + let op = Operation::new( + ctx, + Self::get_opid_static(), + vec![], + value.into_iter().collect(), + vec![], + 0, + ); ReturnOp { op } } + + /// Get the returned value, if it exists. + pub fn retval(&self, ctx: &Context) -> Option { + let op = &*self.get_operation().deref(ctx); + if op.get_num_operands() == 1 { + op.get_operand(0) + } else { + None + } + } } impl_canonical_syntax!(ReturnOp); impl_verify_succ!(ReturnOp); @@ -242,6 +260,16 @@ impl ICmpOp { .set(icmp_op::ATTR_KEY_PREDICATE.clone(), pred); ICmpOp { op } } + + /// Get the predicate + pub fn predicate(&self, ctx: &Context) -> ICmpPredicateAttr { + self.get_operation() + .deref(ctx) + .attributes + .get::(&icmp_op::ATTR_KEY_PREDICATE) + .unwrap() + .clone() + } } impl Verify for ICmpOp { @@ -330,7 +358,7 @@ impl Verify for AllocaOp { impl_op_interface!(OneResultInterface for AllocaOp {}); impl_op_interface!(OneOpdInterface for AllocaOp {}); impl_op_interface!(PointerTypeResult for AllocaOp { - fn result_pointee_type(&self,ctx: &Context) -> Ptr { + fn result_pointee_type(&self, ctx: &Context) -> Ptr { self.op .deref(ctx) .attributes @@ -478,6 +506,11 @@ impl CondBrOp { ), } } + + /// Get the condition value for the branch. + pub fn condition(&self, ctx: &Context) -> Value { + self.op.deref(ctx).get_operand(0).unwrap() + } } impl_canonical_syntax!(CondBrOp); impl_verify_succ!(CondBrOp); @@ -655,7 +688,7 @@ impl GetElementPtrOp { let GepIndex::Constant(i) = idx else { return arg_err_noloc!(GetElementPtrOpErr::IndicesErr); }; - if i as usize >= st.num_fields() { + if st.is_opaque() || i as usize >= st.num_fields() { return arg_err_noloc!(GetElementPtrOpErr::IndicesErr); } indexed_type_inner(ctx, st.field_type(i as usize), idx_itr) @@ -877,6 +910,35 @@ impl CallOpInterface for CallOp { impl_canonical_syntax!(CallOp); impl_verify_succ!(CallOp); +/// Undefined value of a type. +/// See MLIR's [llvm.mlir.undef](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirundef-llvmundefop). +/// +/// Results: +/// +/// | result | description | +/// |-----|-------| +/// | `result` | any type | +#[def_op("llvm.undef")] +pub struct UndefOp {} +impl_canonical_syntax!(UndefOp); +impl_verify_succ!(UndefOp); +impl_op_interface!(OneResultInterface for UndefOp {}); + +impl UndefOp { + /// Create a new [UndefOp]. + pub fn new(ctx: &mut Context, result_ty: Ptr) -> Self { + let op = Operation::new( + ctx, + Self::get_opid_static(), + vec![result_ty], + vec![], + vec![], + 0, + ); + UndefOp { op } + } +} + /// Numeric constant. /// See MLIR's [llvm.mlir.constant](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirconstant-llvmconstantop). /// @@ -977,5 +1039,6 @@ pub fn register(ctx: &mut Context) { StoreOp::register(ctx, StoreOp::parser_fn); CallOp::register(ctx, CallOp::parser_fn); ConstantOp::register(ctx, ConstantOp::parser_fn); + UndefOp::register(ctx, UndefOp::parser_fn); ReturnOp::register(ctx, ReturnOp::parser_fn); } diff --git a/pliron-llvm/src/to_inkwell.rs b/pliron-llvm/src/to_inkwell.rs new file mode 100644 index 0000000..d61604a --- /dev/null +++ b/pliron-llvm/src/to_inkwell.rs @@ -0,0 +1,769 @@ +//! Translate from pliron's LLVM dialec to [inkwell] + +use apint::ApInt; +use pliron::{ + basic_block::BasicBlock, + builtin::{ + attributes::{FloatAttr, IntegerAttr}, + op_interfaces::{ + BranchOpInterface, OneOpdInterface, OneRegionInterface, OneResultInterface, + SingleBlockRegionInterface, SymbolOpInterface, + }, + ops::{FuncOp, ModuleOp}, + types::{FunctionType, IntegerType}, + }, + common_traits::Named, + context::{Context, Ptr}, + decl_op_interface, decl_type_interface, impl_op_interface, impl_type_interface, input_err, + input_err_noloc, input_error, input_error_noloc, + linked_list::{ContainsLinkedList, LinkedList}, + location::Located, + op::{op_cast, Op}, + operation::Operation, + r#type::{type_cast, Type, TypeObj, TypePtr, Typed}, + result::Result, + use_def_lists::Value, + utils::traversals::region::topological_order, +}; + +use inkwell::{ + basic_block::BasicBlock as IWBasicBlock, + builder::Builder, + context::Context as IWContext, + module::Module as IWModule, + types::{ + self as iwtypes, AnyType, AnyTypeEnum, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, + IntType, + }, + values::{ + AnyValue, AnyValueEnum, BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue, + }, + IntPredicate, +}; + +use rustc_hash::FxHashMap; +use thiserror::Error; + +use crate::{ + attributes::ICmpPredicateAttr, + op_interfaces::PointerTypeResult, + ops::{ + AddOp, AllocaOp, AndOp, BitcastOp, BrOp, CondBrOp, ConstantOp, ICmpOp, LoadOp, MulOp, OrOp, + ReturnOp, SDivOp, SRemOp, ShlOp, StoreOp, SubOp, UDivOp, URemOp, UndefOp, XorOp, + }, + types::{ArrayType, PointerType, StructType, VoidType}, +}; + +/// Mapping from pliron entities to inkwell entities. +pub struct ConversionContext<'ctx> { + // A map from inkwell's Values to pliron's Values. + value_map: FxHashMap>, + // A map from inkwell's basic blocks to plirons'. + block_map: FxHashMap, IWBasicBlock<'ctx>>, + // The active LLVM / inkwell [Builder]. + builder: Builder<'ctx>, +} + +impl<'ctx> ConversionContext<'ctx> { + pub fn new(iwctx: &'ctx IWContext) -> Self { + Self { + value_map: FxHashMap::default(), + block_map: FxHashMap::default(), + builder: iwctx.create_builder(), + } + } +} + +#[derive(Error, Debug)] +pub enum ToLLVMErr { + #[error("Type {0} does not have a conversion to LLVM type implemented")] + MissingTypeConversion(String), + #[error("Operation {0} does not have a conversion to LLVM instruction implemented")] + MissingOpConversion(String), + #[error("Array element type must be a basic type")] + ArrayElementTypeNotBasic, + #[error("Function type's return type and argument types must all be a basic type")] + FunctionTypeComponentNotBasic, + #[error("Struct field type must be a basic type")] + StructFieldTypeNotBasic, + #[error("FuncOp must implement ToLLVMType and have FunctionType")] + FuncOpTypeErr, + #[error("PHI argument must be a basic type")] + PhiTypeNotBasic, + #[error("Definition for value {0} not seen yet")] + UndefinedValue(String), + #[error("Block definition {0} not seen yet")] + UndefinedBlock(String), + #[error("Integer instruction must have integer operands")] + IntOpValueErr, + #[error("AllocaOp pointee type must be basic type")] + AllocaOpTypeNotBasic, + #[error("AllocaOp size operand must be IntType")] + AllocaOpSizeNotInt, + #[error("BitcastOp operand must be basic")] + BitcastOpOpdNotBasic, + #[error("BitcastOp result must be a basic type")] + BitcastOpResultNotBasicType, + #[error("Number of block args in the source dialect equal the number of PHIs in target IR")] + NumBlockArgsNumPhisMismatch, + #[error("Value passed to block argument must be BasicValue")] + BranchArgNotBasic, + #[error("Conditional branch's condition must be an IntValue")] + CondBranchCondNotInt, + #[error("Load operand must be a PointerVal")] + LoadOpdNotPointer, + #[error("Loaded value must have BasicType")] + LoadedValueNotBasicType, + #[error("Stored value must have BasicType")] + StoredValueNotBasicType, + #[error("Store pointer must be PointerVal")] + StorePointerIncorrect, + #[error("ConstantOp must have integer or float value")] + ConstOpNotIntOrFloat, + #[error("ICmpOp operands must be IntValues")] + ICmpOpOpdNotInt, + #[error("ReturnOp must return a BasicType")] + ReturnOpOperandNotBasic, +} + +pub fn convert_ipredicate(pred: ICmpPredicateAttr) -> IntPredicate { + match pred { + ICmpPredicateAttr::EQ => IntPredicate::EQ, + ICmpPredicateAttr::NE => IntPredicate::NE, + ICmpPredicateAttr::UGT => IntPredicate::UGT, + ICmpPredicateAttr::UGE => IntPredicate::UGE, + ICmpPredicateAttr::ULT => IntPredicate::ULT, + ICmpPredicateAttr::ULE => IntPredicate::ULE, + ICmpPredicateAttr::SGT => IntPredicate::SGT, + ICmpPredicateAttr::SGE => IntPredicate::SGE, + ICmpPredicateAttr::SLT => IntPredicate::SLT, + ICmpPredicateAttr::SLE => IntPredicate::SLE, + } +} + +decl_type_interface! { + /// A type that implements this is convertible to an inkwell [AnyTypeEnum]. + ToLLVMType { + /// Convert from pliron [Type] to inkwell's [AnyTypeEnum]. + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext) -> Result>; + + fn verify(_type: &dyn Type, _ctx: &Context) -> Result<()> + where Self: Sized, + { + Ok(()) + } + } +} + +decl_op_interface! { + /// An [Op] that implements this is convertible to an inkwell [AnyValueEnum]. + ToLLVMValue { + /// Convert from pliron [Op] to inkwell's [AnyValueEnum]. + fn convert<'ctx>( + &self, ctx: &Context, + iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>) -> Result>; + + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> + where Self: Sized, + { + Ok(()) + } + } +} + +impl_type_interface!( + ToLLVMType for IntegerType { + fn convert<'ctx>( + &self, + _ctx: &Context, + iwctx: &'ctx IWContext) -> Result> + { + Ok(AnyTypeEnum::IntType(iwctx.custom_width_int_type(self.get_width()))) + } + } +); + +impl_type_interface!( + ToLLVMType for ArrayType { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext) -> Result> + { + let el_ty = convert_type(ctx, iwctx, self.elem_type())?; + let el_ty_basic: BasicTypeEnum = + TryFrom::try_from(el_ty) + .map_err(|_err| input_error_noloc!(ToLLVMErr::ArrayElementTypeNotBasic))?; + Ok(el_ty_basic.array_type( + self.size().try_into() + .expect("LLVM's ArrayType's size is u64, \ + but inkwell uses u32 and we can't fit it in u32")) + .as_any_type_enum() + ) + } + } +); + +impl_type_interface!( + ToLLVMType for FunctionType { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext) -> Result> + { + let args_tys: Vec<_> = + self.get_inputs() + .iter() + .map(|ty| { + let ty: BasicTypeEnum = + convert_type(ctx, iwctx, *ty)? + .try_into() + .map_err(|_err| { + input_error_noloc!(ToLLVMErr::FunctionTypeComponentNotBasic) + })?; + Ok(BasicMetadataTypeEnum::from(ty)) + }) + .collect::>()?; + let ret_ty: BasicTypeEnum = + self.get_results().first().map(|ty| convert_type(ctx, iwctx, *ty)) + .unwrap_or(Ok(iwctx.void_type().as_any_type_enum()))? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::FunctionTypeComponentNotBasic))?; + Ok(ret_ty.fn_type(&args_tys, false).as_any_type_enum()) + } + } +); + +impl_type_interface!( + ToLLVMType for VoidType { + fn convert<'ctx>( + &self, + _ctx: &Context, + iwctx: &'ctx IWContext) -> Result> + { + Ok(iwctx.void_type().as_any_type_enum()) + } + } +); + +impl_type_interface!( + ToLLVMType for PointerType { + fn convert<'ctx>( + &self, + _ctx: &Context, + _iwctx: &'ctx IWContext) -> Result> + { + // Ok(iwctx.ptr_type(0).as_any_type_enum()) + todo!() + } + } +); + +impl_type_interface!( + ToLLVMType for StructType { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext) -> Result> + { + if self.is_opaque() { + let name = self.name().expect("Opaqaue struct must have a name"); + Ok(iwctx.opaque_struct_type(name.as_str()).as_any_type_enum()) + } else { + let field_types = self + .fields() + .map(|fty| { + let ty = convert_type(ctx, iwctx, fty)?; + let ty_basic: BasicTypeEnum = ty.try_into() + .map_err(|_err| { + input_error_noloc!(ToLLVMErr::FunctionTypeComponentNotBasic) + })?; + Ok(ty_basic) + }) + .collect::>>()?; + if let Some(name) = self.name() { + let str_ty = iwctx.opaque_struct_type(name.as_str()); + str_ty.set_body(&field_types, false); + Ok(str_ty.as_any_type_enum()) + } else { + Ok(iwctx.struct_type(&field_types, false).as_any_type_enum()) + } + } + } + } +); + +/// Convert a pliron [Type] to inkwell [AnyTypeEnum]. +pub fn convert_type<'ctx>( + ctx: &Context, + iwctx: &'ctx IWContext, + ty: Ptr, +) -> Result> { + if let Some(converter) = type_cast::(&**ty.deref(ctx)) { + return converter.convert(ctx, iwctx); + } + input_err_noloc!(ToLLVMErr::MissingTypeConversion( + ty.deref(ctx).get_type_id().to_string() + )) +} + +fn convert_value_operand<'ctx>( + cctx: &mut ConversionContext<'ctx>, + ctx: &Context, + value: &Value, +) -> Result> { + match cctx.value_map.get(value) { + Some(v) => Ok(*v), + None => { + input_err_noloc!(ToLLVMErr::UndefinedValue(value.unique_name(ctx))) + } + } +} + +fn convert_block_operand<'ctx>( + cctx: &mut ConversionContext<'ctx>, + ctx: &Context, + block: Ptr, +) -> Result> { + match cctx.block_map.get(&block) { + Some(v) => Ok(*v), + None => { + input_err_noloc!(ToLLVMErr::UndefinedBlock(block.unique_name(ctx))) + } + } +} + +macro_rules! to_llvm_value_int_bin_op { + ( + $op_name:ident, $builder_method:ident + ) => { + impl_op_interface! (ToLLVMValue for $op_name { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let lhs = op.get_operand(0).unwrap(); + let rhs = op.get_operand(1).unwrap(); + let lhs: IntValue = convert_value_operand(cctx, ctx, &lhs)? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::IntOpValueErr))?; + let rhs: IntValue = convert_value_operand(cctx, ctx, &rhs)? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::IntOpValueErr))?; + let iw_op = cctx + .builder + .$builder_method(lhs, rhs, &self.get_result(ctx).unique_name(ctx)) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(iw_op.as_any_value_enum()) + } + }); + }; +} + +to_llvm_value_int_bin_op!(AddOp, build_int_add); +to_llvm_value_int_bin_op!(SubOp, build_int_sub); +to_llvm_value_int_bin_op!(MulOp, build_int_mul); +to_llvm_value_int_bin_op!(SDivOp, build_int_signed_div); +to_llvm_value_int_bin_op!(UDivOp, build_int_unsigned_div); +to_llvm_value_int_bin_op!(URemOp, build_int_unsigned_rem); +to_llvm_value_int_bin_op!(SRemOp, build_int_signed_rem); +to_llvm_value_int_bin_op!(AndOp, build_and); +to_llvm_value_int_bin_op!(OrOp, build_or); +to_llvm_value_int_bin_op!(XorOp, build_xor); +to_llvm_value_int_bin_op!(ShlOp, build_left_shift); + +impl_op_interface! (ToLLVMValue for AllocaOp { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let ty: BasicTypeEnum = convert_type(ctx, iwctx, self.result_pointee_type(ctx))? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::AllocaOpTypeNotBasic))?; + let size: IntValue = convert_value_operand(cctx, ctx, &self.get_operand(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::AllocaOpSizeNotInt))?; + let alloca_op = cctx + .builder + .build_array_alloca(ty, size, &self.get_result(ctx).unique_name(ctx)) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(alloca_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for BitcastOp { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let arg: BasicValueEnum = convert_value_operand(cctx, ctx, &self.get_operand(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::BitcastOpOpdNotBasic))?; + let ty: BasicTypeEnum = convert_type(ctx, iwctx, self.result_type(ctx))? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::BitcastOpResultNotBasicType))?; + let bitcast_op = cctx + .builder + .build_bitcast(arg, ty, &self.get_result(ctx).unique_name(ctx)) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(bitcast_op.as_any_value_enum()) + } +}); + +fn link_succ_operands_with_phis( + ctx: &Context, + cctx: &mut ConversionContext<'_>, + source_block: Ptr, + target_block: IWBasicBlock, + opds: Vec, +) -> Result<()> { + let mut phis = vec![]; + for inst in target_block.get_instructions() { + let Ok(phi) = TryInto::::try_into(inst) else { + break; + }; + phis.push(phi); + } + + if phis.len() != opds.len() { + return input_err!( + source_block.deref(ctx).loc(), + ToLLVMErr::NumBlockArgsNumPhisMismatch + ); + } + + let source_block_iw = convert_block_operand(cctx, ctx, source_block)?; + + for (idx, arg) in opds.iter().enumerate() { + let arg_iw: BasicValueEnum = convert_value_operand(cctx, ctx, arg)? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::BranchArgNotBasic))?; + phis[idx].add_incoming(&[(&arg_iw, source_block_iw)]); + } + Ok(()) +} + +impl_op_interface! (ToLLVMValue for BrOp { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let succ = op.get_successor(0).unwrap(); + let succ_iw = convert_block_operand(cctx, ctx, succ)?; + let branch_op = cctx + .builder + .build_unconditional_branch(succ_iw) + .map_err(|err| input_error!(op.loc(), err))?; + + // Link the arguments we pass to the block with the PHIs there. + link_succ_operands_with_phis( + ctx, + cctx, + op.get_container().expect("Unlinked operation"), + succ_iw, + self.successor_operands(ctx, 0), + )?; + + Ok(branch_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for CondBrOp { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let true_succ = op.get_successor(0).unwrap(); + let true_succ_iw = convert_block_operand(cctx, ctx, true_succ)?; + let false_succ = op.get_successor(1).unwrap(); + let false_succ_iw = convert_block_operand(cctx, ctx, false_succ)?; + let cond: IntValue = convert_value_operand(cctx, ctx, &self.condition(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::CondBranchCondNotInt))?; + + let branch_op = cctx + .builder + .build_conditional_branch(cond, true_succ_iw, false_succ_iw) + .map_err(|err| input_error!(op.loc(), err))?; + + // Link the arguments we pass to the block with the PHIs there. + link_succ_operands_with_phis( + ctx, + cctx, + op.get_container().expect("Unlinked operation"), + true_succ_iw, + self.successor_operands(ctx, 0), + )?; + link_succ_operands_with_phis( + ctx, + cctx, + op.get_container().expect("Unlinked operation"), + false_succ_iw, + self.successor_operands(ctx, 1), + )?; + + Ok(branch_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for LoadOp { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let pointee_ty: BasicTypeEnum = convert_type(ctx, iwctx, self.result_type(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::LoadedValueNotBasicType))?; + let ptr: PointerValue = convert_value_operand(cctx, ctx, &self.get_operand(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::LoadOpdNotPointer))?; + let load_op = cctx + .builder + .build_load(pointee_ty, ptr, &self.get_result(ctx).unique_name(ctx)) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(load_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for StoreOp { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let value: BasicValueEnum = convert_value_operand(cctx, ctx, &self.value_opd(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::StoredValueNotBasicType))?; + let ptr: PointerValue = convert_value_operand(cctx, ctx, &self.address_opd(ctx))? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::StorePointerIncorrect))?; + let store_op = cctx + .builder + .build_store(ptr, value) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(store_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for ICmpOp { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let predicate = convert_ipredicate(self.predicate(ctx)); + let lhs: IntValue = convert_value_operand(cctx, ctx, &op.get_operand(0).unwrap())? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::ICmpOpOpdNotInt))?; + let rhs: IntValue = convert_value_operand(cctx, ctx, &op.get_operand(1).unwrap())? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::ICmpOpOpdNotInt))?; + let icmp_op = cctx + .builder + .build_int_compare(predicate, lhs, rhs, &self.get_result(ctx).unique_name(ctx)) + .map_err(|err| input_error!(op.loc(), err))?; + Ok(icmp_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for ReturnOp { + fn convert<'ctx>( + &self, + ctx: &Context, + _iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let ret_op = + if let Some(retval) = self.retval(ctx) { + let retval: BasicValueEnum = convert_value_operand(cctx, ctx, &retval)? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::ReturnOpOperandNotBasic))?; + cctx.builder.build_return + (Some(&retval)).map_err(|err| input_error!(op.loc(), err))? + } else { + cctx.builder.build_return(None).map_err(|err| input_error!(op.loc(), err))? + }; + Ok(ret_op.as_any_value_enum()) + } +}); + +impl_op_interface! (ToLLVMValue for ConstantOp { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext, + _cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let op = self.get_operation().deref(ctx); + let value = self.get_value(ctx); + if let Some(int_val) = value.downcast_ref::() { + let int_ty = TypePtr::::from_ptr(int_val.get_type(ctx), ctx).unwrap(); + let int_ty_iw: iwtypes::IntType = + convert_type(ctx, iwctx, int_ty.into())? + .try_into() + .map_err(|_err| input_error!(op.loc(), ToLLVMErr::ConstOpNotIntOrFloat))?; + let ap_int_val: ApInt = int_val.clone().into(); + let const_val = + int_ty_iw.const_int(ap_int_val.resize_to_u64(), false); + Ok(const_val.as_any_value_enum()) + } else if let Some(_float_val) = value.downcast_ref::() { + todo!() + } else { + input_err!(op.loc(), ToLLVMErr::ConstOpNotIntOrFloat) + } + } +}); + +impl_op_interface! (ToLLVMValue for UndefOp { + fn convert<'ctx>( + &self, + ctx: &Context, + iwctx: &'ctx IWContext, + _cctx: &mut ConversionContext<'ctx>, + ) -> Result> { + let ty = convert_type(ctx, iwctx, self.result_type(ctx))?; + if let Ok(int_ty) = TryInto::::try_into(ty) { + Ok(int_ty.get_undef().as_any_value_enum()) + } else { + todo!() + } + } +}); + +/// Conver a pliron [BasicBlock] to inkwell [BasicBlock][IWBasicBlock]. +fn convert_block<'ctx>( + ctx: &Context, + iwctx: &'ctx IWContext, + cctx: &mut ConversionContext<'ctx>, + block: Ptr, +) -> Result<()> { + let iw_block = cctx.block_map[&block]; + cctx.builder.position_at_end(iw_block); + + for (op, loc) in block + .deref(ctx) + .iter(ctx) + .map(|op| (Operation::get_op(op, ctx), op.deref(ctx).loc())) + { + let Some(op_conv) = op_cast::(&*op) else { + return input_err!( + loc, + ToLLVMErr::MissingOpConversion(op.get_opid().to_string()) + ); + }; + let op_iw = op_conv.convert(ctx, iwctx, cctx)?; + // LLVM instructions have at most one result. + if let Some(res) = op.get_operation().deref(ctx).get_result(0) { + cctx.value_map.insert(res, op_iw); + } + } + + Ok(()) +} + +/// Convert a pliron [FuncOp] to inkwell [FunctionValue] +fn convert_function<'ctx>( + ctx: &Context, + iwctx: &'ctx IWContext, + module_iw: &IWModule<'ctx>, + func_op: FuncOp, +) -> Result> { + let func_ty = func_op.get_type(ctx).deref(ctx); + let func_ty_to_iw = type_cast::(&**func_ty) + .ok_or(input_error_noloc!(ToLLVMErr::FuncOpTypeErr))?; + let ty = func_ty_to_iw.convert(ctx, iwctx)?; + let func_ty: iwtypes::FunctionType = ty + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::FuncOpTypeErr))?; + + let cctx = &mut ConversionContext::new(iwctx); + let func_iw = module_iw.add_function(&func_op.get_symbol_name(ctx), func_ty, None); + + // Map all blocks, staring with entry. + let mut block_iter = func_op.get_region(ctx).deref(ctx).iter(ctx); + { + let entry = block_iter.next().expect("Missing entry block"); + assert!(entry == func_op.get_entry_block(ctx)); + // Map entry block arguments to inkwell function arguments. + for (arg_idx, arg) in entry.deref(ctx).arguments().enumerate() { + cctx.value_map.insert( + arg, + func_iw + .get_nth_param(arg_idx.try_into().unwrap()) + .unwrap() + .as_any_value_enum(), + ); + } + let iw_entry_block = iwctx.append_basic_block(func_iw, &entry.deref(ctx).unique_name(ctx)); + cctx.block_map.insert(entry, iw_entry_block); + } + for block in block_iter { + let iw_block = iwctx.append_basic_block(func_iw, &block.deref(ctx).unique_name(ctx)); + let builder = iwctx.create_builder(); + builder.position_at_end(iw_block); + for arg in block.deref(ctx).arguments() { + let arg_type: BasicTypeEnum = + convert_type(ctx, iwctx, arg.get_type(ctx))? + .try_into() + .map_err(|_err| input_error_noloc!(ToLLVMErr::PhiTypeNotBasic))?; + let phi = builder + .build_phi(arg_type, &arg.unique_name(ctx)) + .map_err(|err| input_error_noloc!(err))?; + cctx.value_map.insert(arg, phi.as_any_value_enum()); + } + cctx.block_map.insert(block, iw_block); + } + + // Convert within every block. + for block in topological_order(ctx, func_op.get_region(ctx)) { + convert_block(ctx, iwctx, cctx, block)?; + } + + Ok(func_iw) +} + +/// Convert pliron [ModuleOp] to inkwell [Module](IWModule). +pub fn convert_module<'ctx>( + ctx: &Context, + iwctx: &'ctx IWContext, + module: ModuleOp, +) -> Result> { + let mod_name = module.get_symbol_name(ctx); + let module_iw = iwctx.create_module(&mod_name); + + for op in module.get_body(ctx, 0).deref(ctx).iter(ctx) { + if let Some(func_op) = Operation::get_op(op, ctx).downcast_ref::() { + convert_function(ctx, iwctx, &module_iw, *func_op)?; + } + // TODO: Globals + } + + Ok(module_iw) +} diff --git a/pliron-llvm/src/types.rs b/pliron-llvm/src/types.rs index fdd16a0..bf0426d 100644 --- a/pliron-llvm/src/types.rs +++ b/pliron-llvm/src/types.rs @@ -35,7 +35,7 @@ use std::hash::Hash; #[def_type("llvm.struct")] #[derive(Debug)] pub struct StructType { - name: Option, + name: Option, fields: Option>>, } @@ -52,12 +52,12 @@ impl StructType { /// the named struct already exists and has its body set. pub fn get_named( ctx: &mut Context, - name: &str, + name: Identifier, fields: Option>>, ) -> Result> { let self_ptr = Type::register_instance( StructType { - name: Some(name.to_string()), + name: Some(name.clone()), // Uniquing happens only on the name, so this doesn't matter. fields: None, }, @@ -66,7 +66,7 @@ impl StructType { // Verify that we created a new or equivalent existing type. let mut self_ref = self_ptr.to_ptr().deref_mut(ctx); let self_ref = self_ref.downcast_mut::().unwrap(); - assert!(self_ref.name.as_ref().unwrap() == name); + assert!(self_ref.name.as_ref().unwrap() == &name); if let Some(fields) = fields { // We've been provided fields to be set. if let Some(existing_fields) = &self_ref.fields { @@ -95,10 +95,10 @@ impl StructType { } /// If a named struct already exists, get a pointer to it. - pub fn get_existing_named(ctx: &Context, name: &str) -> Option> { + pub fn get_existing_named(ctx: &Context, name: &Identifier) -> Option> { Type::get_instance( StructType { - name: Some(name.to_string()), + name: Some(name.clone()), // Named structs are uniqued only on the name. fields: None, }, @@ -127,14 +127,33 @@ impl StructType { self.name.is_some() } + /// Get this struct's name, if it has one. + pub fn name(&self) -> Option { + self.name.clone() + } + /// Get type of the idx'th field. pub fn field_type(&self, field_idx: usize) -> Ptr { - self.fields.as_ref().unwrap()[field_idx] + self.fields + .as_ref() + .expect("field_type shouldn't be called on opaque types")[field_idx] } /// Get the number of fields this struct has pub fn num_fields(&self) -> usize { - self.fields.as_ref().unwrap().len() + self.fields + .as_ref() + .expect("num_fields shouldn't be called on opaque types") + .len() + } + + /// Get an iterator over the fields of this struct + pub fn fields(&self) -> impl Iterator> + '_ { + self.fields + .as_ref() + .expect("fields shouldn't be called on opaque types") + .iter() + .cloned() } } @@ -170,7 +189,7 @@ impl Printable for StructType { thread_local! { // We use a vec instead of a HashMap hoping that this isn't // going to be large, in which case vec would be faster. - static IN_PRINTING: RefCell> = const { RefCell::new(vec![]) }; + static IN_PRINTING: RefCell> = const { RefCell::new(vec![]) }; } if let Some(name) = &self.name { let in_printing = IN_PRINTING.with(|f| f.borrow().contains(name)); @@ -256,7 +275,7 @@ impl Parsable for StructType { let (loc, name_opt, body_opt) = struct_parser.parse_stream(state_stream).into_result()?.0; let ctx = &mut state_stream.state.ctx; if let Some(name) = name_opt { - StructType::get_named(ctx, &name, body_opt) + StructType::get_named(ctx, name, body_opt) .map_err(|mut err| { err.set_loc(loc); err @@ -323,16 +342,16 @@ impl_verify_succ!(PointerType); #[derive(Hash, PartialEq, Eq, Debug)] pub struct ArrayType { elem: Ptr, - size: usize, + size: u64, } impl ArrayType { /// Get or create a new array type. - pub fn get(ctx: &mut Context, elem: Ptr, size: usize) -> TypePtr { + pub fn get(ctx: &mut Context, elem: Ptr, size: u64) -> TypePtr { Type::register_instance(ArrayType { elem, size }, ctx) } /// Get, if it already exists, an array type. - pub fn get_existing(ctx: &Context, elem: Ptr, size: usize) -> Option> { + pub fn get_existing(ctx: &Context, elem: Ptr, size: u64) -> Option> { Type::get_instance(ArrayType { elem, size }, ctx) } @@ -342,7 +361,7 @@ impl ArrayType { } /// Get array size. - pub fn size(&self) -> usize { + pub fn size(&self) -> u64 { self.size } } @@ -372,7 +391,7 @@ impl Parsable for ArrayType { combine::between( token('['), token(']'), - spaced((int_parser::(), spaced(token('x')), type_parser())), + spaced((int_parser::(), spaced(token('x')), type_parser())), ) .parse_stream(state_stream) .map(|(size, _, elem)| ArrayType::get(state_stream.state.ctx, elem, size)) @@ -494,6 +513,7 @@ mod tests { types::{IntegerType, Signedness}, }, context::{Context, Ptr}, + identifier::Identifier, impl_verify_succ, irfmt::parsers::{spaced, type_parser}, location, @@ -507,9 +527,12 @@ mod tests { fn test_struct() -> Result<()> { let mut ctx = Context::new(); let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signless).into(); + let linked_list_id = Identifier::try_new("LinkedList".into()).unwrap(); + let linked_list_2_id = Identifier::try_new("LinkedList2".into()).unwrap(); // Create an opaque struct since we want a recursive type. - let list_struct: Ptr = StructType::get_named(&mut ctx, "LinkedList", None)?.into(); + let list_struct: Ptr = + StructType::get_named(&mut ctx, linked_list_id.clone(), None)?.into(); assert!(list_struct .deref(&ctx) .downcast_ref::() @@ -518,18 +541,18 @@ mod tests { let list_struct_ptr = TypedPointerType::get(&mut ctx, list_struct).into(); let fields = vec![int64_ptr, list_struct_ptr]; // Set the struct body now. - StructType::get_named(&mut ctx, "LinkedList", Some(fields))?; + StructType::get_named(&mut ctx, linked_list_id.clone(), Some(fields))?; assert!(!list_struct .deref(&ctx) .downcast_ref::() .unwrap() .is_opaque()); - let list_struct_2 = StructType::get_existing_named(&ctx, "LinkedList") + let list_struct_2 = StructType::get_existing_named(&ctx, &linked_list_id) .unwrap() .into(); assert!(list_struct == list_struct_2); - assert!(StructType::get_existing_named(&ctx, "LinkedList2").is_none()); + assert!(StructType::get_existing_named(&ctx, &linked_list_2_id).is_none()); assert_eq!( list_struct.disp(&ctx).to_string(), diff --git a/src/basic_block.rs b/src/basic_block.rs index 5820b78..9778dcd 100644 --- a/src/basic_block.rs +++ b/src/basic_block.rs @@ -17,14 +17,14 @@ use crate::{ printers::{iter_with_sep, list_with_sep}, }, linked_list::{private, ContainsLinkedList, LinkedList}, - location::Located, + location::{Located, Location}, operation::Operation, parsable::{self, IntoParseResult, Parsable, ParseResult}, printable::{self, indented_nl, ListSeparator, Printable}, r#type::{TypeObj, Typed}, region::Region, result::Result, - use_def_lists::{DefNode, Value}, + use_def_lists::{DefNode, Use, Value}, vec_exns::VecExtns, }; @@ -105,6 +105,7 @@ pub struct BasicBlock { region_links: RegionLinks, /// A dictionary of attributes. pub attributes: AttributeDict, + loc: Location, } impl Named for BasicBlock { @@ -131,6 +132,7 @@ impl BasicBlock { preds: DefNode::new(), region_links: RegionLinks::default(), attributes: AttributeDict::default(), + loc: Location::Unknown, }; let newblock = Self::alloc(ctx, f); // Let's update the args of the new block. Easier to do it here than during creation. @@ -154,6 +156,11 @@ impl BasicBlock { self.args.get(arg_idx).map(|arg| arg.into()) } + /// Get an iterator over the arguments + pub fn arguments(&self) -> impl Iterator + '_ { + self.args.iter().map(Into::into) + } + /// Add a new argument with specified type. Returns idx at which it was added. pub fn add_argument(&mut self, ty: Ptr) -> usize { self.args.push_back_with(|arg_idx| BlockArgument { @@ -189,6 +196,63 @@ impl BasicBlock { self.preds.num_uses() } + /// Get all predecessors of this block. + pub fn preds(&self, ctx: &Context) -> Vec> { + self.preds + .uses() + .map(|r#use| { + r#use + .op + .deref(ctx) + .get_container() + .expect("Terminator branching to this block is not in any basic block") + }) + .collect() + } + + /// Checks whether self is a successor of `pred`. + /// O(n) in the number of successors of `pred`. + pub fn is_succ_of(&self, ctx: &Context, pred: Ptr) -> bool { + let self_ptr = self.get_self_ptr(ctx); + // We do not check [Self::get_defnode_ref].uses here because + // we'd have to go through them all. We do not have a Use<_> + // object to directly check membership. + pred.deref(ctx).get_tail().map_or(false, |pred_term| { + pred_term + .deref(ctx) + .successors() + .any(|succ| self_ptr == succ) + }) + } + + /// Retarget predecessors (that satisfy pred) to `other`. + pub fn retarget_some_preds_to) -> bool>( + &mut self, + ctx: &Context, + pred: P, + other: Ptr, + ) { + let predicate = |ctx: &Context, r#use: &Use>| { + let pred_block = r#use + .op + .deref(ctx) + .get_container() + .expect("Predecessor block must be in a Region"); + pred(ctx, pred_block) + }; + + self.preds.replace_some_uses_with(ctx, predicate, &other); + } + + /// Get all successors of this block. + pub fn succs(&self, ctx: &Context) -> Vec> { + self.get_tail() + .expect("A well formed BasicBlock must have a terminator") + .deref(ctx) + .successors() + .collect() + } + /// Drop all uses that this block holds. pub fn drop_all_uses(ptr: Ptr, ctx: &Context) { let ops: Vec<_> = ptr.deref(ctx).iter(ctx).collect(); @@ -216,6 +280,16 @@ impl BasicBlock { } } +impl Located for BasicBlock { + fn loc(&self) -> Location { + self.loc.clone() + } + + fn set_loc(&mut self, loc: Location) { + self.loc = loc; + } +} + impl private::ContainsLinkedList for BasicBlock { fn set_head(&mut self, head: Option>) { self.ops_list.first = head; diff --git a/src/builtin/ops.rs b/src/builtin/ops.rs index 11ae008..0984706 100644 --- a/src/builtin/ops.rs +++ b/src/builtin/ops.rs @@ -116,11 +116,6 @@ impl ModuleOp { opop } - - /// Add an [Operation] into this module. - pub fn add_operation(&self, ctx: &mut Context, op: Ptr) { - self.append_operation(ctx, op, 0) - } } impl_op_interface!(OneRegionInterface for ModuleOp {}); diff --git a/src/lib.rs b/src/lib.rs index 799c215..445eec2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ pub mod trait_cast; pub mod r#type; pub mod uniqued_any; pub mod use_def_lists; +pub mod utils; pub mod vec_exns; pub mod walkers; diff --git a/src/parsable.rs b/src/parsable.rs index 950eeb8..a21cde2 100644 --- a/src/parsable.rs +++ b/src/parsable.rs @@ -350,7 +350,8 @@ impl NameTracker { match scope.entry(id.0.clone()) { Entry::Occupied(mut occ) => match occ.get_mut() { LabelRef::ForwardRef(fref) => { - fref.retarget_some_preds_to(ctx, |_, _| true, block); + fref.deref_mut(ctx) + .retarget_some_preds_to(ctx, |_, _| true, block); BasicBlock::erase(*fref, ctx); occ.insert(LabelRef::Defined(block)); } diff --git a/src/use_def_lists.rs b/src/use_def_lists.rs index 2c98094..7bb7b2d 100644 --- a/src/use_def_lists.rs +++ b/src/use_def_lists.rs @@ -19,7 +19,6 @@ use crate::{ basic_block::BasicBlock, common_traits::Named, context::{Context, Ptr}, - linked_list::{ContainsLinkedList, LinkedList}, operation::Operation, printable::Printable, r#type::{TypeObj, Typed}, @@ -60,8 +59,8 @@ impl DefNode { } /// Get a reference to all [Use]es. - pub(crate) fn get_uses(&self) -> Vec> { - self.uses.iter().cloned().collect() + pub(crate) fn uses(&self) -> impl Iterator> + '_ { + self.uses.iter().cloned() } /// This definition has a new use. Track it and return a corresponding [Use]. @@ -84,7 +83,7 @@ impl DefNode { pub(crate) fn replace_some_uses_with) -> bool>( &mut self, ctx: &Context, - pred: P, + predicate: P, other: &T, ) where T: DefTrait + UseTrait, @@ -92,12 +91,12 @@ impl DefNode { if std::ptr::eq(self, &*other.get_defnode_ref(ctx)) { return; } - for r#use in self.uses.iter().filter(|r#use| pred(ctx, r#use)) { + for r#use in self.uses.iter().filter(|r#use| predicate(ctx, r#use)) { let mut use_mut = T::get_usenode_mut(r#use, ctx); *use_mut = other.get_defnode_mut(ctx).add_use(*other, *r#use); } // self will no longer have these uses. - self.uses.retain(|r#use| !pred(ctx, r#use)); + self.uses.retain(|r#use| !predicate(ctx, r#use)); } } @@ -136,12 +135,12 @@ impl Value { /// Get all uses of this value. pub fn get_uses(&self, ctx: &Context) -> Vec> { - self.get_defnode_ref(ctx).get_uses() + self.get_defnode_ref(ctx).uses().collect() } /// Does this definition have any [Use]? pub fn has_use(&self, ctx: &Context) -> bool { - !self.get_defnode_ref(ctx).has_use() + self.get_defnode_ref(ctx).has_use() } /// Replace uses of the underlying definition, that satisfy `pred`, with `other`. @@ -260,58 +259,6 @@ impl Named for Ptr { } } -impl Ptr { - /// How many predecessors does this block have? - pub fn num_preds(&self, ctx: &Context) -> usize { - self.get_defnode_ref(ctx).num_uses() - } - - /// Get all predecessors of this value. - pub fn get_preds(&self, ctx: &Context) -> Vec> { - self.get_defnode_ref(ctx) - .get_uses() - .iter() - .map(|r#use| r#use.op.deref(ctx).get_successor(r#use.opd_idx).unwrap()) - .collect() - } - - /// Does this [BasicBlock] have any predecessor? - pub fn has_pred(&self, ctx: &Context) -> bool { - !self.get_defnode_ref(ctx).has_use() - } - - /// Checks whether self is a successor of `pred`. - /// O(n) in the number of successors of `pred`. - pub fn is_succ_of(&self, ctx: &Context, pred: Ptr) -> bool { - // We do not check [Self::get_defnode_ref].uses here because - // we'd have to go through them all. We do not have a Use<_> - // object to directly check membership. - pred.deref(ctx).get_tail().map_or(false, |pred_term| { - pred_term.deref(ctx).successors().any(|succ| self == &succ) - }) - } - - /// Retarget predecessors (that satisfy pred) to `other`. - pub fn retarget_some_preds_to) -> bool>( - &self, - ctx: &Context, - pred: P, - other: Ptr, - ) { - let pred = |ctx: &Context, r#use: &Use>| { - let pred_block = r#use - .op - .deref(ctx) - .get_container() - .expect("Predecessor block must be in a Region"); - pred(ctx, pred_block) - }; - - self.get_defnode_mut(ctx) - .replace_some_uses_with(ctx, pred, &other); - } -} - impl DefTrait for Ptr { fn get_defnode_ref<'a>(&self, ctx: &'a Context) -> Ref<'a, DefNode> { let block = self.deref(ctx); diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..1b8073d --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod traversals; diff --git a/src/utils/traversals.rs b/src/utils/traversals.rs new file mode 100644 index 0000000..20884cd --- /dev/null +++ b/src/utils/traversals.rs @@ -0,0 +1,51 @@ +//! Utility functions for various graph traversals + +/// Region traversal utilities +pub mod region { + use rustc_hash::FxHashSet; + + use crate::{ + basic_block::BasicBlock, + context::{Context, Ptr}, + linked_list::ContainsLinkedList, + region::Region, + }; + + /// Compute post-order of the blocks in a region. + pub fn post_order(ctx: &Context, reg: Ptr) -> Vec> { + let on_stack = &mut FxHashSet::>::default(); + let mut po = Vec::>::new(); + + fn walk( + ctx: &Context, + block: Ptr, + on_stack: &mut FxHashSet>, + po: &mut Vec>, + ) { + if !on_stack.insert(block) { + // block already visited. + return; + } + // Visit successors before visiting self. + for succ in block.deref(ctx).succs(ctx) { + walk(ctx, succ, on_stack, po); + } + // Visit self. + po.push(block); + } + + // Walk every block (not just entry) since we may have unreachable blocks. + for block in reg.deref(ctx).iter(ctx) { + walk(ctx, block, on_stack, &mut po); + } + + po + } + + /// Compute reverse-post-order of the blocks in a region. + pub fn topological_order(ctx: &Context, reg: Ptr) -> Vec> { + let mut po = post_order(ctx, reg); + po.reverse(); + po + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d627996..902744a 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,8 +5,8 @@ use pliron::{ self, attributes::IntegerAttr, op_interfaces::{ - IsTerminatorInterface, OneResultInterface, OneResultVerifyErr, ZeroOpdInterface, - ZeroResultVerifyErr, + IsTerminatorInterface, OneResultInterface, OneResultVerifyErr, + SingleBlockRegionInterface, ZeroOpdInterface, ZeroResultVerifyErr, }, ops::{FuncOp, ModuleOp}, types::{FunctionType, IntegerType, Signedness}, @@ -180,7 +180,7 @@ pub fn const_ret_in_mod(ctx: &mut Context) -> Result<(ModuleOp, FuncOp, Constant // Our function is going to have type () -> (). let func_ty = FunctionType::get(ctx, vec![], vec![i64_ty.into()]); let func = FuncOp::new(ctx, "foo", func_ty); - module.add_operation(ctx, func.get_operation()); + module.append_operation(ctx, func.get_operation(), 0); let bb = func.get_entry_block(ctx); // Create a `const 0` op and add it to bb.