diff --git a/Cargo.toml b/Cargo.toml index 9160ebd..57debda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,7 @@ keywords = ["pliron", "llvm", "mlir", "compiler"] generational-arena = "0.2" downcast-rs = "1.2.0" rustc-hash = "1.1.0" -# anyerror = "1.0.75" -# thiserror = "1.0.24" +thiserror = "1.0.49" # clap = "4.1.6" apint = "0.2.0" sorted_vector_map = "0.1.0" diff --git a/src/attribute.rs b/src/attribute.rs index a1f1810..38db3eb 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -33,7 +33,7 @@ use crate::{ common_traits::Verify, context::Context, dialect::{Dialect, DialectName}, - error::CompilerError, + error::Result, printable::{self, Printable}, }; @@ -54,7 +54,7 @@ pub trait Attribute: Printable + Verify + Downcast + CastFrom + Sync { Self: Sized; /// Verify all interfaces implemented by this attribute. - fn verify_interfaces(&self, ctx: &Context) -> Result<(), CompilerError>; + fn verify_interfaces(&self, ctx: &Context) -> Result<()>; /// Register this attribute's [AttrId] in the dialect it belongs to. /// **Warning**: No check is made as to whether this attr is already registered @@ -163,7 +163,7 @@ impl Printable for AttrId { } /// Every attribute interface must have a function named `verify` with this type. -pub type AttrInterfaceVerifier = fn(&dyn Attribute, &Context) -> Result<(), CompilerError>; +pub type AttrInterfaceVerifier = fn(&dyn Attribute, &Context) -> Result<()>; /// impl [Attribute] for a rust type. /// @@ -179,7 +179,7 @@ pub type AttrInterfaceVerifier = fn(&dyn Attribute, &Context) -> Result<(), Comp /// ); /// # use pliron::{ /// # impl_attr, printable::{self, Printable}, -/// # context::Context, error::CompilerError, common_traits::Verify, +/// # context::Context, error::Result, common_traits::Verify, /// # attribute::Attribute, /// # }; /// # impl Printable for MyAttr { @@ -194,7 +194,7 @@ pub type AttrInterfaceVerifier = fn(&dyn Attribute, &Context) -> Result<(), Comp /// # } /// /// # impl Verify for MyAttr { -/// # fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { +/// # fn verify(&self, _ctx: &Context) -> Result<()> { /// # todo!() /// # } /// # } @@ -234,7 +234,7 @@ macro_rules! impl_attr { dialect: $crate::dialect::DialectName::new($dialect_name), } } - fn verify_interfaces(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify_interfaces(&self, ctx: &Context) -> $crate::error::Result<()> { let interface_verifiers = paste::paste!{ inventory::iter::<[]> }; @@ -262,7 +262,7 @@ macro_rules! impl_attr { /// ); /// trait MyAttrInterface: Attribute { /// fn monu(&self); -/// fn verify(attr: &dyn Attribute, ctx: &Context) -> Result<(), CompilerError> +/// fn verify(attr: &dyn Attribute, ctx: &Context) -> Result<()> /// where /// Self: Sized, /// { @@ -277,7 +277,7 @@ macro_rules! impl_attr { /// ); /// # use pliron::{ /// # impl_attr, printable::{self, Printable}, -/// # context::Context, error::CompilerError, common_traits::Verify, +/// # context::Context, error::Result, common_traits::Verify, /// # attribute::Attribute, impl_attr_interface /// # }; /// # impl Printable for MyAttr { @@ -287,7 +287,7 @@ macro_rules! impl_attr { /// # } /// /// # impl Verify for MyAttr { -/// # fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { +/// # fn verify(&self, _ctx: &Context) -> Result<()> { /// # todo!() /// # } /// # } diff --git a/src/basic_block.rs b/src/basic_block.rs index 2e7308f..31dd9e7 100644 --- a/src/basic_block.rs +++ b/src/basic_block.rs @@ -7,7 +7,7 @@ use crate::{ common_traits::{Named, Verify}, context::{private::ArenaObj, ArenaCell, Context, Ptr}, debug_info::get_block_arg_name, - error::CompilerError, + error::Result, indented_block, linked_list::{private, ContainsLinkedList, LinkedList}, operation::Operation, @@ -294,7 +294,7 @@ impl ArenaObj for BasicBlock { } impl Verify for BasicBlock { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { self.iter(ctx).try_for_each(|op| op.deref(ctx).verify(ctx)) } } diff --git a/src/common_traits.rs b/src/common_traits.rs index 61b352f..177d6b0 100644 --- a/src/common_traits.rs +++ b/src/common_traits.rs @@ -1,10 +1,10 @@ //! Utility traits such as [Named], [Verify] etc. -use crate::{context::Context, error::CompilerError}; +use crate::{context::Context, error::Result}; /// Check and ensure correctness. pub trait Verify { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError>; + fn verify(&self, ctx: &Context) -> Result<()>; } /// Anything that has a name. diff --git a/src/context.rs b/src/context.rs index d3761ff..aadb79b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -4,6 +4,7 @@ use crate::{ basic_block::BasicBlock, common_traits::Verify, dialect::{Dialect, DialectName}, + error::Result, op::{OpCreator, OpId}, operation::Operation, printable::{self, Printable}, @@ -175,7 +176,7 @@ impl Printable for Ptr { } impl Verify for Ptr { - fn verify(&self, ctx: &Context) -> Result<(), crate::error::CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { self.deref(ctx).verify(ctx) } } diff --git a/src/debug_info.rs b/src/debug_info.rs index fc10356..cdcf2e6 100644 --- a/src/debug_info.rs +++ b/src/debug_info.rs @@ -136,7 +136,7 @@ mod tests { types::{IntegerType, Signedness}, }, }, - error::CompilerError, + error::Result, op::Op, }; use apint::ApInt; @@ -144,7 +144,7 @@ mod tests { use super::{get_operation_result_name, set_operation_result_name}; #[test] - fn test_op_result_name() -> Result<(), CompilerError> { + fn test_op_result_name() -> Result<()> { let mut ctx = Context::new(); dialects::builtin::register(&mut ctx); @@ -158,7 +158,7 @@ mod tests { } #[test] - fn test_block_arg_name() -> Result<(), CompilerError> { + fn test_block_arg_name() -> Result<()> { let mut ctx = Context::new(); dialects::builtin::register(&mut ctx); diff --git a/src/dialect.rs b/src/dialect.rs index 5e11547..60ba638 100644 --- a/src/dialect.rs +++ b/src/dialect.rs @@ -1,6 +1,6 @@ //! [Dialect]s are a mechanism to group related [Op](crate::op::Op)s, [Type](crate::type::Type)s //! and [Attribute](crate::attribute::Attribute)s. -use std::ops::Deref; +use std::{fmt::Display, ops::Deref}; use combine::{easy, ParseResult, Parser}; use rustc_hash::FxHashMap; @@ -32,6 +32,12 @@ impl Printable for DialectName { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { + ::fmt(self, f) + } +} + +impl Display for DialectName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } diff --git a/src/dialects/builtin/attr_interfaces.rs b/src/dialects/builtin/attr_interfaces.rs index d40fdf1..14087b8 100644 --- a/src/dialects/builtin/attr_interfaces.rs +++ b/src/dialects/builtin/attr_interfaces.rs @@ -1,7 +1,7 @@ use crate::{ attribute::Attribute, context::{Context, Ptr}, - error::CompilerError, + error::Result, r#type::TypeObj, }; @@ -11,7 +11,7 @@ pub trait TypedAttrInterface: Attribute { /// Get this attribute's type. fn get_type(&self) -> Ptr; - fn verify(_attr: &dyn Attribute, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_attr: &dyn Attribute, _ctx: &Context) -> Result<()> where Self: Sized, { diff --git a/src/dialects/builtin/attributes.rs b/src/dialects/builtin/attributes.rs index c887deb..7cd34fd 100644 --- a/src/dialects/builtin/attributes.rs +++ b/src/dialects/builtin/attributes.rs @@ -7,7 +7,7 @@ use crate::{ common_traits::Verify, context::{Context, Ptr}, dialect::Dialect, - error::CompilerError, + error::Result, impl_attr, impl_attr_interface, printable::{self, Printable}, r#type::TypeObj, @@ -46,7 +46,7 @@ impl Printable for StringAttr { } impl Verify for StringAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -72,7 +72,7 @@ impl Printable for IntegerAttr { } impl Verify for IntegerAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -119,7 +119,7 @@ impl Printable for FloatAttr { } impl Verify for FloatAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -164,7 +164,7 @@ impl Printable for SmallDictAttr { } impl Verify for SmallDictAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -222,7 +222,7 @@ impl Printable for VecAttr { } impl Verify for VecAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -251,7 +251,7 @@ impl Printable for UnitAttr { } impl Verify for UnitAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } @@ -280,7 +280,7 @@ impl Printable for TypeAttr { } impl Verify for TypeAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } diff --git a/src/dialects/builtin/op_interfaces.rs b/src/dialects/builtin/op_interfaces.rs index 6606d27..7d9eb8f 100644 --- a/src/dialects/builtin/op_interfaces.rs +++ b/src/dialects/builtin/op_interfaces.rs @@ -1,7 +1,9 @@ +use thiserror::Error; + use crate::{ basic_block::BasicBlock, context::{Context, Ptr}, - error::CompilerError, + error::Result, linked_list::ContainsLinkedList, op::{op_cast, Op}, operation::Operation, @@ -9,13 +11,14 @@ use crate::{ r#type::TypeObj, region::Region, use_def_lists::Value, + verify_err, }; use super::attributes::StringAttr; /// An [Op] implementing this interface is a block terminator. pub trait IsTerminatorInterface: Op { - fn verify(_op: &dyn Op, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -42,7 +45,7 @@ pub trait RegionKindInterface: Op { /// must require dominance to hold. fn has_ssa_dominance(idx: usize) -> bool; - fn verify(_op: &dyn Op, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -50,6 +53,10 @@ pub trait RegionKindInterface: Op { } } +#[derive(Error, Debug)] +#[error("Op {0} must have a single region")] +pub struct OneRegionVerifyErr(String); + /// [Op]s that have exactly one region. pub trait OneRegionInterface: Op { fn get_region(&self, ctx: &Context) -> Ptr { @@ -60,20 +67,22 @@ pub trait OneRegionInterface: Op { } /// Checks that the operation has exactly one region. - fn verify(op: &dyn Op, ctx: &Context) -> Result<(), CompilerError> + fn verify(op: &dyn Op, ctx: &Context) -> Result<()> where Self: Sized, { let self_op = op.get_operation().deref(ctx); if self_op.regions.len() != 1 { - return Err(CompilerError::VerificationError { - msg: format!("Op {} must have single region.", op.get_opid().disp(ctx)), - }); + return verify_err!(OneRegionVerifyErr(op.get_opid().disp(ctx).to_string())); } Ok(()) } } +#[derive(Error, Debug)] +#[error("Op {0} must only have regions with single block")] +pub struct SingleBlockRegionVerifyErr(String); + /// [Op]s with regions that have a single block. pub trait SingleBlockRegionInterface: Op { /// Get the single body block in `region_idx`. @@ -93,19 +102,16 @@ pub trait SingleBlockRegionInterface: Op { } /// Checks that the operation has regions with single block. - fn verify(op: &dyn Op, ctx: &Context) -> Result<(), CompilerError> + fn verify(op: &dyn Op, ctx: &Context) -> Result<()> where Self: Sized, { let self_op = op.get_operation().deref(ctx); for region in &self_op.regions { if region.deref(ctx).iter(ctx).count() != 1 { - return Err(CompilerError::VerificationError { - msg: format!( - "SingleBlockRegion Op {} must have single region.", - self_op.get_opid().disp(ctx) - ), - }); + return verify_err!(SingleBlockRegionVerifyErr( + self_op.get_opid().disp(ctx).to_string() + )); } } Ok(()) @@ -130,7 +136,7 @@ pub trait SymbolOpInterface: Op { .insert(super::ATTR_KEY_SYM_NAME, name_attr); } - fn verify(_op: &dyn Op, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -138,6 +144,10 @@ pub trait SymbolOpInterface: Op { } } +#[derive(Error, Debug)] +#[error("Op {0} must have single result")] +pub struct OneResultVerifyErr(String); + /// An [Op] having exactly one result. pub trait OneResultInterface: Op { /// Get the single result defined by this Op. @@ -156,18 +166,13 @@ pub trait OneResultInterface: Op { .unwrap_or_else(|| panic!("{} must have exactly one result", self.get_opid().disp(ctx))) } - fn verify(op: &dyn Op, ctx: &Context) -> Result<(), CompilerError> + fn verify(op: &dyn Op, ctx: &Context) -> Result<()> where Self: Sized, { let op = &*op.get_operation().deref(ctx); if op.get_num_results() != 1 { - return Err(CompilerError::VerificationError { - msg: format!( - "Expected exactly one result on operation {}", - op.get_opid().disp(ctx) - ), - }); + return verify_err!(OneResultVerifyErr(op.get_opid().disp(ctx).to_string())); } Ok(()) } @@ -178,7 +183,7 @@ pub trait CallOpInterface: Op { /// Returns the symbol of the callee of this call operation. fn get_callee_sym(&self, ctx: &Context) -> String; - fn verify(_op: &dyn Op, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -203,3 +208,21 @@ pub fn get_callees_syms(ctx: &Context, op: Ptr) -> Vec { } callees } + +#[derive(Error, Debug)] +#[error("Op {0} must not have any operand")] +pub struct ZeroOpdVerifyErr(String); + +/// An [Op] having no operands. +pub trait ZeroOpdInterface: Op { + fn verify(op: &dyn Op, ctx: &Context) -> Result<()> + where + Self: Sized, + { + let op = &*op.get_operation().deref(ctx); + if op.get_num_operands() != 0 { + return verify_err!(ZeroOpdVerifyErr(op.get_opid().disp(ctx).to_string())); + } + Ok(()) + } +} diff --git a/src/dialects/builtin/ops.rs b/src/dialects/builtin/ops.rs index faa080d..275ff64 100644 --- a/src/dialects/builtin/ops.rs +++ b/src/dialects/builtin/ops.rs @@ -1,3 +1,5 @@ +use thiserror::Error; + use crate::{ attribute::{self, attr_cast, AttrObj}, basic_block::BasicBlock, @@ -5,13 +7,14 @@ use crate::{ context::{Context, Ptr}, declare_op, dialect::Dialect, - error::CompilerError, + error::Result, impl_op_interface, linked_list::ContainsLinkedList, op::Op, operation::Operation, printable::{self, Printable}, r#type::TypeObj, + verify_err, }; use super::{ @@ -19,6 +22,7 @@ use super::{ attributes::{FloatAttr, IntegerAttr, TypeAttr}, op_interfaces::{ OneRegionInterface, OneResultInterface, SingleBlockRegionInterface, SymbolOpInterface, + ZeroOpdInterface, }, types::FunctionType, }; @@ -60,7 +64,7 @@ impl Printable for ModuleOp { } impl Verify for ModuleOp { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } @@ -178,19 +182,23 @@ impl Printable for FuncOp { } } +#[derive(Error, Debug)] +pub enum FuncOpVerifyErr { + #[error("function does not have function type")] + NotFuncType, + #[error("incorrect number of results or operands")] + IncorrectNumResultsOpds, +} + impl Verify for FuncOp { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { let ty = self.get_type(ctx); if !(ty.deref(ctx).is::()) { - return Err(CompilerError::VerificationError { - msg: "Unexpected Func type".to_string(), - }); + return verify_err!(FuncOpVerifyErr::NotFuncType); } let op = &*self.get_operation().deref(ctx); if op.get_num_results() != 0 || op.get_num_operands() != 0 { - return Err(CompilerError::VerificationError { - msg: "Incorrect number of results or operands".to_string(), - }); + return verify_err!(FuncOpVerifyErr::IncorrectNumResultsOpds); } Ok(()) } @@ -260,24 +268,21 @@ impl Printable for ConstantOp { } } +#[derive(Error, Debug)] +#[error("{}: Unexpected type", ConstantOp::get_opid_static())] +pub struct ConstantOpVerifyErr; + impl Verify for ConstantOp { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { let value = self.get_value(ctx); if !(value.is::() || value.is::()) { - return Err(CompilerError::VerificationError { - msg: "Unexpected constant type".to_string(), - }); - } - let op = &*self.get_operation().deref(ctx); - if op.get_num_operands() != 0 { - return Err(CompilerError::VerificationError { - msg: "Incorrect number of results or operands".to_string(), - }); + return verify_err!(ConstantOpVerifyErr); } Ok(()) } } +impl_op_interface! (ZeroOpdInterface for ConstantOp {}); impl_op_interface! (OneResultInterface for ConstantOp {}); pub fn register(ctx: &mut Context, dialect: &mut Dialect) { diff --git a/src/dialects/builtin/types.rs b/src/dialects/builtin/types.rs index 268bde5..442a9a6 100644 --- a/src/dialects/builtin/types.rs +++ b/src/dialects/builtin/types.rs @@ -8,7 +8,7 @@ use crate::{ common_traits::Verify, context::{Context, Ptr}, dialect::Dialect, - error::CompilerError, + error::Result, impl_type, parsable::{spaced, Parsable, StateStream}, printable::{self, ListSeparator, Printable, PrintableIter}, @@ -95,7 +95,7 @@ impl Printable for IntegerType { } impl Verify for IntegerType { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } @@ -174,7 +174,7 @@ impl Parsable for FunctionType { } impl Verify for FunctionType { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { todo!() } } diff --git a/src/dialects/llvm/ops.rs b/src/dialects/llvm/ops.rs index 2320600..0febf23 100644 --- a/src/dialects/llvm/ops.rs +++ b/src/dialects/llvm/ops.rs @@ -4,7 +4,7 @@ use crate::{ declare_op, dialect::Dialect, dialects::builtin::op_interfaces::IsTerminatorInterface, - error::CompilerError, + error::Result, impl_op_interface, op::Op, operation::Operation, @@ -53,7 +53,7 @@ impl Printable for ReturnOp { } impl Verify for ReturnOp { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } diff --git a/src/dialects/llvm/types.rs b/src/dialects/llvm/types.rs index 7ff103d..879f4c7 100644 --- a/src/dialects/llvm/types.rs +++ b/src/dialects/llvm/types.rs @@ -2,14 +2,16 @@ use crate::{ common_traits::Verify, context::{Context, Ptr}, dialect::Dialect, - error::CompilerError, - impl_type, + error::Result, + impl_type, input_err, parsable::{identifier, spaced, to_parse_result, Parsable, StateStream}, printable::{self, Printable, PrintableIter}, r#type::{type_parser, Type, TypeObj}, storage_uniquer::TypeValueHash, + verify_err, }; use combine::{between, easy, optional, sep_by, token, ParseResult, Parser}; +use thiserror::Error; use std::hash::Hash; @@ -77,7 +79,7 @@ impl StructType { ctx: &mut Context, name: &str, fields: Option>, - ) -> Result, CompilerError> { + ) -> Result> { let self_ptr = Type::register_instance( StructType { name: Some(name.to_string()), @@ -95,9 +97,7 @@ impl StructType { self_ref.fields = fields; self_ref.finalized = true; } else if self_ref.fields != fields { - return Err(CompilerError::BadInput { - msg: format!("Struct {name} already exists and is different"), - }); + return input_err!(StructErr::ExistingMismatch(name.into())); } } Ok(self_ptr) @@ -150,12 +150,20 @@ impl StructType { } } +#[derive(Debug, Error)] +pub enum StructErr { + #[error("struct {0} is not finalized")] + NotFinalized(String), + #[error("struct {0} already exists and is different")] + ExistingMismatch(String), +} + impl Verify for StructType { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { if !self.finalized { - return Err(CompilerError::VerificationError { - msg: "Struct not finalized".to_string(), - }); + return verify_err!(StructErr::NotFinalized( + self.name.clone().unwrap_or("".into()) + )); } Ok(()) } @@ -344,8 +352,8 @@ impl Parsable for PointerType { } impl Verify for PointerType { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { - todo!() + fn verify(&self, _ctx: &Context) -> Result<()> { + Ok(()) } } @@ -365,16 +373,16 @@ mod tests { dialects::{ self, builtin::types::{IntegerType, Signedness}, - llvm::types::{PointerType, StructField, StructType}, + llvm::types::{PointerType, StructErr, StructField, StructType}, }, - error::CompilerError, + error::{Error, ErrorKind, Result}, parsable::{self, state_stream_from_iterator}, printable::Printable, r#type::{type_parser, Type}, }; #[test] - fn test_struct() -> Result<(), CompilerError> { + fn test_struct() -> Result<()> { let mut ctx = Context::new(); let int64_ptr = IntegerType::get(&mut ctx, 64, Signedness::Signless); @@ -532,8 +540,8 @@ mod tests { let expected_err_msg = expect![[r#" Parse error at line: 1, column: 15 - Compilation failed. - Struct My1 already exists and is different + Compilation error: invalid input. + struct My1 already exists and is different "#]]; expected_err_msg.assert_eq(&err_msg); @@ -542,9 +550,10 @@ mod tests { parsable::State { ctx: &mut ctx }, ); let res = type_parser().parse(state_stream).unwrap().0; - let expected_err_msg = expect![[r#" - Internal compiler error. Verification failed. - Struct not finalized"#]]; - expected_err_msg.assert_eq(&res.verify(&ctx).unwrap_err().to_string()) + matches!( + &res.verify(&ctx), + Err (Error { kind: ErrorKind::VerificationFailed, err }) + if err.is::() + ); } } diff --git a/src/error.rs b/src/error.rs index 4f6f018..c8bc087 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,23 +1,109 @@ //! Utilities for error handling -use std::fmt::Display; +use thiserror::Error; /// The kinds of errors we have during compilation. -#[derive(Debug)] -pub enum CompilerError { - BadInput { msg: String }, - VerificationError { msg: String }, +#[derive(Debug, Error)] +pub enum ErrorKind { + #[error("invalid input")] + InvalidInput, + #[error("verification failed")] + VerificationFailed, } -impl Display for CompilerError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CompilerError::BadInput { msg } => { - write!(f, "Compilation failed.\n{}", msg) - } - CompilerError::VerificationError { msg } => { - write!(f, "Internal compiler error. Verification failed.\n{}", msg) - } - } +/// An error object that can hold any [std::error::Error]. +#[derive(Debug, Error)] +#[error("Compilation error: {kind}.\n{err}")] +pub struct Error { + pub kind: ErrorKind, + pub err: Box, +} + +/// Type alias for [std::result::Result] with the error type set to [struct@Error] +pub type Result = std::result::Result; + +#[doc(hidden)] +#[derive(Debug, Error)] +#[error("{0}")] +pub struct StringError(pub String); + +/// Specify [ErrorKind] and create an error from any [std::error::Error] object. +/// The macro also accepts [format!] like arguments to create one-off errors. +/// It may be shorter to just use [verify_err!](crate::verify_err) or +/// [input_err!](crate::verify_err) instead. +#[macro_export] +macro_rules! create_err { + ($kind: expr, $str: literal $($t:tt)*) => { + $crate::create_err!($kind, $crate::error::StringError(format!($str $($t)*))) + }; + ($kind: expr, $err: expr) => { + Err($crate::error::Error { + kind: $kind, + err: Box::new($err), + }) + }; +} + +/// Create an [ErrorKind::VerificationFailed] error from any [std::error::Error] object. +/// The macro also accepts [format!] like arguments to create one-off errors. +/// ```rust +/// use thiserror::Error; +/// use pliron::{verify_err, error::{Result, ErrorKind, Error}}; +/// +/// #[derive(Error, Debug)] +/// #[error("sample error")] +/// pub struct SampleErr; +/// +/// assert!( +/// matches!( +/// verify_err!(SampleErr), +/// Result::<()>::Err(Error { +/// kind: ErrorKind::VerificationFailed, +/// err +/// }) if err.is::() +/// )); +/// +/// let res_msg: Result<()> = verify_err!("Some formatted {}", 0); +/// assert_eq!( +/// res_msg.unwrap_err().err.to_string(), +/// "Some formatted 0" +/// ); +/// ``` +#[macro_export] +macro_rules! verify_err { + ($($t:tt)*) => { + $crate::create_err!($crate::error::ErrorKind::VerificationFailed, $($t)*) + } +} + +/// Create an [ErrorKind::InvalidInput] error from any [std::error::Error] object. +/// The macro also accepts [format!] like arguments to create one-off errors. +/// ```rust +/// use thiserror::Error; +/// use pliron::{input_err, error::{Result, ErrorKind, Error}}; +/// +/// #[derive(Error, Debug)] +/// #[error("sample error")] +/// pub struct SampleErr; +/// +/// assert!( +/// matches!( +/// input_err!(SampleErr), +/// Result::<()>::Err(Error { +/// kind: ErrorKind::InvalidInput, +/// err +/// }) if err.is::() +/// )); +/// +/// let res_msg: Result<()> = input_err!("Some formatted {}", 0); +/// assert_eq!( +/// res_msg.unwrap_err().err.to_string(), +/// "Some formatted 0" +/// ); +/// ``` +#[macro_export] +macro_rules! input_err { + ($($t:tt)*) => { + $crate::create_err!($crate::error::ErrorKind::InvalidInput, $($t)*) } } diff --git a/src/op.rs b/src/op.rs index d0dae14..07129a8 100644 --- a/src/op.rs +++ b/src/op.rs @@ -22,7 +22,7 @@ //! [OpObj]s can be downcasted to their concrete types using //! [downcast_rs](https://docs.rs/downcast-rs/1.2.0/downcast_rs/index.html#example-without-generics). -use std::ops::Deref; +use std::{fmt::Display, ops::Deref}; use downcast_rs::{impl_downcast, Downcast}; use intertrait::{cast::CastRef, CastFrom}; @@ -31,7 +31,7 @@ use crate::{ common_traits::Verify, context::{Context, Ptr}, dialect::{Dialect, DialectName}, - error::CompilerError, + error::Result, operation::Operation, printable::{self, Printable}, }; @@ -62,6 +62,12 @@ impl Printable for OpName { _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { + ::fmt(self, f) + } +} + +impl Display for OpName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } @@ -76,11 +82,17 @@ pub struct OpId { impl Printable for OpId { fn fmt( &self, - ctx: &Context, + _ctx: &Context, _state: &printable::State, f: &mut core::fmt::Formatter<'_>, ) -> core::fmt::Result { - write!(f, "{}.{}", self.dialect.disp(ctx), self.name.disp(ctx)) + ::fmt(self, f) + } +} + +impl Display for OpId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.dialect, self.name) } } @@ -105,7 +117,7 @@ pub trait Op: Downcast + Verify + Printable + CastFrom { Self: Sized; /// Verify all interfaces implemented by this op. - fn verify_interfaces(&self, ctx: &Context) -> Result<(), CompilerError>; + fn verify_interfaces(&self, ctx: &Context) -> Result<()>; /// Register Op in Context and add it to dialect. fn register(ctx: &mut Context, dialect: &mut Dialect) @@ -145,7 +157,7 @@ pub fn op_impls(op: &dyn Op) -> bool { } /// Every op interface must have a function named `verify` with this type. -pub type OpInterfaceVerifier = fn(&dyn Op, &Context) -> Result<(), CompilerError>; +pub type OpInterfaceVerifier = fn(&dyn Op, &Context) -> Result<()>; /// Declare an [Op] /// @@ -159,7 +171,7 @@ pub type OpInterfaceVerifier = fn(&dyn Op, &Context) -> Result<(), CompilerError /// ); /// # use pliron::{ /// # op::Op, declare_op, printable::{self, Printable}, -/// # context::Context, error::CompilerError, common_traits::Verify +/// # context::Context, error::Result, common_traits::Verify /// # }; /// # impl Printable for MyOp { /// # fn fmt(&self, _ctx: &Context, _state: &printable::State, _f: &mut core::fmt::Formatter<'_>) @@ -170,7 +182,7 @@ pub type OpInterfaceVerifier = fn(&dyn Op, &Context) -> Result<(), CompilerError /// # } /// /// # impl Verify for MyOp { -/// # fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { +/// # fn verify(&self, _ctx: &Context) -> Result<()> { /// # todo!() /// # } /// # } @@ -216,7 +228,7 @@ macro_rules! declare_op { } } - fn verify_interfaces(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify_interfaces(&self, ctx: &Context) -> $crate::error::Result<()> { let interface_verifiers = paste::paste!{ inventory::iter::<[]> }; @@ -242,7 +254,7 @@ macro_rules! declare_op { /// ); /// trait MyOpInterface: Op { /// fn gubbi(&self); -/// fn verify(op: &dyn Op, ctx: &Context) -> Result<(), CompilerError> +/// fn verify(op: &dyn Op, ctx: &Context) -> Result<()> /// where /// Self: Sized, /// { @@ -257,7 +269,7 @@ macro_rules! declare_op { /// ); /// # use pliron::{ /// # op::Op, declare_op, impl_op_interface, -/// # printable::{self, Printable}, context::Context, error::CompilerError, +/// # printable::{self, Printable}, context::Context, error::Result, /// # common_traits::Verify /// # }; /// # impl Printable for MyOp { @@ -267,7 +279,7 @@ macro_rules! declare_op { /// # } /// /// # impl Verify for MyOp { -/// # fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { +/// # fn verify(&self, _ctx: &Context) -> Result<()> { /// # todo!() /// # } /// # } diff --git a/src/operation.rs b/src/operation.rs index 031f9f9..008a1c1 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -5,6 +5,7 @@ use std::marker::PhantomData; use rustc_hash::FxHashMap; +use thiserror::Error; use crate::{ attribute::AttrObj, @@ -12,7 +13,7 @@ use crate::{ common_traits::{Named, Verify}, context::{private::ArenaObj, ArenaCell, Context, Ptr}, debug_info, - error::CompilerError, + error::Result, linked_list::{private, LinkedList}, op::{self, OpId, OpObj}, printable::{self, Printable}, @@ -20,6 +21,7 @@ use crate::{ region::Region, use_def_lists::{DefNode, DefTrait, DefUseParticipant, Use, UseNode, Value}, vec_exns::VecExtns, + verify_err, }; /// Represents the result of an [Operation]. @@ -417,17 +419,19 @@ impl Printable for Operand { } } +#[derive(Error, Debug)] +#[error("operand is not a use of its def")] +pub struct DefUseVerifyErr; + impl Verify for Operand { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { if !self .r#use .get_def() .get_defnode_ref(ctx) .has_use_of(&self.into()) { - Err(CompilerError::VerificationError { - msg: "Operand is not a use of its def".to_string(), - }) + verify_err!(DefUseVerifyErr) } else { Ok(()) } @@ -435,7 +439,7 @@ impl Verify for Operand { } impl Verify for Operation { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { for _attr in self.attributes.values() { // TODO. // attr.verify(ctx)?; diff --git a/src/parsable.rs b/src/parsable.rs index 827339f..cd582d6 100644 --- a/src/parsable.rs +++ b/src/parsable.rs @@ -1,6 +1,6 @@ //! IR objects that can be parsed from their text representation. -use crate::{context::Context, error::CompilerError}; +use crate::{context::Context, error::Result}; use combine::{ easy, parser::char::spaces, @@ -143,10 +143,10 @@ pub fn spaced, Output>( combine::between(spaces(), spaces(), parser) } -/// Convert `Result<_, CompilerError>` into [ParseResult], +/// Convert [Result] into [ParseResult], /// Helps in returning errors when writing a parser. pub fn to_parse_result<'a, T>( - result: Result, + result: Result, position: SourcePosition, ) -> ParseResult>> { match result { diff --git a/src/region.rs b/src/region.rs index 9a11687..374b4c1 100644 --- a/src/region.rs +++ b/src/region.rs @@ -3,7 +3,7 @@ use crate::{ basic_block::BasicBlock, common_traits::Verify, context::{private::ArenaObj, Context, Ptr}, - error::CompilerError, + error::Result, indented_block, linked_list::{private, ContainsLinkedList}, operation::Operation, @@ -101,7 +101,7 @@ impl ArenaObj for Region { } impl Verify for Region { - fn verify(&self, ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { self.iter(ctx).try_for_each(|op| op.deref(ctx).verify(ctx)) } } diff --git a/src/type.rs b/src/type.rs index b0101c1..a79d0c7 100644 --- a/src/type.rs +++ b/src/type.rs @@ -11,7 +11,8 @@ use crate::common_traits::Verify; use crate::context::{private::ArenaObj, ArenaCell, Context, Ptr}; use crate::dialect::{Dialect, DialectName}; -use crate::error::CompilerError; +use crate::error::Result; +use crate::input_err; use crate::parsable::{identifier, spaced, to_parse_result, Parsable, ParserFn, StateStream}; use crate::printable::{self, Printable}; use crate::storage_uniquer::TypeValueHash; @@ -246,7 +247,7 @@ impl Printable for TypeObj { } impl Verify for TypeObj { - fn verify(&self, ctx: &Context) -> Result<(), crate::error::CompilerError> { + fn verify(&self, ctx: &Context) -> Result<()> { self.as_ref().verify(ctx) } } @@ -265,7 +266,7 @@ impl Verify for TypeObj { /// ); /// # use pliron::{ /// # impl_type, printable::{self, Printable}, -/// # context::Context, error::CompilerError, common_traits::Verify, +/// # context::Context, error::Result, common_traits::Verify, /// # storage_uniquer::TypeValueHash, r#type::Type, /// # }; /// # impl Printable for MyType { @@ -280,7 +281,7 @@ impl Verify for TypeObj { /// # } /// # /// # impl Verify for MyType { -/// # fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { +/// # fn verify(&self, _ctx: &Context) -> Result<()> { /// # todo!() /// # } /// # } @@ -334,9 +335,7 @@ pub fn type_parse<'a>( .expect("Dialect name parsed but dialect isn't registered"); let Some(type_parser) = dialect.types.get(&type_id) else { return to_parse_result( - Err(CompilerError::BadInput { - msg: format!("Unregistered type {}.", type_id.disp(state.ctx)), - }), + input_err!("Unregistered type {}", type_id.disp(state.ctx)), position, ) .into_result(); @@ -381,8 +380,8 @@ mod test { let expected_err_msg = expect![[r#" Parse error at line: 1, column: 1 - Compilation failed. - Unregistered type builtin.some. + Compilation error: invalid input. + Unregistered type builtin.some "#]]; expected_err_msg.assert_eq(&err_msg); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 2814798..dc2fb7d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,7 +14,7 @@ use pliron::{ }, llvm::ops::ReturnOp, }, - error::CompilerError, + error::Result, op::Op, printable::Printable, }; @@ -28,9 +28,7 @@ pub fn setup_context_dialects() -> Context { // Create a print a module "bar", with a function "foo" // containing a single `return 0`. -pub fn const_ret_in_mod( - ctx: &mut Context, -) -> Result<(ModuleOp, FuncOp, ConstantOp, ReturnOp), CompilerError> { +pub fn const_ret_in_mod(ctx: &mut Context) -> Result<(ModuleOp, FuncOp, ConstantOp, ReturnOp)> { let i64_ty = IntegerType::get(ctx, 64, Signedness::Signed); let module = ModuleOp::new(ctx, "bar"); // Our function is going to have type () -> (). diff --git a/tests/interfaces.rs b/tests/interfaces.rs index 0569614..21ed1db 100644 --- a/tests/interfaces.rs +++ b/tests/interfaces.rs @@ -8,12 +8,13 @@ use pliron::{ dialect::{Dialect, DialectName}, dialects::{ builtin::{ - attr_interfaces::TypedAttrInterface, attributes::StringAttr, - op_interfaces::OneResultInterface, + attr_interfaces::TypedAttrInterface, + attributes::StringAttr, + op_interfaces::{OneResultInterface, OneResultVerifyErr}, }, llvm::ops::ReturnOp, }, - error::CompilerError, + error::{Error, ErrorKind, Result}, impl_attr, impl_attr_interface, impl_op_interface, op::Op, operation::Operation, @@ -39,7 +40,7 @@ impl Printable for ZeroResultOp { } impl Verify for ZeroResultOp { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } @@ -67,12 +68,16 @@ fn check_intrf_verfiy_errs() { assert!(matches!( module_op.get_operation().verify(ctx), - Err(CompilerError::VerificationError { msg }) - if msg == "Expected exactly one result on operation test.zero_results")); + Err(Error { + kind: ErrorKind::VerificationFailed, + err + }) + if err.is::() + )) } pub trait TestOpInterface: Op { - fn verify(_op: &dyn Op, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -84,7 +89,7 @@ impl_op_interface!(TestOpInterface for ReturnOp {}); impl_op_interface!(TestOpInterface for pliron::dialects::builtin::ops::ModuleOp {}); pub trait TestAttrInterface: Attribute { - fn verify(_op: &dyn Attribute, _ctx: &Context) -> Result<(), CompilerError> + fn verify(_op: &dyn Attribute, _ctx: &Context) -> Result<()> where Self: Sized, { @@ -100,7 +105,7 @@ struct MyAttr { ty: Ptr, } impl Verify for MyAttr { - fn verify(&self, _ctx: &Context) -> Result<(), CompilerError> { + fn verify(&self, _ctx: &Context) -> Result<()> { Ok(()) } } diff --git a/tests/ir_construct.rs b/tests/ir_construct.rs index 00a9391..ad8760e 100644 --- a/tests/ir_construct.rs +++ b/tests/ir_construct.rs @@ -6,7 +6,7 @@ use pliron::{ dialects::builtin::{ attributes::IntegerAttr, op_interfaces::OneResultInterface, ops::ConstantOp, }, - error::CompilerError, + error::Result, op::Op, operation::Operation, printable::Printable, @@ -18,7 +18,7 @@ mod common; // Test erasing the entire top module. #[test] -fn construct_and_erase() -> Result<(), CompilerError> { +fn construct_and_erase() -> Result<()> { let ctx = &mut setup_context_dialects(); let module_op = const_ret_in_mod(ctx)?.0.get_operation(); Operation::erase(module_op, ctx); @@ -41,7 +41,7 @@ fn removed_used_op() { // Testing replacing all uses of c0 with c1. #[test] -fn replace_c0_with_c1() -> Result<(), CompilerError> { +fn replace_c0_with_c1() -> Result<()> { let ctx = &mut setup_context_dialects(); // const_ret_in_mod builds a module with a function. @@ -67,7 +67,7 @@ fn replace_c0_with_c1() -> Result<(), CompilerError> { // Replace ret_op's first operand (which is c0) with c1. // Erase c0. Verify. #[test] -fn replace_c0_with_c1_operand() -> Result<(), CompilerError> { +fn replace_c0_with_c1_operand() -> Result<()> { let ctx = &mut setup_context_dialects(); // const_ret_in_mod builds a module with a function.