From 9c14df9a7ed05826922a957a100a5f54b166894b Mon Sep 17 00:00:00 2001 From: Vaivaswatha Nagaraj Date: Fri, 21 Jun 2024 18:56:42 +0530 Subject: [PATCH] Use `Identifier` as key to `AttributeDict` Also move `ConstantOp` to the LLVM, and test dialects separately. --- Cargo.toml | 1 - pliron-llvm/Cargo.toml | 4 +- pliron-llvm/src/bin/llvm-opt.rs | 54 +++++++++- pliron-llvm/src/from_inkwell.rs | 58 ++++++----- pliron-llvm/src/op_interfaces.rs | 10 +- pliron-llvm/src/ops.rs | 167 +++++++++++++++++++++++++------ src/attribute.rs | 99 ++++++++++++++++-- src/builtin/attributes.rs | 120 ++++++++++------------ src/builtin/mod.rs | 4 +- src/builtin/op_interfaces.rs | 17 ++-- src/builtin/ops.rs | 138 ++++--------------------- src/debug_info.rs | 63 +++++++++--- src/identifier.rs | 17 +++- src/op.rs | 91 +++++++++-------- tests/common/mod.rs | 98 ++++++++++++++++-- tests/ir_construct.rs | 57 ++++------- 16 files changed, 629 insertions(+), 369 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d9bb932..a289c71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,6 @@ downcast-rs = "1.2.1" rustc-hash.workspace = true thiserror.workspace = true apint = "0.2.0" -sorted_vector_map = "0.1.0" linkme.workspace = true once_cell = "1.19.0" paste = "1.0" diff --git a/pliron-llvm/Cargo.toml b/pliron-llvm/Cargo.toml index 126bde7..96c2879 100644 --- a/pliron-llvm/Cargo.toml +++ b/pliron-llvm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pliron-llvm" -description = "Derive macros for pliron" +description = "LLVM dialect for pliron" version.workspace = true edition.workspace = true repository.workspace = true @@ -14,7 +14,7 @@ license.workspace = true [dependencies] pliron-derive = { path = "../pliron-derive", version = "0" } pliron = { path = "../", version = "0" } -clap = "4.5" +clap = { version = "4.5", features = ["derive"] } combine.workspace = true thiserror.workspace = true linkme.workspace = true diff --git a/pliron-llvm/src/bin/llvm-opt.rs b/pliron-llvm/src/bin/llvm-opt.rs index cb9056f..35b545a 100644 --- a/pliron-llvm/src/bin/llvm-opt.rs +++ b/pliron-llvm/src/bin/llvm-opt.rs @@ -1,5 +1,53 @@ -use pliron::result::Result; +use std::{path::PathBuf, process::ExitCode}; -pub fn main() -> Result<()> { - todo!() +use inkwell::{context::Context as IWContext, module::Module as IWModule}; + +use clap::Parser; +use pliron::{arg_error_noloc, context::Context, printable::Printable, result::Result}; +use pliron_llvm::from_inkwell; + +#[derive(Parser)] +#[command(version, about="LLVM Optimizer", long_about = None)] +struct Cli { + /// Input LLVM file + #[arg(short, value_name = "FILE")] + input: PathBuf, + + /// Output LLVM file + #[arg(short, value_name = "FILE")] + output: PathBuf, + + /// Emit text LLVM-IR + #[arg(short = 'S', default_value_t = false)] + text_output: bool, +} + +fn run(cli: Cli, ctx: &mut Context) -> Result<()> { + let context = IWContext::create(); + let module = IWModule::parse_bitcode_from_path(cli.input, &context) + .map_err(|err| arg_error_noloc!("{}", err))?; + + let pliron_module = from_inkwell::convert_module(ctx, &module)?; + println!("{}", pliron_module.disp(ctx)); + + if cli.text_output { + module + .print_to_file(&cli.output) + .map_err(|err| arg_error_noloc!("{}", err.to_string()))?; + } else { + module.write_bitcode_to_path(&cli.output); + } + Ok(()) +} + +pub fn main() -> ExitCode { + let cli = Cli::parse(); + let ctx = &mut Context::default(); + match run(cli, ctx) { + Ok(_) => ExitCode::SUCCESS, + Err(e) => { + eprintln!("{}", e.disp(ctx)); + ExitCode::FAILURE + } + } } diff --git a/pliron-llvm/src/from_inkwell.rs b/pliron-llvm/src/from_inkwell.rs index 9cd7f2a..433e897 100644 --- a/pliron-llvm/src/from_inkwell.rs +++ b/pliron-llvm/src/from_inkwell.rs @@ -119,12 +119,23 @@ pub fn convert_ipredicate(ipred: IntPredicate) -> ICmpPredicateAttr { } /// Mapping from inkwell entities to pliron entities. -#[derive(Default)] -struct ConversionMaps<'ctx> { +struct ConversionContext<'ctx> { // A map from inkwell's Values to pliron's Values. value_map: FxHashMap, Value>, // A map from inkwell's basic blocks to plirons'. block_map: FxHashMap, Ptr>, + // Entry block of the function we're processing. + _entry_block: Ptr, +} + +impl<'ctx> ConversionContext<'ctx> { + fn new(entry_block: Ptr) -> Self { + Self { + value_map: FxHashMap::default(), + block_map: FxHashMap::default(), + _entry_block: entry_block, + } + } } /// Get the successors of an inkwell block. @@ -201,22 +212,23 @@ pub enum ConversionErr { } fn convert_operands( - cmap: &ConversionMaps, + cctx: &ConversionContext, inst: InstructionValue, ) -> Result<(Vec, Vec>)> { let mut opds = vec![]; let mut succs = vec![]; for opd in inst.get_operands().flatten() { if let Some(val) = opd.left() { - let Some(m_val) = cmap.value_map.get(&val.as_any_value_enum()) else { + if let Some(m_val) = cctx.value_map.get(&val.as_any_value_enum()) { + opds.push(*m_val); + } else { return input_err_noloc!(ConversionErr::UndefinedValue( val.as_any_value_enum().print_to_string().to_string() )); - }; - opds.push(*m_val); + } } else { let block = opd.right().unwrap(); - let Some(m_block) = cmap.block_map.get(&block) else { + let Some(m_block) = cctx.block_map.get(&block) else { return input_err_noloc!(ConversionErr::UndefinedBlock( block.get_name().to_str_res().unwrap().to_string() )); @@ -235,7 +247,7 @@ fn get_operand(opds: &[T], idx: usize) -> Result { /// Compute the arguments to be passed when branching from `src` to `dest`. fn convert_branch_args( - cmap: &ConversionMaps, + cctx: &ConversionContext, src_block: IWBasicBlock, dst_block: IWBasicBlock, ) -> Result> { @@ -249,7 +261,7 @@ fn convert_branch_args( src_block.get_name().to_str_res().unwrap().to_string() )); }; - let Some(m_incoming_val) = cmap.value_map.get(&incoming_val.as_any_value_enum()) else { + let Some(m_incoming_val) = cctx.value_map.get(&incoming_val.as_any_value_enum()) else { return input_err_noloc!(ConversionErr::UndefinedValue( incoming_val .as_any_value_enum() @@ -268,10 +280,10 @@ fn convert_branch_args( fn convert_instruction( ctx: &mut Context, - cmap: &ConversionMaps, + cctx: &ConversionContext, inst: InstructionValue, ) -> Result> { - let (ref opds, ref succs) = convert_operands(cmap, inst)?; + let (ref opds, ref succs) = convert_operands(cctx, inst)?; match inst.get_opcode() { InstructionOpcode::Add => { let (lhs, rhs) = (get_operand(opds, 0)?, get_operand(opds, 1)?); @@ -306,12 +318,12 @@ fn convert_instruction( "Conditional branch must have two successors" ); let true_dest_opds = convert_branch_args( - cmap, + cctx, inst.get_parent().unwrap(), inst.get_operand(1).unwrap().unwrap_right(), )?; let false_dest_opds = convert_branch_args( - cmap, + cctx, inst.get_parent().unwrap(), inst.get_operand(2).unwrap().unwrap_right(), )?; @@ -326,7 +338,7 @@ fn convert_instruction( .get_operation()) } else { let dest_opds = convert_branch_args( - cmap, + cctx, inst.get_parent().unwrap(), inst.get_operand(0).unwrap().unwrap_right(), )?; @@ -444,7 +456,7 @@ fn convert_instruction( // Convert inkwell `block` to pliron's `m_block`. fn convert_block<'ctx>( ctx: &mut Context, - cmap: &mut ConversionMaps<'ctx>, + cctx: &mut ConversionContext<'ctx>, block: IWBasicBlock<'ctx>, m_block: Ptr, ) -> Result<()> { @@ -453,14 +465,14 @@ fn convert_block<'ctx>( if inst_val.is_phi_value() { let ty = convert_type(ctx, &inst.get_type().as_any_type_enum())?; let arg_idx = m_block.deref_mut(ctx).add_argument(ty); - cmap.value_map + cctx.value_map .insert(inst_val, m_block.deref(ctx).get_argument(arg_idx).unwrap()); } else { - let m_inst = convert_instruction(ctx, cmap, inst)?; + let m_inst = convert_instruction(ctx, cctx, inst)?; m_inst.insert_at_back(m_block, ctx); // LLVM instructions have at most one result. if let Some(res) = m_inst.deref(ctx).get_result(0) { - cmap.value_map.insert(inst_val, res); + cctx.value_map.insert(inst_val, res); } } } @@ -475,7 +487,7 @@ fn convert_function(ctx: &mut Context, function: FunctionValue) -> Result Result = + pliron::Lazy::new(|| "llvm_integer_overflow_flags".try_into().unwrap()); #[derive(Error, Debug)] #[error("IntegerOverflowFlag missing on Op")] @@ -103,7 +105,7 @@ decl_op_interface! { self.get_operation() .deref(ctx) .attributes - .get::(ATTR_KEY_INTEGER_OVERFLOW_FLAGS) + .get::(&ATTR_KEY_INTEGER_OVERFLOW_FLAGS) .expect("Integer overflow flag missing or is of incorrect type") .clone() } @@ -116,7 +118,7 @@ decl_op_interface! { self.get_operation() .deref_mut(ctx) .attributes - .set(ATTR_KEY_INTEGER_OVERFLOW_FLAGS, flag); + .set(ATTR_KEY_INTEGER_OVERFLOW_FLAGS.clone(), flag); } fn verify(op: &dyn Op, ctx: &Context) -> Result<()> @@ -125,7 +127,7 @@ decl_op_interface! { { let op = op.get_operation().deref(ctx); if op.attributes.get:: - (ATTR_KEY_INTEGER_OVERFLOW_FLAGS).is_none() + (&ATTR_KEY_INTEGER_OVERFLOW_FLAGS).is_none() { return verify_err!(op.loc(), IntBinArithOpWithOverflowFlagErr); } diff --git a/pliron-llvm/src/ops.rs b/pliron-llvm/src/ops.rs index 68cd8b8..9a353fb 100644 --- a/pliron-llvm/src/ops.rs +++ b/pliron-llvm/src/ops.rs @@ -2,19 +2,21 @@ use pliron::{ arg_err_noloc, + attribute::{attr_cast, AttrObj}, basic_block::BasicBlock, builtin::{ attr_interfaces::TypedAttrInterface, - attributes::{StringAttr, TypeAttr}, + attributes::{FloatAttr, IntegerAttr, StringAttr, TypeAttr}, op_interfaces::{ BranchOpInterface, CallOpCallable, CallOpInterface, IsTerminatorInterface, OneOpdInterface, OneResultInterface, SameOperandsAndResultType, SameOperandsType, - SameResultsType, ZeroResultInterface, ATTR_KEY_CALLEE_TYPE, + SameResultsType, ZeroOpdInterface, ZeroResultInterface, ATTR_KEY_CALLEE_TYPE, }, types::{FunctionType, IntegerType, Signedness}, }, common_traits::Verify, context::{Context, Ptr}, + identifier::Identifier, impl_canonical_syntax, impl_op_interface, impl_verify_succ, location::Located, op::Op, @@ -212,12 +214,18 @@ pub enum ICmpOpVerifyErr { /// /// | key | value | via Interface | /// |-----|-------| --------------| -/// | [ATTR_KEY_PREDICATE](ICmpOp::ATTR_KEY_PREDICATE) | [ICmpPredicateAttr](ICmpPredicateAttr) | N/A | +/// | [ATTR_KEY_PREDICATE](icmp_op::ATTR_KEY_PREDICATE) | [ICmpPredicateAttr](ICmpPredicateAttr) | N/A | #[def_op("llvm.icmp")] pub struct ICmpOp {} +pub mod icmp_op { + use super::*; + + pub static ATTR_KEY_PREDICATE: pliron::Lazy = + pliron::Lazy::new(|| "llvm_icmp_predicate".try_into().unwrap()); +} + impl ICmpOp { - pub const ATTR_KEY_PREDICATE: &'static str = "llvm.icmp_predicate"; /// Create a new [ICmpOp] pub fn new(ctx: &mut Context, pred: ICmpPredicateAttr, lhs: Value, rhs: Value) -> Self { let bool_ty = IntegerType::get(ctx, 1, Signedness::Signless); @@ -231,7 +239,7 @@ impl ICmpOp { ); op.deref_mut(ctx) .attributes - .set(Self::ATTR_KEY_PREDICATE, pred); + .set(icmp_op::ATTR_KEY_PREDICATE.clone(), pred); ICmpOp { op } } } @@ -243,7 +251,7 @@ impl Verify for ICmpOp { if op .attributes - .get::(Self::ATTR_KEY_PREDICATE) + .get::(&icmp_op::ATTR_KEY_PREDICATE) .is_none() { verify_err!(op.loc(), ICmpOpVerifyErr::PredAttrErr)? @@ -295,7 +303,7 @@ pub enum AllocaOpVerifyErr { /// /// | key | value | via Interface | /// |-----|-------| --------------| -/// | [ATTR_KEY_ELEM_TYPE](AllocaOp::ATTR_KEY_ELEM_TYPE) | [TypeAttr](pliron::builtin::attributes::TypeAttr) | N/A | +/// | [ATTR_KEY_ELEM_TYPE](alloca_op::ATTR_KEY_ELEM_TYPE) | [TypeAttr](pliron::builtin::attributes::TypeAttr) | N/A | #[def_op("llvm.alloca")] pub struct AllocaOp {} impl_canonical_syntax!(AllocaOp); @@ -310,7 +318,7 @@ impl Verify for AllocaOp { // Ensure correctness of element type. if op .attributes - .get::(Self::ATTR_KEY_ELEM_TYPE) + .get::(&alloca_op::ATTR_KEY_ELEM_TYPE) .is_none() { verify_err!(op.loc(), AllocaOpVerifyErr::ElemTypeAttr)? @@ -326,15 +334,19 @@ impl_op_interface!(PointerTypeResult for AllocaOp { self.op .deref(ctx) .attributes - .get::(Self::ATTR_KEY_ELEM_TYPE) + .get::(&alloca_op::ATTR_KEY_ELEM_TYPE) .expect("AllocaOp missing or incorrect type for elem_type attribute") .get_type() } }); -impl AllocaOp { - pub const ATTR_KEY_ELEM_TYPE: &'static str = "llvm.element_type"; +pub mod alloca_op { + use super::*; + pub static ATTR_KEY_ELEM_TYPE: pliron::Lazy = + pliron::Lazy::new(|| "llvm_alloca_element_type".try_into().unwrap()); +} +impl AllocaOp { /// Create a new [AllocaOp] pub fn new(ctx: &mut Context, elem_type: Ptr, size: Value) -> Self { let ptr_ty = PointerType::get(ctx).into(); @@ -346,9 +358,10 @@ impl AllocaOp { vec![], 0, ); - op.deref_mut(ctx) - .attributes - .set(Self::ATTR_KEY_ELEM_TYPE, TypeAttr::new(elem_type)); + op.deref_mut(ctx).attributes.set( + alloca_op::ATTR_KEY_ELEM_TYPE.clone(), + TypeAttr::new(elem_type), + ); AllocaOp { op } } } @@ -503,8 +516,8 @@ pub enum GetElementPtrOpErr { /// /// | key | value | via Interface | /// |-----|-------| --------------| -/// | [ATTR_KEY_INDICES](GetElementPtrOp::ATTR_KEY_INDICES) | [GepIndicesAttr](super::attributes::GepIndicesAttr)> | N/A | -/// | [ATTR_KEY_SRC_ELEM_TYPE](GetElementPtrOp::ATTR_KEY_SRC_ELEM_TYPE) | [TypeAttr] | N/A | +/// | [ATTR_KEY_INDICES](gep_op::ATTR_KEY_INDICES) | [GepIndicesAttr](super::attributes::GepIndicesAttr)> | N/A | +/// | [ATTR_KEY_SRC_ELEM_TYPE](gep_op::ATTR_KEY_SRC_ELEM_TYPE) | [TypeAttr] | N/A | /// /// ### Operands /// | operand | description | @@ -532,7 +545,7 @@ impl Verify for GetElementPtrOp { // Ensure that we have the indices as an attribute. if op .attributes - .get::(Self::ATTR_KEY_INDICES) + .get::(&gep_op::ATTR_KEY_INDICES) .is_none() { verify_err!(op.loc(), GetElementPtrOpErr::IndicesAttrErr)? @@ -552,10 +565,17 @@ impl Verify for GetElementPtrOp { } } -impl GetElementPtrOp { +pub mod gep_op { + use super::*; /// [Attribute](pliron::attribute::Attribute) to get the indices vector. - pub const ATTR_KEY_INDICES: &'static str = "llvm.gep_indices"; - pub const ATTR_KEY_SRC_ELEM_TYPE: &'static str = "llvm.gep_src_elem_type"; + pub static ATTR_KEY_INDICES: pliron::Lazy = + pliron::Lazy::new(|| "llvm_gep_indices".try_into().unwrap()); + /// [Attribute](pliron::attribute::Attribute) to get the source element type. + pub static ATTR_KEY_SRC_ELEM_TYPE: pliron::Lazy = + pliron::Lazy::new(|| "llvm_gep_src_elem_type".try_into().unwrap()); +} + +impl GetElementPtrOp { /// Create a new [GetElementPtrOp] pub fn new( ctx: &mut Context, @@ -578,10 +598,10 @@ impl GetElementPtrOp { let op = Operation::new(ctx, Self::get_opid_static(), vec![], opds, vec![], 0); op.deref_mut(ctx) .attributes - .set(Self::ATTR_KEY_INDICES, GepIndicesAttr(attr)); + .set(gep_op::ATTR_KEY_INDICES.clone(), GepIndicesAttr(attr)); op.deref_mut(ctx) .attributes - .set(Self::ATTR_KEY_SRC_ELEM_TYPE, elem_type); + .set(gep_op::ATTR_KEY_SRC_ELEM_TYPE.clone(), elem_type); GetElementPtrOp { op } } @@ -590,7 +610,7 @@ impl GetElementPtrOp { self.op .deref(ctx) .attributes - .get::(Self::ATTR_KEY_SRC_ELEM_TYPE) + .get::(&gep_op::ATTR_KEY_SRC_ELEM_TYPE) .expect("GetElementPtrOp missing or has incorrect src_elem_type attribute type") .get_type() } @@ -604,7 +624,7 @@ impl GetElementPtrOp { pub fn indices(&self, ctx: &Context) -> Vec { let op = &*self.op.deref(ctx); op.attributes - .get::(Self::ATTR_KEY_INDICES) + .get::(&gep_op::ATTR_KEY_INDICES) .unwrap() .0 .iter() @@ -786,14 +806,20 @@ impl_op_interface!(ZeroResultInterface for LoadOp {}); /// ### Attributes: /// | key | value | via Interface | /// |-----|-------| --------------| -/// | [ATTR_KEY_CALLEE](Self::ATTR_KEY_CALLEE) | [StringAttr] | N/A | +/// | [ATTR_KEY_CALLEE](call_op::ATTR_KEY_CALLEE) | [StringAttr] | N/A | /// | [ATTR_KEY_CALLEE_TYPE](pliron::builtin::op_interfaces::ATTR_KEY_CALLEE_TYPE) | [TypeAttr] | [CallOpInterface] | /// #[def_op("llvm.call")] pub struct CallOp {} +pub mod call_op { + use super::*; + pub static ATTR_KEY_CALLEE: pliron::Lazy = + pliron::Lazy::new(|| "llvm_call_callee".try_into().unwrap()); +} + impl CallOp { - pub const ATTR_KEY_CALLEE: &'static str = "llvm.callee"; + /// Get a new [CallOp]. pub fn new( ctx: &mut Context, callee: CallOpCallable, @@ -807,7 +833,7 @@ impl CallOp { Operation::new(ctx, Self::get_opid_static(), vec![res_ty], args, vec![], 0); op.deref_mut(ctx) .attributes - .set(Self::ATTR_KEY_CALLEE, StringAttr::new(cval)); + .set(call_op::ATTR_KEY_CALLEE.clone(), StringAttr::new(cval)); op } CallOpCallable::Indirect(csym) => { @@ -815,16 +841,17 @@ impl CallOp { Operation::new(ctx, Self::get_opid_static(), vec![res_ty], args, vec![], 0) } }; - op.deref_mut(ctx) - .attributes - .set(ATTR_KEY_CALLEE_TYPE, TypeAttr::new(callee_ty.into())); + op.deref_mut(ctx).attributes.set( + ATTR_KEY_CALLEE_TYPE.clone(), + TypeAttr::new(callee_ty.into()), + ); CallOp { op } } } impl CallOpInterface for CallOp { fn callee(&self, ctx: &Context) -> CallOpCallable { let op = self.op.deref(ctx); - if let Some(callee_sym) = op.attributes.get::(Self::ATTR_KEY_CALLEE) { + if let Some(callee_sym) = op.attributes.get::(&call_op::ATTR_KEY_CALLEE) { CallOpCallable::Direct(callee_sym.clone().into()) } else { CallOpCallable::Indirect( @@ -837,7 +864,7 @@ impl CallOpInterface for CallOp { let op = self.op.deref(ctx); let skip = if op .attributes - .get::(Self::ATTR_KEY_CALLEE) + .get::(&call_op::ATTR_KEY_CALLEE) .is_some() { 0 @@ -850,6 +877,81 @@ impl CallOpInterface for CallOp { impl_canonical_syntax!(CallOp); impl_verify_succ!(CallOp); +/// Numeric constant. +/// See MLIR's [llvm.mlir.constant](https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirconstant-llvmconstantop). +/// +/// Attributes: +/// +/// | key | value | +/// |-----|-------| +/// |[ATTR_KEY_VALUE](constant_op::ATTR_KEY_VALUE) | [IntegerAttr] or [FloatAttr] | +/// +/// Results: +/// +/// | result | description | +/// |-----|-------| +/// | `result` | any type | +#[def_op("llvm.constant")] +pub struct ConstantOp {} + +pub mod constant_op { + use super::*; + /// Attribute key for the constant value. + pub static ATTR_KEY_VALUE: pliron::Lazy = + pliron::Lazy::new(|| "llvm_constant_value".try_into().unwrap()); +} + +impl ConstantOp { + /// Get the constant value that this Op defines. + pub fn get_value(&self, ctx: &Context) -> AttrObj { + let op = self.get_operation().deref(ctx); + op.attributes + .0 + .get(&constant_op::ATTR_KEY_VALUE) + .unwrap() + .clone() + } + + /// Create a new [ConstantOp]. + pub fn new(ctx: &mut Context, value: AttrObj) -> Self { + let result_type = attr_cast::(&*value) + .expect("ConstantOp const value must provide TypedAttrInterface") + .get_type(); + let op = Operation::new( + ctx, + Self::get_opid_static(), + vec![result_type], + vec![], + vec![], + 0, + ); + op.deref_mut(ctx) + .attributes + .0 + .insert(constant_op::ATTR_KEY_VALUE.clone(), value); + ConstantOp { op } + } +} + +#[derive(Error, Debug)] +#[error("{}: Unexpected type", ConstantOp::get_opid_static())] +pub struct ConstantOpVerifyErr; + +impl Verify for ConstantOp { + fn verify(&self, ctx: &Context) -> Result<()> { + let loc = self.get_operation().deref(ctx).loc(); + let value = self.get_value(ctx); + if !(value.is::() || value.is::()) { + return verify_err!(loc, ConstantOpVerifyErr); + } + Ok(()) + } +} + +impl_canonical_syntax!(ConstantOp); +impl_op_interface! (ZeroOpdInterface for ConstantOp {}); +impl_op_interface! (OneResultInterface for ConstantOp {}); + /// Register ops in the LLVM dialect. pub fn register(ctx: &mut Context) { AddOp::register(ctx, AddOp::parser_fn); @@ -874,5 +976,6 @@ pub fn register(ctx: &mut Context) { LoadOp::register(ctx, LoadOp::parser_fn); StoreOp::register(ctx, StoreOp::parser_fn); CallOp::register(ctx, CallOp::parser_fn); + ConstantOp::register(ctx, ConstantOp::parser_fn); ReturnOp::register(ctx, ReturnOp::parser_fn); } diff --git a/src/attribute.rs b/src/attribute.rs index d5271df..ea48374 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -34,7 +34,7 @@ use std::{ ops::Deref, }; -use combine::{parser, Parser}; +use combine::{between, parser, token, Parser}; use downcast_rs::{impl_downcast, Downcast}; use dyn_clone::DynClone; use linkme::distributed_slice; @@ -46,35 +46,118 @@ use crate::{ dialect::DialectName, identifier::Identifier, input_err, - irfmt::parsers::spaced, + irfmt::{ + parsers::{attr_parser, delimited_list_parser, spaced}, + printers::iter_with_sep, + }, location::Located, parsable::{Parsable, ParseResult, ParserFn, StateStream}, printable::{self, Printable}, result::Result, }; +#[derive(Clone)] +struct AttributeDictKeyVal { + key: Identifier, + val: AttrObj, +} +impl Printable for AttributeDictKeyVal { + fn fmt( + &self, + ctx: &Context, + _state: &printable::State, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "({}: {})", self.key, self.val.disp(ctx)) + } +} + +impl Parsable for AttributeDictKeyVal { + type Arg = (); + + type Parsed = Self; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + _arg: Self::Arg, + ) -> ParseResult<'a, Self::Parsed> { + between( + token('('), + token(')'), + (Identifier::parser(()), spaced(token(':')), attr_parser()), + ) + .map(|(key, _, val)| AttributeDictKeyVal { key, val }) + .parse_stream(state_stream) + .into_result() + } +} + +impl Printable for AttributeDict { + fn fmt( + &self, + ctx: &Context, + _state: &printable::State, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!( + f, + "[{}]", + iter_with_sep( + self.0.iter().map(|(key, val)| AttributeDictKeyVal { + key: key.clone(), + val: val.clone() + }), + printable::ListSeparator::CharSpace(','), + ) + .disp(ctx) + ) + } +} + +impl Parsable for AttributeDict { + type Arg = (); + type Parsed = Self; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + _arg: Self::Arg, + ) -> ParseResult<'a, Self::Parsed> { + delimited_list_parser('[', ']', ',', AttributeDictKeyVal::parser(())) + .map(|key_vals| { + AttributeDict( + key_vals + .into_iter() + .map(|key_val| (key_val.key, key_val.val)) + .collect(), + ) + }) + .parse_stream(state_stream) + .into_result() + } +} + /// A dictionary of attributes, mapping keys to attribute objects. -#[derive(Default)] -pub struct AttributeDict(pub FxHashMap<&'static str, AttrObj>); +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct AttributeDict(pub FxHashMap); impl AttributeDict { /// Get reference to attribute value that is mapped to key `k`. - pub fn get(&self, k: &'static str) -> Option<&T> { + pub fn get(&self, k: &Identifier) -> Option<&T> { self.0.get(k).and_then(|ao| ao.downcast_ref::()) } /// Get mutable reference to attribute value that is mapped to key `k`. - pub fn get_mut(&mut self, k: &'static str) -> Option<&mut T> { + pub fn get_mut(&mut self, k: &Identifier) -> Option<&mut T> { self.0.get_mut(k).and_then(|ao| ao.downcast_mut::()) } /// Reference to the attribute value (that is mapped to key `k`) as an interface reference. - pub fn get_as(&self, k: &'static str) -> Option<&T> { + pub fn get_as(&self, k: &Identifier) -> Option<&T> { self.0.get(k).and_then(|ao| attr_cast::(&**ao)) } /// Set the attribute value for key `k`. - pub fn set(&mut self, k: &'static str, v: T) { + pub fn set(&mut self, k: Identifier, v: T) { self.0.insert(k, Box::new(v)); } } diff --git a/src/builtin/attributes.rs b/src/builtin/attributes.rs index 1e45c82..cb7992c 100644 --- a/src/builtin/attributes.rs +++ b/src/builtin/attributes.rs @@ -1,17 +1,16 @@ use apint::ApInt; use combine::{ any, between, many, many1, none_of, - parser::char::{hex_digit, string}, + parser::char::{self, hex_digit, string}, token, Parser, }; use pliron_derive::def_attribute; -use sorted_vector_map::SortedVectorMap; -use thiserror::Error; use crate::{ - attribute::{AttrObj, Attribute}, + attribute::{AttrObj, Attribute, AttributeDict}, common_traits::Verify, context::{Context, Ptr}, + identifier::Identifier, impl_attr_interface, impl_verify_succ, input_err, irfmt::{ parsers::{delimited_list_parser, spaced, type_parser}, @@ -22,7 +21,6 @@ use crate::{ printable::{self, Printable}, r#type::{TypeObj, TypePtr, Typed}, result::Result, - verify_err_noloc, }; use super::{attr_interfaces::TypedAttrInterface, types::IntegerType}; @@ -242,45 +240,26 @@ impl Parsable for FloatAttr { } } -/// An attribute that is a small dictionary of other attributes. -/// Implemented as a key-sorted list of key value pairs. -/// Efficient only for small number of keys. +/// An attribute that is a dictionary of other attributes. /// Similar to MLIR's [DictionaryAttr](https://mlir.llvm.org/docs/Dialects/Builtin/#dictionaryattr), -#[def_attribute("builtin.small_dict")] -#[derive(PartialEq, Eq, Clone, Debug)] -pub struct SmallDictAttr(SortedVectorMap<&'static str, AttrObj>); +#[def_attribute("builtin.dict")] +#[derive(PartialEq, Clone, Eq, Debug)] +pub struct DictAttr(AttributeDict); -impl Printable for SmallDictAttr { +impl Printable for DictAttr { fn fmt( &self, - _ctx: &Context, + ctx: &Context, _state: &printable::State, - _f: &mut core::fmt::Formatter<'_>, + f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - todo!() + write!(f, "{}", self.0.disp(ctx)) } } -#[derive(Error, Debug)] -#[error("SmallDictAttr keys are not sorted")] -struct SmallDictAttrVerifyErr; -impl Verify for SmallDictAttr { - fn verify(&self, _ctx: &Context) -> Result<()> { - for (str1, str2) in self - .0 - .iter() - .map(|(&key, _)| key) - .zip(self.0.iter().skip(1).map(|(&key, _)| key)) - { - if str1 > str2 { - return verify_err_noloc!(SmallDictAttrVerifyErr); - } - } - Ok(()) - } -} +impl_verify_succ!(DictAttr); -impl Parsable for SmallDictAttr { +impl Parsable for DictAttr { type Arg = (); type Parsed = Self; @@ -292,34 +271,34 @@ impl Parsable for SmallDictAttr { } } -impl SmallDictAttr { - /// Create a new [SmallDictAttr]. - pub fn new(value: Vec<(&'static str, AttrObj)>) -> Self { - let mut dict = SortedVectorMap::with_capacity(value.len()); +impl DictAttr { + /// Create a new [DictAttr]. + pub fn new(value: Vec<(Identifier, AttrObj)>) -> Self { + let mut dict = AttributeDict::default(); for (name, val) in value { - dict.insert(name, val); + dict.0.insert(name, val); } - SmallDictAttr(dict) + DictAttr(dict) } /// Add an entry to the dictionary. - pub fn insert(&mut self, key: &'static str, val: AttrObj) { - self.0.insert(key, val); + pub fn insert(&mut self, key: &Identifier, val: AttrObj) { + self.0 .0.insert(key.clone(), val); } /// Remove an entry from the dictionary. - pub fn remove(&mut self, key: &'static str) { - self.0.remove(key); + pub fn remove(&mut self, key: &Identifier) { + self.0 .0.remove(key); } /// Lookup a name in the dictionary. - pub fn lookup<'a>(&'a self, key: &'static str) -> Option<&'a AttrObj> { - self.0.get(key) + pub fn lookup<'a>(&'a self, key: &Identifier) -> Option<&'a AttrObj> { + self.0 .0.get(key) } /// Lookup a name in the dictionary, get a mutable reference. - pub fn lookup_mut<'a>(&'a mut self, key: &'static str) -> Option<&'a mut AttrObj> { - self.0.get_mut(key) + pub fn lookup_mut<'a>(&'a mut self, key: &Identifier) -> Option<&'a mut AttrObj> { + self.0 .0.get_mut(key) } } @@ -472,7 +451,7 @@ impl_attr_interface!( pub fn register(ctx: &mut Context) { StringAttr::register_attr_in_dialect(ctx, StringAttr::parser_fn); IntegerAttr::register_attr_in_dialect(ctx, IntegerAttr::parser_fn); - SmallDictAttr::register_attr_in_dialect(ctx, SmallDictAttr::parser_fn); + DictAttr::register_attr_in_dialect(ctx, DictAttr::parser_fn); VecAttr::register_attr_in_dialect(ctx, VecAttr::parser_fn); UnitAttr::register_attr_in_dialect(ctx, UnitAttr::parser_fn); TypeAttr::register_attr_in_dialect(ctx, TypeAttr::parser_fn); @@ -492,13 +471,14 @@ mod tests { types::{IntegerType, Signedness}, }, context::Context, + identifier::Identifier, irfmt::parsers::attr_parser, location, parsable::{self, state_stream_from_iterator}, printable::Printable, }; - use super::{SmallDictAttr, TypeAttr, VecAttr}; + use super::{DictAttr, TypeAttr, VecAttr}; #[test] fn test_integer_attributes() { let mut ctx = Context::new(); @@ -599,31 +579,39 @@ mod tests { let hello_attr: AttrObj = StringAttr::new("hello".to_string()).into(); let world_attr: AttrObj = StringAttr::new("world".to_string()).into(); - let mut dict1: AttrObj = SmallDictAttr::new(vec![ - ("hello", hello_attr.clone()), - ("world", world_attr.clone()), + let hello_id: Identifier = "hello".try_into().unwrap(); + let world_id: Identifier = "world".try_into().unwrap(); + + let mut dict1: AttrObj = DictAttr::new(vec![ + (hello_id.clone(), hello_attr.clone()), + (world_id.clone(), world_attr.clone()), ]) .into(); - let mut dict2 = - SmallDictAttr::new(vec![("hello", StringAttr::new("hello".to_string()).into())]).into(); - let dict1_rev = SmallDictAttr::new(vec![ - ("world", world_attr.clone()), - ("hello", hello_attr.clone()), + let mut dict2 = DictAttr::new(vec![( + hello_id.clone(), + StringAttr::new("hello".to_string()).into(), + )]) + .into(); + let dict1_rev = DictAttr::new(vec![ + (world_id.clone(), world_attr.clone()), + (hello_id.clone(), hello_attr.clone()), ]) .into(); assert!(&dict1 != &dict2); assert!(dict1 == dict1_rev); - let dict1_attr = dict1.as_mut().downcast_mut::().unwrap(); - let dict2_attr = dict2.as_mut().downcast_mut::().unwrap(); - assert!(dict1_attr.lookup("hello").unwrap() == &hello_attr); - assert!(dict1_attr.lookup("world").unwrap() == &world_attr); - assert!(dict1_attr.lookup("hello world").is_none()); - dict2_attr.insert("world", world_attr); + let dict1_attr = dict1.as_mut().downcast_mut::().unwrap(); + let dict2_attr = dict2.as_mut().downcast_mut::().unwrap(); + assert!(dict1_attr.lookup(&hello_id).unwrap() == &hello_attr); + assert!(dict1_attr.lookup(&world_id).unwrap() == &world_attr); + assert!(dict1_attr + .lookup(&"hello_world".try_into().unwrap()) + .is_none()); + dict2_attr.insert(&world_id, world_attr); assert!(dict1_attr == dict2_attr); - dict1_attr.remove("hello"); - dict2_attr.remove("hello"); + dict1_attr.remove(&hello_id); + dict2_attr.remove(&hello_id); assert!(&dict1 == &dict2); } diff --git a/src/builtin/mod.rs b/src/builtin/mod.rs index 57293d4..e8d2e19 100644 --- a/src/builtin/mod.rs +++ b/src/builtin/mod.rs @@ -7,6 +7,7 @@ pub mod types; use crate::{ context::Context, dialect::{Dialect, DialectName}, + identifier::Identifier, }; pub fn register(ctx: &mut Context) { @@ -18,4 +19,5 @@ pub fn register(ctx: &mut Context) { } /// Key for debug info related attributes. -pub const ATTR_KEY_DEBUG_INFO: &str = "builtin.debug_info"; +pub static ATTR_KEY_DEBUG_INFO: crate::Lazy = + crate::Lazy::new(|| "builtin_debug_info".try_into().unwrap()); diff --git a/src/builtin/op_interfaces.rs b/src/builtin/op_interfaces.rs index 4684d62..52afc4f 100644 --- a/src/builtin/op_interfaces.rs +++ b/src/builtin/op_interfaces.rs @@ -8,6 +8,7 @@ use crate::{ builtin::attributes::TypeAttr, context::{Context, Ptr}, decl_op_interface, + identifier::Identifier, linked_list::ContainsLinkedList, location::{Located, Location}, op::{op_cast, Op}, @@ -194,7 +195,8 @@ decl_op_interface! { } /// Key for symbol name attribute when the operation defines a symbol. -pub const ATTR_KEY_SYM_NAME: &str = "builtin.sym_name"; +pub static ATTR_KEY_SYM_NAME: crate::Lazy = + crate::Lazy::new(|| "builtin_sym_name".try_into().unwrap()); #[derive(Error, Debug)] #[error("Op implementing SymbolOpInterface does not have a symbol defined")] @@ -206,7 +208,7 @@ decl_op_interface! { // Get the name of the symbol defined by this operation. fn get_symbol_name(&self, ctx: &Context) -> String { let self_op = self.get_operation().deref(ctx); - let s_attr = self_op.attributes.get::(ATTR_KEY_SYM_NAME).unwrap(); + let s_attr = self_op.attributes.get::(&ATTR_KEY_SYM_NAME).unwrap(); String::from(s_attr.clone()) } @@ -214,7 +216,7 @@ decl_op_interface! { fn set_symbol_name(&self, ctx: &mut Context, name: &str) { let name_attr = StringAttr::new(name.to_string()); let mut self_op = self.get_operation().deref_mut(ctx); - self_op.attributes.set(ATTR_KEY_SYM_NAME, name_attr); + self_op.attributes.set(ATTR_KEY_SYM_NAME.clone(), name_attr); } fn verify(op: &dyn Op, ctx: &Context) -> Result<()> @@ -222,7 +224,7 @@ decl_op_interface! { Self: Sized, { let self_op = op.get_operation().deref(ctx); - if self_op.attributes.get::(ATTR_KEY_SYM_NAME).is_none() { + if self_op.attributes.get::(&ATTR_KEY_SYM_NAME).is_none() { return verify_err!(op.get_operation().deref(ctx).loc(), SymbolOpInterfaceErr); } Ok(()) @@ -539,7 +541,8 @@ pub enum CallOpInterfaceErr { CalleeTypeAttrIncorrectTypeErr, } -pub const ATTR_KEY_CALLEE_TYPE: &str = "llvm.callee_type"; +pub static ATTR_KEY_CALLEE_TYPE: crate::Lazy = + crate::Lazy::new(|| "builtin_callee_type".try_into().unwrap()); decl_op_interface! { /// A call-like op: Transfers control from one function to another. @@ -551,7 +554,7 @@ decl_op_interface! { { let op = op.get_operation().deref(ctx); let Some(callee_type_attr) = - op.attributes.get::(ATTR_KEY_CALLEE_TYPE) + op.attributes.get::(&ATTR_KEY_CALLEE_TYPE) else { return verify_err!(op.loc(), CallOpInterfaceErr::CalleeTypeAttrNotFoundErr); }; @@ -572,7 +575,7 @@ decl_op_interface! { /// Type of the callee fn callee_type(&self, ctx: &Context) -> TypePtr { let self_op = self.get_operation().deref(ctx); - let ty_attr = self_op.attributes.get::(ATTR_KEY_CALLEE_TYPE).unwrap(); + let ty_attr = self_op.attributes.get::(&ATTR_KEY_CALLEE_TYPE).unwrap(); TypePtr::from_ptr (ty_attr.get_type(ctx), ctx).expect("Incorrect callee type, not a FunctionType") } diff --git a/src/builtin/ops.rs b/src/builtin/ops.rs index 6d268c0..11ae008 100644 --- a/src/builtin/ops.rs +++ b/src/builtin/ops.rs @@ -3,7 +3,6 @@ use pliron_derive::def_op; use thiserror::Error; use crate::{ - attribute::{attr_cast, AttrObj}, basic_block::BasicBlock, builtin::op_interfaces::ZeroResultInterface, common_traits::{Named, Verify}, @@ -11,14 +10,14 @@ use crate::{ identifier::Identifier, impl_op_interface, impl_verify_succ, input_err, irfmt::{ - parsers::{attr_parser, process_parsed_ssa_defs, spaced, type_parser}, + parsers::{spaced, type_parser}, printers::op::{region, symb_op_header, typed_symb_op_header}, }, linked_list::ContainsLinkedList, location::{Located, Location}, op::{Op, OpObj}, operation::Operation, - parsable::{IntoParseResult, Parsable, ParseResult, StateStream}, + parsable::{Parsable, ParseResult, StateStream}, printable::{self, Printable}, r#type::{TypeObj, TypePtr, Typed}, region::Region, @@ -28,11 +27,10 @@ use crate::{ use super::{ attr_interfaces::TypedAttrInterface, - attributes::{FloatAttr, IntegerAttr, TypeAttr}, + attributes::TypeAttr, op_interfaces::{ self, IsolatedFromAboveInterface, OneRegionInterface, OneResultInterface, - OneResultVerifyErr, SingleBlockRegionInterface, SymbolOpInterface, SymbolTableInterface, - ZeroOpdInterface, + SingleBlockRegionInterface, SymbolOpInterface, SymbolTableInterface, ZeroOpdInterface, }, types::{FunctionType, UnitType}, }; @@ -141,14 +139,18 @@ impl_op_interface!(ZeroResultInterface for ModuleOp {}); /// | key | value | via Interface | /// |-----|-------|-----| /// | [ATTR_KEY_SYM_NAME](super::op_interfaces::ATTR_KEY_SYM_NAME) | [StringAttr](super::attributes::StringAttr) | [SymbolOpInterface] | -/// | [ATTR_KEY_FUNC_TYPE](FuncOp::ATTR_KEY_FUNC_TYPE) | [TypeAttr](super::attributes::TypeAttr) | N/A | +/// | [ATTR_KEY_FUNC_TYPE](func_op::ATTR_KEY_FUNC_TYPE) | [TypeAttr](super::attributes::TypeAttr) | N/A | #[def_op("builtin.func")] pub struct FuncOp {} -impl FuncOp { - /// Attribute key for the constant value. - pub const ATTR_KEY_FUNC_TYPE: &'static str = "func.type"; +pub mod func_op { + use super::*; + /// Attribute key for the function type. + pub static ATTR_KEY_FUNC_TYPE: crate::Lazy = + crate::Lazy::new(|| "builtin_func_type".try_into().unwrap()); +} +impl FuncOp { /// Create a new [FuncOp]. /// The returned function has a single region with an empty `entry` block. pub fn new(ctx: &mut Context, name: &str, ty: TypePtr) -> Self { @@ -163,7 +165,9 @@ impl FuncOp { { let opref = &mut *op.deref_mut(ctx); // Set function type attributes. - opref.attributes.set(Self::ATTR_KEY_FUNC_TYPE, ty_attr); + opref + .attributes + .set(func_op::ATTR_KEY_FUNC_TYPE.clone(), ty_attr); } let opop = FuncOp { op }; opop.set_symbol_name(ctx, name); @@ -176,7 +180,7 @@ impl FuncOp { let opref = self.get_operation().deref(ctx); opref .attributes - .get_as::(Self::ATTR_KEY_FUNC_TYPE) + .get_as::(&func_op::ATTR_KEY_FUNC_TYPE) .unwrap() .get_type() } @@ -257,7 +261,9 @@ impl Parsable for FuncOp { let ty_attr = TypeAttr::new(fty); let opref = &mut *op.deref_mut(ctx); // Set function type attributes. - opref.attributes.set(Self::ATTR_KEY_FUNC_TYPE, ty_attr); + opref + .attributes + .set(func_op::ATTR_KEY_FUNC_TYPE.clone(), ty_attr); } let opop = Box::new(FuncOp { op }); opop.set_symbol_name(ctx, &fname); @@ -285,111 +291,6 @@ impl Verify for FuncOp { impl_op_interface!(ZeroOpdInterface for FuncOp {}); impl_op_interface!(ZeroResultInterface for FuncOp {}); -/// Numeric constant. -/// See MLIR's [arith.constant](https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop). -/// -/// Attributes: -/// -/// | key | value | -/// |-----|-------| -/// |[ATTR_KEY_VALUE](ConstantOp::ATTR_KEY_VALUE) | [IntegerAttr] or [FloatAttr] | -/// -/// Results: -/// -/// | result | description | -/// |-----|-------| -/// | `result` | any type | -#[def_op("builtin.constant")] -pub struct ConstantOp {} - -impl ConstantOp { - /// Attribute key for the constant value. - pub const ATTR_KEY_VALUE: &'static str = "constant.value"; - /// Get the constant value that this Op defines. - pub fn get_value(&self, ctx: &Context) -> AttrObj { - let op = self.get_operation().deref(ctx); - op.attributes.0.get(Self::ATTR_KEY_VALUE).unwrap().clone() - } - - /// Create a new [ConstantOp]. - pub fn new(ctx: &mut Context, value: AttrObj) -> Self { - let result_type = attr_cast::(&*value) - .expect("ConstantOp const value must provide TypedAttrInterface") - .get_type(); - let op = Operation::new( - ctx, - Self::get_opid_static(), - vec![result_type], - vec![], - vec![], - 0, - ); - op.deref_mut(ctx) - .attributes - .0 - .insert(Self::ATTR_KEY_VALUE, value); - ConstantOp { op } - } -} - -impl Printable for ConstantOp { - fn fmt( - &self, - ctx: &Context, - _state: &printable::State, - f: &mut core::fmt::Formatter<'_>, - ) -> core::fmt::Result { - write!( - f, - "{} = {} {}", - self.get_result(ctx).disp(ctx), - self.get_opid().disp(ctx), - self.get_value(ctx).disp(ctx) - ) - } -} - -impl Parsable for ConstantOp { - type Arg = Vec<(Identifier, Location)>; - type Parsed = OpObj; - - fn parse<'a>( - state_stream: &mut StateStream<'a>, - results: Self::Arg, - ) -> ParseResult<'a, Self::Parsed> { - let loc = state_stream.loc(); - - if results.len() != 1 { - input_err!(loc, OneResultVerifyErr(Self::get_opid_static().to_string()))? - } - - let attr = attr_parser().parse_stream(state_stream).into_result()?.0; - - let op = Box::new(Self::new(state_stream.state.ctx, attr)); - process_parsed_ssa_defs(state_stream, &results, op.get_operation())?; - - Ok(op as OpObj).into_parse_result() - } -} - -#[derive(Error, Debug)] -#[error("{}: Unexpected type", ConstantOp::get_opid_static())] -pub struct ConstantOpVerifyErr; - -impl Verify for ConstantOp { - fn verify(&self, ctx: &Context) -> Result<()> { - let loc = self.get_operation().deref(ctx).loc(); - let value = self.get_value(ctx); - if !(value.is::() || value.is::()) { - return verify_err!(loc, ConstantOpVerifyErr); - } - Ok(()) - } -} - -impl_op_interface! (ZeroOpdInterface for ConstantOp {}); -impl_op_interface! (OneResultInterface for ConstantOp {}); - /// A placeholder during parsing to refer to yet undefined operations. /// MLIR [uses](https://github.com/llvm/llvm-project/blob/185b81e034ba60081023b6e59504dfffb560f3e3/mlir/lib/AsmParser/Parser.cpp#L1075) /// [UnrealizedConversionCastOp](https://mlir.llvm.org/docs/Dialects/Builtin/#builtinunrealized_conversion_cast-unrealizedconversioncastop) @@ -460,6 +361,5 @@ impl ForwardRefOp { pub fn register(ctx: &mut Context) { ModuleOp::register(ctx, ModuleOp::parser_fn); FuncOp::register(ctx, FuncOp::parser_fn); - ConstantOp::register(ctx, ConstantOp::parser_fn); ForwardRefOp::register(ctx, ForwardRefOp::parser_fn); } diff --git a/src/debug_info.rs b/src/debug_info.rs index 2861eca..e26e46d 100644 --- a/src/debug_info.rs +++ b/src/debug_info.rs @@ -6,16 +6,18 @@ use crate::{ attribute::{AttrObj, AttributeDict}, basic_block::BasicBlock, builtin::{ - attributes::{SmallDictAttr, StringAttr, UnitAttr, VecAttr}, + attributes::{DictAttr, StringAttr, UnitAttr, VecAttr}, ATTR_KEY_DEBUG_INFO, }, context::{Context, Ptr}, + identifier::Identifier, operation::Operation, vec_exns::VecExtns, }; /// Key into a debug info's variable name. -const DEBUG_INFO_KEY_NAME: &str = "debug_info.name"; +pub static DEBUG_INFO_KEY_NAME: pliron::Lazy = + pliron::Lazy::new(|| "debug_info_name".try_into().unwrap()); fn set_name_from_attr_map( attributes: &mut AttributeDict, @@ -24,12 +26,12 @@ fn set_name_from_attr_map( name: String, ) { let name_attr: AttrObj = StringAttr::new(name).into(); - match attributes.0.entry(ATTR_KEY_DEBUG_INFO) { + match attributes.0.entry(ATTR_KEY_DEBUG_INFO.clone()) { hash_map::Entry::Occupied(mut occupied) => { - let di_dict = occupied.get_mut().downcast_mut::().unwrap(); + let di_dict = occupied.get_mut().downcast_mut::().unwrap(); let expect_msg = "Existing attribute entry for result names incorrect"; let names = di_dict - .lookup_mut(DEBUG_INFO_KEY_NAME) + .lookup_mut(&DEBUG_INFO_KEY_NAME) .expect(expect_msg) .downcast_mut::() .expect(expect_msg); @@ -39,7 +41,11 @@ fn set_name_from_attr_map( let mut names = Vec::new_init(max_idx, |_idx| UnitAttr::new().into()); names[idx] = name_attr; vacant.insert( - SmallDictAttr::new(vec![(DEBUG_INFO_KEY_NAME, VecAttr::new(names).into())]).into(), + DictAttr::new(vec![( + DEBUG_INFO_KEY_NAME.clone(), + VecAttr::new(names).into(), + )]) + .into(), ); } } @@ -51,9 +57,9 @@ fn get_name_from_attr_map( panic_msg: &str, ) -> Option { attributes - .get::(ATTR_KEY_DEBUG_INFO) + .get::(&ATTR_KEY_DEBUG_INFO) .and_then(|di_dict| { - di_dict.lookup(DEBUG_INFO_KEY_NAME).and_then(|names| { + di_dict.lookup(&DEBUG_INFO_KEY_NAME).and_then(|names| { let names = names.downcast_ref::().expect(panic_msg); names.0.get(idx).and_then(|name| { name.downcast_ref::() @@ -116,34 +122,61 @@ pub fn get_block_arg_name(ctx: &Context, block: Ptr, arg_idx: usize) #[cfg(test)] mod tests { + use pliron_derive::def_op; + use crate::{ basic_block::BasicBlock, builtin::{ self, - attributes::IntegerAttr, - ops::ConstantOp, + op_interfaces::{OneResultInterface, ZeroOpdInterface}, types::{IntegerType, Signedness}, }, common_traits::Verify, context::Context, debug_info::{get_block_arg_name, set_block_arg_name}, + dialect::{Dialect, DialectName}, + impl_canonical_syntax, impl_op_interface, impl_verify_succ, op::Op, + operation::Operation, + parsable::Parsable, result::Result, }; - use apint::ApInt; + + #[def_op("test.zero")] + struct ZeroOp; + impl_canonical_syntax!(ZeroOp); + impl_verify_succ!(ZeroOp); + impl_op_interface! (ZeroOpdInterface for ZeroOp {}); + impl_op_interface! (OneResultInterface for ZeroOp {}); + impl ZeroOp { + pub fn new(ctx: &mut Context) -> Self { + let i64_ty = IntegerType::get(ctx, 64, Signedness::Signed); + ZeroOp { + op: Operation::new( + ctx, + Self::get_opid_static(), + vec![i64_ty.into()], + vec![], + vec![], + 0, + ), + } + } + } use super::{get_operation_result_name, set_operation_result_name}; #[test] fn test_op_result_name() -> Result<()> { let mut ctx = Context::new(); - builtin::register(&mut ctx); + let test_dialect = Dialect::new(DialectName::new("test")); + test_dialect.register(&mut ctx); + ZeroOp::register(&mut ctx, ZeroOp::parser_fn); - let i64_ty = IntegerType::get(&mut ctx, 64, Signedness::Signed); - let cop = ConstantOp::new(&mut ctx, IntegerAttr::new(i64_ty, ApInt::from(0)).into()); + let cop = ZeroOp::new(&mut ctx); let op = cop.get_operation(); set_operation_result_name(&ctx, op, 0, "foo".to_string()); - assert!(get_operation_result_name(&ctx, op, 0).unwrap() == "foo"); + assert_eq!(get_operation_result_name(&ctx, op, 0).unwrap(), "foo"); op.deref(&ctx).verify(&ctx)?; Ok(()) } diff --git a/src/identifier.rs b/src/identifier.rs index 984b4b3..da517cb 100644 --- a/src/identifier.rs +++ b/src/identifier.rs @@ -13,16 +13,27 @@ use crate::{ verify_err_noloc, }; -#[derive(Clone, Hash, PartialEq, Eq, Debug)] +#[derive(Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)] /// An [Identifier] must satisfy the regex `[a-zA-Z_][a-zA-Z0-9_]*`. /// Also see [module description](module@crate::identifier). pub struct Identifier(String); +static IDENTIFIER_REGEX: crate::Lazy = + crate::Lazy::new(|| regex::Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap()); + impl Identifier { /// Attempt to construct a new [Identifier] from a [String]. + /// Examples: + /// ``` + /// use pliron::identifier::Identifier; + /// let _: Identifier = "hi12".try_into().expect("Identifier creation error"); + /// let _: Identifier = "A12ab".try_into().expect("Identifier creation error"); + /// TryInto::::try_into("hi12.").expect_err("Malformed identifier not caught"); + /// TryInto::::try_into("12ab").expect_err("Malformed identifier not caught"); + /// TryInto::::try_into(".a12ab").expect_err("Malformed identifier not caught"); + /// ``` pub fn try_new(value: String) -> Result { - let re = regex::Regex::new(r"[a-zA-Z_][a-zA-Z0-9_]*").unwrap(); - if !(re.is_match(&value)) { + if !(IDENTIFIER_REGEX.is_match(&value)) { return verify_err_noloc!(MalformedIdentifierErr(value.clone())); } Ok(Identifier(value)) diff --git a/src/op.rs b/src/op.rs index ef310c4..fd5138a 100644 --- a/src/op.rs +++ b/src/op.rs @@ -39,6 +39,7 @@ use std::{ use thiserror::Error; use crate::{ + attribute::AttributeDict, builtin::types::FunctionType, common_traits::Verify, context::{Context, Ptr}, @@ -413,8 +414,9 @@ macro_rules! decl_op_interface { } /// Printer for an [Op] in canonical syntax. -/// `res_1: type_1, res_2: type_2, ... res_n: type_n = op_id (opd_1, opd_2, ... opd_n) [succ_1, succ_2, ... succ_n] : function-type` -/// TODO: Handle operations with regions, attributes. +/// `res_1: type_1, res_2: type_2, ... res_n: type_n = +/// op_id (opd_1, opd_2, ... opd_n) [succ_1, succ_2, ... succ_n] [attr-dic]: function-type` +/// TODO: Handle operations with regions. pub fn canonical_syntax_fmt( op: OpObj, ctx: &Context, @@ -436,10 +438,11 @@ pub fn canonical_syntax_fmt( } let ret = write!( f, - "{} ({}) [{}] : {}", + "{} ({}) [{}] {}: {}", op.get_opid().disp(ctx), operands.disp(ctx), successors.disp(ctx), + op.attributes.disp(ctx), op_type.disp(ctx), ); ret @@ -463,48 +466,52 @@ pub fn canonical_syntax_parse<'a>( // Results and opid have already been parsed. Continue after that. delimited_list_parser('(', ')', ',', ssa_opd_parser()) .and(spaces().with(delimited_list_parser('[', ']', ',', block_opd_parser()))) + .and(spaces().with(AttributeDict::parser(()))) .skip(spaced(token(':'))) .and((location(), FunctionType::parser(()))) - .then(move |((operands, successors), (fty_loc, fty))| { - let opid = opid.clone(); - let results = results.clone(); - let fty_loc = fty_loc.clone(); - combine::parser(move |parsable_state: &mut StateStream<'a>| { + .then( + move |(((operands, successors), attr_dict), (fty_loc, fty))| { + let opid = opid.clone(); let results = results.clone(); - let ctx = &mut parsable_state.state.ctx; - let results_types = fty.deref(ctx).get_results().to_vec(); - let operands_types = fty.deref(ctx).get_inputs().to_vec(); - if results_types.len() != results.len() { - input_err!( - fty_loc.clone(), - CanonicalSyntaxParseError::ResultsMismatch { - num_res_ty: results_types.len(), - num_res: results.len() - } - )? - } - if operands.len() != operands_types.len() { - input_err!( - fty_loc.clone(), - CanonicalSyntaxParseError::OperandsMismatch { - num_opd_ty: operands_types.len(), - num_opd: operands.len() - } - )? - } - let opr = Operation::new( - ctx, - opid.clone(), - results_types, - operands.clone(), - successors.clone(), - 0, - ); - let op = from_operation(ctx, opr); - process_parsed_ssa_defs(parsable_state, &results, opr)?; - Ok(op).into_parse_result() - }) - }) + let fty_loc = fty_loc.clone(); + combine::parser(move |parsable_state: &mut StateStream<'a>| { + let results = results.clone(); + let ctx = &mut parsable_state.state.ctx; + let results_types = fty.deref(ctx).get_results().to_vec(); + let operands_types = fty.deref(ctx).get_inputs().to_vec(); + if results_types.len() != results.len() { + input_err!( + fty_loc.clone(), + CanonicalSyntaxParseError::ResultsMismatch { + num_res_ty: results_types.len(), + num_res: results.len() + } + )? + } + if operands.len() != operands_types.len() { + input_err!( + fty_loc.clone(), + CanonicalSyntaxParseError::OperandsMismatch { + num_opd_ty: operands_types.len(), + num_opd: operands.len() + } + )? + } + let opr = Operation::new( + ctx, + opid.clone(), + results_types, + operands.clone(), + successors.clone(), + 0, + ); + opr.deref_mut(ctx).attributes = attr_dict.clone(); + let op = from_operation(ctx, opr); + process_parsed_ssa_defs(parsable_state, &results, opr)?; + Ok(op).into_parse_result() + }) + }, + ) .parse_stream(state_stream) .into() } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 878470b..d627996 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,10 +1,14 @@ use apint::ApInt; use pliron::{ + attribute::AttrObj, builtin::{ self, attributes::IntegerAttr, - op_interfaces::{IsTerminatorInterface, OneResultInterface, ZeroResultVerifyErr}, - ops::{ConstantOp, FuncOp, ModuleOp}, + op_interfaces::{ + IsTerminatorInterface, OneResultInterface, OneResultVerifyErr, ZeroOpdInterface, + ZeroResultVerifyErr, + }, + ops::{FuncOp, ModuleOp}, types::{FunctionType, IntegerType, Signedness}, }, common_traits::Verify, @@ -13,11 +17,11 @@ use pliron::{ dialect::{Dialect, DialectName}, identifier::Identifier, impl_op_interface, impl_verify_succ, input_err, - irfmt::parsers::ssa_opd_parser, + irfmt::parsers::{attr_parser, process_parsed_ssa_defs, ssa_opd_parser}, location::{Located, Location}, op::{Op, OpObj}, operation::Operation, - parsable::{self, Parsable, ParseResult}, + parsable::{self, IntoParseResult, Parsable, ParseResult, StateStream}, printable::{self, Printable}, result::Result, use_def_lists::Value, @@ -75,6 +79,88 @@ impl Parsable for ReturnOp { impl_op_interface!(IsTerminatorInterface for ReturnOp {}); impl_verify_succ!(ReturnOp); +#[def_op("test.constant")] +pub struct ConstantOp; +impl_verify_succ!(ConstantOp); +impl_op_interface! (ZeroOpdInterface for ConstantOp {}); +impl_op_interface! (OneResultInterface for ConstantOp {}); +impl ConstantOp { + pub const ATTR_KEY_VALUE: pliron::Lazy = + pliron::Lazy::new(|| "constant_value".try_into().unwrap()); + + pub fn new(ctx: &mut Context, value: u64) -> Self { + let i64_ty = IntegerType::get(ctx, 64, Signedness::Signed); + let int_attr = IntegerAttr::new(i64_ty, ApInt::from_u64(value)); + let op = Operation::new( + ctx, + Self::get_opid_static(), + vec![i64_ty.into()], + vec![], + vec![], + 0, + ); + op.deref_mut(ctx) + .attributes + .0 + .insert(Self::ATTR_KEY_VALUE.clone(), Box::new(int_attr)); + ConstantOp { op } + } + + pub fn get_value(&self, ctx: &Context) -> AttrObj { + let op = self.get_operation().deref(ctx); + op.attributes.0.get(&Self::ATTR_KEY_VALUE).unwrap().clone() + } +} +impl Printable for ConstantOp { + fn fmt( + &self, + ctx: &Context, + _state: &printable::State, + f: &mut core::fmt::Formatter<'_>, + ) -> core::fmt::Result { + write!( + f, + "{} = {} {}", + self.get_result(ctx).disp(ctx), + self.get_opid().disp(ctx), + self.get_value(ctx).disp(ctx) + ) + } +} +impl Parsable for ConstantOp { + type Arg = Vec<(Identifier, Location)>; + type Parsed = OpObj; + + fn parse<'a>( + state_stream: &mut StateStream<'a>, + results: Self::Arg, + ) -> ParseResult<'a, Self::Parsed> { + let loc = state_stream.loc(); + + if results.len() != 1 { + input_err!( + loc.clone(), + OneResultVerifyErr(Self::get_opid_static().to_string()) + )? + } + + let attr = attr_parser().parse_stream(state_stream).into_result()?.0; + let int_attr = match attr.downcast::() { + Ok(int_attr) => int_attr, + Err(attr) => input_err!( + loc, + "Expected integer attribute, but found {}", + attr.disp(state_stream.state.ctx) + )?, + }; + let int_val: u64 = Into::::into(*int_attr).try_to_u64().unwrap(); + let op = Box::new(Self::new(state_stream.state.ctx, int_val)); + process_parsed_ssa_defs(state_stream, &results, op.get_operation())?; + + Ok(op as OpObj).into_parse_result() + } +} + pub fn setup_context_dialects() -> Context { let mut ctx = Context::new(); builtin::register(&mut ctx); @@ -82,6 +168,7 @@ pub fn setup_context_dialects() -> Context { let test_dialect = Dialect::new(DialectName::new("test")); test_dialect.register(&mut ctx); ReturnOp::register(&mut ctx, ReturnOp::parser_fn); + ConstantOp::register(&mut ctx, ConstantOp::parser_fn); ctx } @@ -97,8 +184,7 @@ pub fn const_ret_in_mod(ctx: &mut Context) -> Result<(ModuleOp, FuncOp, Constant let bb = func.get_entry_block(ctx); // Create a `const 0` op and add it to bb. - let zero_const = IntegerAttr::new(i64_ty, ApInt::from(0)); - let const_op = ConstantOp::new(ctx, zero_const.into()); + let const_op = ConstantOp::new(ctx, 0); const_op.get_operation().insert_at_front(bb, ctx); set_operation_result_name(ctx, const_op.get_operation(), 0, "c0".to_string()); diff --git a/tests/ir_construct.rs b/tests/ir_construct.rs index a04a930..83d8339 100644 --- a/tests/ir_construct.rs +++ b/tests/ir_construct.rs @@ -1,7 +1,7 @@ -use apint::ApInt; +use common::ConstantOp; use expect_test::{expect, Expect}; use pliron::{ - builtin::{attributes::IntegerAttr, op_interfaces::OneResultInterface, ops::ConstantOp}, + builtin::op_interfaces::OneResultInterface, common_traits::Verify, context::Context, debug_info::set_operation_result_name, @@ -11,7 +11,6 @@ use pliron::{ operation::Operation, parsable::{self, state_stream_from_iterator, Parsable}, printable::Printable, - r#type::TypePtr, result::Result, walkers::{ self, @@ -57,13 +56,7 @@ fn replace_c0_with_c1() -> Result<()> { // const_ret_in_mod builds a module with a function. let (module_op, _, const_op, _) = const_ret_in_mod(ctx).unwrap(); - // Insert a new constant. - let one_const = IntegerAttr::new( - TypePtr::from_ptr(const_op.result_type(ctx), ctx) - .expect("Expected const_op to have integer type"), - ApInt::from(1), - ); - let const1_op = ConstantOp::new(ctx, one_const.into()); + let const1_op = ConstantOp::new(ctx, 1); const1_op .get_operation() .insert_after(ctx, const_op.get_operation()); @@ -87,12 +80,7 @@ fn replace_c0_with_c1_operand() -> Result<()> { // const_ret_in_mod builds a module with a function. let (module_op, _, const_op, ret_op) = const_ret_in_mod(ctx).unwrap(); - // Insert a new constant. - let one_const = IntegerAttr::new( - TypePtr::from_ptr(const_op.result_type(ctx), ctx).unwrap(), - ApInt::from(1), - ); - let const1_op = ConstantOp::new(ctx, one_const.into()); + let const1_op = ConstantOp::new(ctx, 1); const1_op .get_operation() .insert_after(ctx, const_op.get_operation()); @@ -104,8 +92,8 @@ fn replace_c0_with_c1_operand() -> Result<()> { ^block_1v1(): builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; - c1_op_5v1_res0 = builtin.constant builtin.integer <0x1: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; + c1_op_5v1_res0 = test.constant builtin.integer <0x1: builtin.int>; test.return c0_op_3v1_res0 } }"#]] @@ -123,7 +111,7 @@ fn replace_c0_with_c1_operand() -> Result<()> { ^block_1v1(): builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c1_op_5v1_res0 = builtin.constant builtin.integer <0x1: builtin.int>; + c1_op_5v1_res0 = test.constant builtin.integer <0x1: builtin.int>; test.return c1_op_5v1_res0 } }"#]] @@ -145,7 +133,7 @@ fn print_simple() -> Result<()> { ^block_1v1(): builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; test.return c0_op_3v1_res0 } }"#]] @@ -161,7 +149,7 @@ fn parse_simple() -> Result<()> { ^block_0_0(): builtin.func @foo: builtin.function <() -> (builtin.int )> { ^entry_block_1_0(): - c0_op_2_0_res0 = builtin.constant builtin.integer <0x0: builtin.int >; + c0_op_2_0_res0 = test.constant builtin.integer <0x0: builtin.int >; test.return c0_op_2_0_res0 ^exit(a : builtin.int ): } @@ -200,8 +188,8 @@ fn parse_err_multiple_def() { ^block_0_0(): builtin.func @foo: builtin.function <() -> (builtin.int )> { ^entry_block_1_0(): - c0_op_2_0_res0 = builtin.constant builtin.integer <0x0: builtin.int >; - c0_op_2_0_res0 = builtin.constant builtin.integer <0x0: builtin.int >; + c0_op_2_0_res0 = test.constant builtin.integer <0x0: builtin.int >; + c0_op_2_0_res0 = test.constant builtin.integer <0x0: builtin.int >; test.return c0_op_2_0_res0 ^exit(): } @@ -218,7 +206,7 @@ fn parse_err_multiple_def() { ^block_0_0(): builtin.func @foo: builtin.function <() -> (builtin.int )> { ^entry_block_1_0(): - c0_op_2_0_res0 = builtin.constant builtin.integer <0x0: builtin.int >; + c0_op_2_0_res0 = test.constant builtin.integer <0x0: builtin.int >; test.return c0_op_2_0_res0 ^entry_block_1_0(): } @@ -256,7 +244,7 @@ fn parse_err_block_label_colon() { ^block_0_0(): builtin.func @foo: builtin.function <() -> (builtin.int )> { ^entry_block_1_0(): - c0_op_2_0_res0 = builtin.constant builtin.integer <0x0: builtin.int >; + c0_op_2_0_res0 = test.constant builtin.integer <0x0: builtin.int >; test.return c0_op_2_0_res0 ^exit() } @@ -314,16 +302,16 @@ fn test_preorder_forward_walk() { ^block_1v1(): builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; test.return c0_op_3v1_res0 } } builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; test.return c0_op_3v1_res0 } - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int> + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int> test.return c0_op_3v1_res0 "#]] .assert_eq(&ops); @@ -352,18 +340,18 @@ fn test_postorder_forward_walk() { accum + &op.disp(ctx).to_string() + "\n" }); expect![[r#" - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int> + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int> test.return c0_op_3v1_res0 builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; test.return c0_op_3v1_res0 } builtin.module @bar { ^block_1v1(): builtin.func @foo: builtin.function<() -> (builtin.int)> { ^entry_block_2v1(): - c0_op_3v1_res0 = builtin.constant builtin.integer <0x0: builtin.int>; + c0_op_3v1_res0 = test.constant builtin.integer <0x0: builtin.int>; test.return c0_op_3v1_res0 } } @@ -376,12 +364,7 @@ fn test_walker_find_op() { let ctx = &mut setup_context_dialects(); let (module_op, _, const_op, _) = const_ret_in_mod(ctx).unwrap(); - // Insert a new constant after `const_op`. - let one_const = IntegerAttr::new( - TypePtr::from_ptr(const_op.result_type(ctx), ctx).unwrap(), - ApInt::from(1), - ); - let const1_op = ConstantOp::new(ctx, one_const.into()); + let const1_op = ConstantOp::new(ctx, 1); const1_op .get_operation() .insert_after(ctx, const_op.get_operation());