diff --git a/macro/src/dialect/operation.rs b/macro/src/dialect/operation.rs index 0987db02ba..0cd32865d0 100644 --- a/macro/src/dialect/operation.rs +++ b/macro/src/dialect/operation.rs @@ -76,10 +76,11 @@ impl<'a> OperationField<'a> { let (param_type, return_type) = { if ac.is_unit() { (quote! { bool }, quote! { bool }) - } else if ac.is_optional() { - (quote! { #kind_type<'c> }, quote! { Option<#kind_type<'c>> }) } else { - (quote! { #kind_type<'c> }, quote! { #kind_type<'c> }) + ( + quote! { #kind_type<'c> }, + quote! { Result<#kind_type<'c>, ::melior::Error> }, + ) } }; let sanitized = sanitize_name_snake(name); @@ -108,7 +109,7 @@ impl<'a> OperationField<'a> { } else { ( quote! { ::melior::ir::Region<'c> }, - quote! { ::melior::ir::RegionRef<'c, '_> }, + quote! { Result<::melior::ir::RegionRef<'c, '_>, ::melior::Error> }, ) } }; @@ -142,7 +143,7 @@ impl<'a> OperationField<'a> { } else { ( quote! { &::melior::ir::Block<'c> }, - quote! { ::melior::ir::BlockRef<'c, '_> }, + quote! { Result<::melior::ir::BlockRef<'c, '_>, ::melior::Error> }, ) } }; @@ -201,16 +202,23 @@ impl<'a> OperationField<'a> { if tc.is_optional() { ( quote! { #param_kind_type }, - quote! { Option<#return_kind_type> }, + quote! { Result<#return_kind_type, ::melior::Error> }, ) } else { ( quote! { &[#param_kind_type] }, - quote! { impl Iterator }, + if let VariadicKind::AttrSized {} = variadic_info { + quote! { Result, ::melior::Error> } + } else { + quote! { impl Iterator } + }, ) } } else { - (param_kind_type, return_kind_type) + ( + param_kind_type, + quote!(Result<#return_kind_type, ::melior::Error>), + ) } }; diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs index 74e9f93389..de6525a040 100644 --- a/macro/src/dialect/operation/accessors.rs +++ b/macro/src/dialect/operation/accessors.rs @@ -19,6 +19,12 @@ impl<'a> OperationField<'a> { .variadic_info .as_ref() .expect("operands and results need variadic info"); + let error_variant = match &self.kind { + FieldKind::Operand(_) => quote!(OperandNotFound), + FieldKind::Result(_) => quote!(ResultNotFound), + _ => unreachable!(), + }; + let name = self.name; Some(match variadic_kind { VariadicKind::Simple { @@ -32,9 +38,9 @@ impl<'a> OperationField<'a> { // elements. quote! { if self.operation.#count() < #len { - None + Err(::melior::Error::#error_variant(#name)) } else { - self.operation.#kind_ident(#index).ok() + self.operation.#kind_ident(#index) } } } else { @@ -49,16 +55,14 @@ impl<'a> OperationField<'a> { } else if *seen_variable_length { // Single element after variable length group // Compute the length of that variable group and take the next element - let error = format!("operation should have this {}", kind); quote! { let group_length = self.operation.#count() - #len + 1; - self.operation.#kind_ident(#index + group_length - 1).expect(#error) + self.operation.#kind_ident(#index + group_length - 1) } } else { // All elements so far are singular - let error = format!("operation should have this {}", kind); quote! { - self.operation.#kind_ident(#index).expect(#error) + self.operation.#kind_ident(#index) } } } @@ -67,7 +71,6 @@ impl<'a> OperationField<'a> { num_preceding_simple, num_preceding_variadic, } => { - let error = format!("operation should have this {}", kind); let compute_start_length = quote! { let total_var_len = self.operation.#count() - #num_variable_length + 1; let group_len = total_var_len / #num_variable_length; @@ -79,47 +82,42 @@ impl<'a> OperationField<'a> { } } else { quote! { - self.operation.#kind_ident(start).expect(#error) + self.operation.#kind_ident(start) } }; quote! { #compute_start_length #get_elements } } VariadicKind::AttrSized {} => { - let error = format!("operation should have this {}", kind); let attribute_name = format!("{}_segment_sizes", kind); - let attribute_missing_error = - format!("operation has {} attribute", attribute_name); let compute_start_length = quote! { let attribute = ::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from( self.operation - .attribute(#attribute_name) - .expect(#attribute_missing_error) - ).expect("is a DenseI32ArrayAttribute"); + .attribute(#attribute_name)? + )?; let start = (0..#index) - .map(|index| attribute.element(index) - .expect("has segment size")) + .map(|index| attribute.element(index)) + .collect::, _>>()? + .into_iter() .sum::() as usize; - let group_len = attribute - .element(#index) - .expect("has segment size") as usize; + let group_len = attribute.element(#index)? as usize; }; let get_elements = if !constraint.is_variable_length() { quote! { - self.operation.#kind_ident(start).expect(#error) + self.operation.#kind_ident(start) } } else if constraint.is_optional() { quote! { if group_len == 0 { - None + Err(::melior::Error::#error_variant(#name)) } else { - self.operation.#kind_ident(start).ok() + self.operation.#kind_ident(start) } } } else { quote! { - self.operation.#plural().skip(start).take(group_len) + Ok(self.operation.#plural().skip(start).take(group_len)) } }; @@ -140,7 +138,7 @@ impl<'a> OperationField<'a> { } } else { quote! { - self.operation.successor(#index).expect("operation should have this successor") + self.operation.successor(#index) } }) } @@ -155,30 +153,21 @@ impl<'a> OperationField<'a> { } } else { quote! { - self.operation.region(#index).expect("operation should have this region") + self.operation.region(#index) } }) } FieldKind::Attribute(constraint) => { let name = &self.name; - let attribute_error = format!("operation should have attribute {}", name); - let type_error = format!("{} should be a {}", name, constraint.storage_type()); Some(if constraint.is_unit() { quote! { self.operation.attribute(#name).is_some() } - } else if constraint.is_optional() { - quote! { - self.operation - .attribute(#name) - .map(|attribute| attribute.try_into().expect(#type_error)) - } } else { quote! { self.operation - .attribute(#name) - .expect(#attribute_error) + .attribute(#name)? .try_into() - .expect(#type_error) + .map_err(::melior::Error::from) } }) } @@ -192,7 +181,7 @@ impl<'a> OperationField<'a> { if constraint.is_unit() || constraint.is_optional() { Some(quote! { - let _ = self.operation.remove_attribute(#name); + self.operation.remove_attribute(#name) }) } else { None @@ -239,7 +228,7 @@ impl<'a> OperationField<'a> { let ident = sanitize_name_snake(&format!("remove_{}", self.name)); self.remover_impl().map_or(quote!(), |body| { quote! { - pub fn #ident(&mut self) { + pub fn #ident(&mut self) -> Result<(), ::melior::Error> { #body } } diff --git a/macro/tests/operand.rs b/macro/tests/operand.rs index 55babcf1aa..27991a61fc 100644 --- a/macro/tests/operand.rs +++ b/macro/tests/operand.rs @@ -24,8 +24,8 @@ fn simple() { location, ); - assert_eq!(op.lhs(), block.argument(0).unwrap().into()); - assert_eq!(op.rhs(), block.argument(1).unwrap().into()); + assert_eq!(op.lhs().unwrap(), block.argument(0).unwrap().into()); + assert_eq!(op.rhs().unwrap(), block.argument(1).unwrap().into()); assert_eq!(op.operation().operand_count(), 2); } @@ -48,7 +48,7 @@ fn variadic_after_single() { location, ); - assert_eq!(op.first(), block.argument(0).unwrap().into()); + assert_eq!(op.first().unwrap(), block.argument(0).unwrap().into()); assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into())); assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into())); assert_eq!(op.operation().operand_count(), 3); diff --git a/macro/tests/region.rs b/macro/tests/region.rs index 626b65268c..a2bfd9b090 100644 --- a/macro/tests/region.rs +++ b/macro/tests/region.rs @@ -22,7 +22,7 @@ fn single() { region_test::single(r1, location) }; - assert!(op.default_region().first_block().is_some()); + assert!(op.default_region().unwrap().first_block().is_some()); } #[test] @@ -51,7 +51,7 @@ fn variadic_after_single() { assert_eq!(op.operation().to_string(), op2.operation().to_string()); - assert!(op.default_region().first_block().is_none()); + assert!(op.default_region().unwrap().first_block().is_none()); assert_eq!(op.other_regions().count(), 2); assert!(op.other_regions().next().unwrap().first_block().is_some()); assert!(op.other_regions().nth(1).unwrap().first_block().is_none()); diff --git a/melior/src/error.rs b/melior/src/error.rs index 7efc475576..543fdf21b8 100644 --- a/melior/src/error.rs +++ b/melior/src/error.rs @@ -1,4 +1,5 @@ use std::{ + convert::Infallible, error, fmt::{self, Display, Formatter}, str::Utf8Error, @@ -15,6 +16,7 @@ pub enum Error { value: String, }, InvokeFunction, + OperandNotFound(&'static str), OperationResultExpected(String), PositionOutOfBounds { name: &'static str, @@ -22,6 +24,7 @@ pub enum Error { index: usize, }, ParsePassPipeline(String), + ResultNotFound(&'static str), RunPass, TypeExpected(&'static str, String), UnknownDiagnosticSeverity(u32), @@ -44,6 +47,9 @@ impl Display for Error { write!(formatter, "element of {type} type expected: {value}") } Self::InvokeFunction => write!(formatter, "failed to invoke JIT-compiled function"), + Self::OperandNotFound(name) => { + write!(formatter, "operand {name} not found") + } Self::OperationResultExpected(value) => { write!(formatter, "operation result expected: {value}") } @@ -53,6 +59,9 @@ impl Display for Error { Self::PositionOutOfBounds { name, value, index } => { write!(formatter, "{name} position {index} out of bounds: {value}") } + Self::ResultNotFound(name) => { + write!(formatter, "result {name} not found") + } Self::RunPass => write!(formatter, "failed to run pass"), Self::TypeExpected(r#type, actual) => { write!(formatter, "{type} type expected: {actual}") @@ -74,3 +83,9 @@ impl From for Error { Self::Utf8(error) } } + +impl From for Error { + fn from(_: Infallible) -> Self { + unreachable!() + } +} diff --git a/melior/src/ir/operation.rs b/melior/src/ir/operation.rs index 0c18c612a4..3d37f6809e 100644 --- a/melior/src/ir/operation.rs +++ b/melior/src/ir/operation.rs @@ -184,18 +184,19 @@ impl<'c> Operation<'c> { } /// Gets a attribute with the given name. - pub fn attribute(&self, name: &str) -> Option> { + pub fn attribute(&self, name: &str) -> Result, Error> { unsafe { Attribute::from_option_raw(mlirOperationGetAttributeByName( self.raw, StringRef::from(name).to_raw(), )) } + .ok_or(Error::AttributeNotFound(name.into())) } /// Checks if the operation has a attribute with the given name. pub fn has_attribute(&self, name: &str) -> bool { - self.attribute(name).is_some() + self.attribute(name).is_ok() } /// Sets the attribute with the given name to the given attribute. @@ -547,14 +548,14 @@ mod tests { assert!(operation.has_attribute("foo")); assert_eq!( operation.attribute("foo").map(|a| a.to_string()), - Some("\"bar\"".into()) + Ok("\"bar\"".into()) ); assert!(operation.remove_attribute("foo").is_ok()); assert!(operation.remove_attribute("foo").is_err()); operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into()); assert_eq!( operation.attribute("foo").map(|a| a.to_string()), - Some("\"foo\"".into()) + Ok("\"foo\"".into()) ); assert_eq!( operation.attributes().next(),