Skip to content

Commit

Permalink
Return Result in accessors instead of panic (#286)
Browse files Browse the repository at this point in the history
As previously suggested in #274, I replaced the `.expect` with returning
a `Result` in the operation accessors of ODS generated dialects.

I also added some variants to the `Error` enum to support this. The
`Infallible` variant was added to allow using `TryInto` to convert
`Attribute` into itself, such that we avoid needing to handle
`Attribute` differently from `StringAttribute`, `IntAttribute`, etc. But
if you prefer, I can change the macro code to deal with this case
instead, I'm honestly not a big fan of my `Infallible` hack either.

---------

Co-authored-by: Yota Toyama <[email protected]>
  • Loading branch information
Danacus and raviqqe authored Aug 18, 2023
1 parent 555ba08 commit a5f68f8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 55 deletions.
24 changes: 16 additions & 8 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ impl<'a> OperationField<'a> {
let (param_type, return_type) = {
if ac.is_unit() {
(quote! { bool }, quote! { bool })
} else if ac.is_optional() {
(quote! { #kind_type<'c> }, quote! { Option<#kind_type<'c>> })
} else {
(quote! { #kind_type<'c> }, quote! { #kind_type<'c> })
(
quote! { #kind_type<'c> },
quote! { Result<#kind_type<'c>, ::melior::Error> },
)
}
};
let sanitized = sanitize_name_snake(name);
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<'a> OperationField<'a> {
} else {
(
quote! { ::melior::ir::Region<'c> },
quote! { ::melior::ir::RegionRef<'c, '_> },
quote! { Result<::melior::ir::RegionRef<'c, '_>, ::melior::Error> },
)
}
};
Expand Down Expand Up @@ -142,7 +143,7 @@ impl<'a> OperationField<'a> {
} else {
(
quote! { &::melior::ir::Block<'c> },
quote! { ::melior::ir::BlockRef<'c, '_> },
quote! { Result<::melior::ir::BlockRef<'c, '_>, ::melior::Error> },
)
}
};
Expand Down Expand Up @@ -201,16 +202,23 @@ impl<'a> OperationField<'a> {
if tc.is_optional() {
(
quote! { #param_kind_type },
quote! { Option<#return_kind_type> },
quote! { Result<#return_kind_type, ::melior::Error> },
)
} else {
(
quote! { &[#param_kind_type] },
quote! { impl Iterator<Item = #return_kind_type> },
if let VariadicKind::AttrSized {} = variadic_info {
quote! { Result<impl Iterator<Item = #return_kind_type>, ::melior::Error> }
} else {
quote! { impl Iterator<Item = #return_kind_type> }
},
)
}
} else {
(param_kind_type, return_kind_type)
(
param_kind_type,
quote!(Result<#return_kind_type, ::melior::Error>),
)
}
};

Expand Down
65 changes: 27 additions & 38 deletions macro/src/dialect/operation/accessors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ impl<'a> OperationField<'a> {
.variadic_info
.as_ref()
.expect("operands and results need variadic info");
let error_variant = match &self.kind {
FieldKind::Operand(_) => quote!(OperandNotFound),
FieldKind::Result(_) => quote!(ResultNotFound),
_ => unreachable!(),
};
let name = self.name;

Some(match variadic_kind {
VariadicKind::Simple {
Expand All @@ -32,9 +38,9 @@ impl<'a> OperationField<'a> {
// elements.
quote! {
if self.operation.#count() < #len {
None
Err(::melior::Error::#error_variant(#name))
} else {
self.operation.#kind_ident(#index).ok()
self.operation.#kind_ident(#index)
}
}
} else {
Expand All @@ -49,16 +55,14 @@ impl<'a> OperationField<'a> {
} else if *seen_variable_length {
// Single element after variable length group
// Compute the length of that variable group and take the next element
let error = format!("operation should have this {}", kind);
quote! {
let group_length = self.operation.#count() - #len + 1;
self.operation.#kind_ident(#index + group_length - 1).expect(#error)
self.operation.#kind_ident(#index + group_length - 1)
}
} else {
// All elements so far are singular
let error = format!("operation should have this {}", kind);
quote! {
self.operation.#kind_ident(#index).expect(#error)
self.operation.#kind_ident(#index)
}
}
}
Expand All @@ -67,7 +71,6 @@ impl<'a> OperationField<'a> {
num_preceding_simple,
num_preceding_variadic,
} => {
let error = format!("operation should have this {}", kind);
let compute_start_length = quote! {
let total_var_len = self.operation.#count() - #num_variable_length + 1;
let group_len = total_var_len / #num_variable_length;
Expand All @@ -79,47 +82,42 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.#kind_ident(start).expect(#error)
self.operation.#kind_ident(start)
}
};

quote! { #compute_start_length #get_elements }
}
VariadicKind::AttrSized {} => {
let error = format!("operation should have this {}", kind);
let attribute_name = format!("{}_segment_sizes", kind);
let attribute_missing_error =
format!("operation has {} attribute", attribute_name);
let compute_start_length = quote! {
let attribute =
::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from(
self.operation
.attribute(#attribute_name)
.expect(#attribute_missing_error)
).expect("is a DenseI32ArrayAttribute");
.attribute(#attribute_name)?
)?;
let start = (0..#index)
.map(|index| attribute.element(index)
.expect("has segment size"))
.map(|index| attribute.element(index))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.sum::<i32>() as usize;
let group_len = attribute
.element(#index)
.expect("has segment size") as usize;
let group_len = attribute.element(#index)? as usize;
};
let get_elements = if !constraint.is_variable_length() {
quote! {
self.operation.#kind_ident(start).expect(#error)
self.operation.#kind_ident(start)
}
} else if constraint.is_optional() {
quote! {
if group_len == 0 {
None
Err(::melior::Error::#error_variant(#name))
} else {
self.operation.#kind_ident(start).ok()
self.operation.#kind_ident(start)
}
}
} else {
quote! {
self.operation.#plural().skip(start).take(group_len)
Ok(self.operation.#plural().skip(start).take(group_len))
}
};

Expand All @@ -140,7 +138,7 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.successor(#index).expect("operation should have this successor")
self.operation.successor(#index)
}
})
}
Expand All @@ -155,30 +153,21 @@ impl<'a> OperationField<'a> {
}
} else {
quote! {
self.operation.region(#index).expect("operation should have this region")
self.operation.region(#index)
}
})
}
FieldKind::Attribute(constraint) => {
let name = &self.name;
let attribute_error = format!("operation should have attribute {}", name);
let type_error = format!("{} should be a {}", name, constraint.storage_type());

Some(if constraint.is_unit() {
quote! { self.operation.attribute(#name).is_some() }
} else if constraint.is_optional() {
quote! {
self.operation
.attribute(#name)
.map(|attribute| attribute.try_into().expect(#type_error))
}
} else {
quote! {
self.operation
.attribute(#name)
.expect(#attribute_error)
.attribute(#name)?
.try_into()
.expect(#type_error)
.map_err(::melior::Error::from)
}
})
}
Expand All @@ -192,7 +181,7 @@ impl<'a> OperationField<'a> {

if constraint.is_unit() || constraint.is_optional() {
Some(quote! {
let _ = self.operation.remove_attribute(#name);
self.operation.remove_attribute(#name)
})
} else {
None
Expand Down Expand Up @@ -239,7 +228,7 @@ impl<'a> OperationField<'a> {
let ident = sanitize_name_snake(&format!("remove_{}", self.name));
self.remover_impl().map_or(quote!(), |body| {
quote! {
pub fn #ident(&mut self) {
pub fn #ident(&mut self) -> Result<(), ::melior::Error> {
#body
}
}
Expand Down
6 changes: 3 additions & 3 deletions macro/tests/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ fn simple() {
location,
);

assert_eq!(op.lhs(), block.argument(0).unwrap().into());
assert_eq!(op.rhs(), block.argument(1).unwrap().into());
assert_eq!(op.lhs().unwrap(), block.argument(0).unwrap().into());
assert_eq!(op.rhs().unwrap(), block.argument(1).unwrap().into());
assert_eq!(op.operation().operand_count(), 2);
}

Expand All @@ -48,7 +48,7 @@ fn variadic_after_single() {
location,
);

assert_eq!(op.first(), block.argument(0).unwrap().into());
assert_eq!(op.first().unwrap(), block.argument(0).unwrap().into());
assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into()));
assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into()));
assert_eq!(op.operation().operand_count(), 3);
Expand Down
4 changes: 2 additions & 2 deletions macro/tests/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn single() {
region_test::single(r1, location)
};

assert!(op.default_region().first_block().is_some());
assert!(op.default_region().unwrap().first_block().is_some());
}

#[test]
Expand Down Expand Up @@ -51,7 +51,7 @@ fn variadic_after_single() {

assert_eq!(op.operation().to_string(), op2.operation().to_string());

assert!(op.default_region().first_block().is_none());
assert!(op.default_region().unwrap().first_block().is_none());
assert_eq!(op.other_regions().count(), 2);
assert!(op.other_regions().next().unwrap().first_block().is_some());
assert!(op.other_regions().nth(1).unwrap().first_block().is_none());
Expand Down
15 changes: 15 additions & 0 deletions melior/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
convert::Infallible,
error,
fmt::{self, Display, Formatter},
str::Utf8Error,
Expand All @@ -15,13 +16,15 @@ pub enum Error {
value: String,
},
InvokeFunction,
OperandNotFound(&'static str),
OperationResultExpected(String),
PositionOutOfBounds {
name: &'static str,
value: String,
index: usize,
},
ParsePassPipeline(String),
ResultNotFound(&'static str),
RunPass,
TypeExpected(&'static str, String),
UnknownDiagnosticSeverity(u32),
Expand All @@ -44,6 +47,9 @@ impl Display for Error {
write!(formatter, "element of {type} type expected: {value}")
}
Self::InvokeFunction => write!(formatter, "failed to invoke JIT-compiled function"),
Self::OperandNotFound(name) => {
write!(formatter, "operand {name} not found")
}
Self::OperationResultExpected(value) => {
write!(formatter, "operation result expected: {value}")
}
Expand All @@ -53,6 +59,9 @@ impl Display for Error {
Self::PositionOutOfBounds { name, value, index } => {
write!(formatter, "{name} position {index} out of bounds: {value}")
}
Self::ResultNotFound(name) => {
write!(formatter, "result {name} not found")
}
Self::RunPass => write!(formatter, "failed to run pass"),
Self::TypeExpected(r#type, actual) => {
write!(formatter, "{type} type expected: {actual}")
Expand All @@ -74,3 +83,9 @@ impl From<Utf8Error> for Error {
Self::Utf8(error)
}
}

impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
9 changes: 5 additions & 4 deletions melior/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,19 @@ impl<'c> Operation<'c> {
}

/// Gets a attribute with the given name.
pub fn attribute(&self, name: &str) -> Option<Attribute<'c>> {
pub fn attribute(&self, name: &str) -> Result<Attribute<'c>, Error> {
unsafe {
Attribute::from_option_raw(mlirOperationGetAttributeByName(
self.raw,
StringRef::from(name).to_raw(),
))
}
.ok_or(Error::AttributeNotFound(name.into()))
}

/// Checks if the operation has a attribute with the given name.
pub fn has_attribute(&self, name: &str) -> bool {
self.attribute(name).is_some()
self.attribute(name).is_ok()
}

/// Sets the attribute with the given name to the given attribute.
Expand Down Expand Up @@ -547,14 +548,14 @@ mod tests {
assert!(operation.has_attribute("foo"));
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"bar\"".into())
Ok("\"bar\"".into())
);
assert!(operation.remove_attribute("foo").is_ok());
assert!(operation.remove_attribute("foo").is_err());
operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into());
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"foo\"".into())
Ok("\"foo\"".into())
);
assert_eq!(
operation.attributes().next(),
Expand Down

0 comments on commit a5f68f8

Please sign in to comment.