Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed Dec 6, 2023
1 parent 3c227e3 commit 8037059
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 78 deletions.
9 changes: 7 additions & 2 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ fn generate_dialect_module(
.map(Operation::new)
.collect::<Result<Vec<_>, _>>()?
.iter()
.filter(|operation| operation.dialect_name() == dialect_name)
.map(|operation| generate_operation(operation))
.map(|operation| {
Ok::<_, Error>(if operation.dialect_name()? == dialect_name {
Some(generate_operation(operation)?)
} else {
None
})
})
.collect::<Result<Vec<_>, _>>()?;

let doc = format!(
Expand Down
114 changes: 63 additions & 51 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,8 @@ use tblgen::{error::WithLocation, record::Record};

#[derive(Clone, Debug)]
pub struct Operation<'a> {
dialect_name: &'a str,
short_name: &'a str,
full_name: String,
class_name: &'a str,
summary: String,
definition: Record<'a>,
can_infer_type: bool,
description: String,
regions: Vec<OperationField<'a>>,
successors: Vec<OperationField<'a>>,
results: Vec<OperationField<'a>>,
Expand All @@ -38,7 +33,6 @@ pub struct Operation<'a> {

impl<'a> Operation<'a> {
pub fn new(definition: Record<'a>) -> Result<Self, Error> {
let dialect = definition.def_value("opDialect")?;
let traits = Self::collect_traits(definition)?;
let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name));

Expand All @@ -50,30 +44,7 @@ impl<'a> Operation<'a> {
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
)?;

let name = definition.name()?;
let class_name = if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
};
let short_name = definition.str_value("opName")?;

Ok(Self {
dialect_name: dialect.name()?,
short_name,
full_name: {
let dialect_name = dialect.string_value("name")?;

if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
}
},
class_name,
successors: Self::collect_successors(definition)?,
operands: Self::collect_operands(
&arguments,
Expand All @@ -89,26 +60,13 @@ impl<'a> Operation<'a> {
&& unfixed_result_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
let summary = definition.str_value("summary")?;

[
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" ")
},
description: sanitize_documentation(definition.str_value("description")?)?,
regions,
definition,
})
}

pub fn dialect_name(&self) -> &str {
self.dialect_name
pub fn dialect_name(&self) -> Result<&str, Error> {
Ok(self.dialect()?.name()?)
}

pub fn fields(&self) -> impl Iterator<Item = &OperationField<'a>> + Clone {
Expand Down Expand Up @@ -334,21 +292,75 @@ impl<'a> Operation<'a> {
})
.collect()
}

pub fn dialect(&self) -> Result<Record, Error> {
Ok(self.definition.def_value("opDialect")?)
}

pub fn class_name(&self) -> Result<&str, Error> {
let name = self.definition.name()?;

Ok(if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
})
}

pub fn short_name(&self) -> Result<&str, Error> {
Ok(self.definition.str_value("opName")?)
}

pub fn full_name(&self) -> Result<String, Error> {
let dialect_name = self.dialect()?.string_value("name")?;
let short_name = self.short_name()?;

Ok(if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
})
}

pub fn summary(&self) -> Result<String, Error> {
let short_name = self.short_name()?;
let class_name = self.class_name()?;
let summary = self.definition.str_value("summary")?;

Ok([
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" "))
}

pub fn description(&self) -> Result<String, Error> {
Ok(sanitize_documentation(
self.definition.str_value("description")?,
)?)
}
}

pub fn generate_operation(operation: &Operation) -> Result<TokenStream, Error> {
let class_name = format_ident!("{}", &operation.class_name);
let name = &operation.full_name;
let class_name = format_ident!("{}", &operation.class_name()?);
let name = &operation.full_name()?;
let accessors = operation
.fields()
.map(|field| field.accessors())
.collect::<Result<Vec<_>, _>>()?;
let builder = OperationBuilder::new(operation)?;
let builder_tokens = builder.to_tokens()?;
let builder_fn = builder.create_op_builder_fn();
let builder_fn = builder.create_op_builder_fn()?;
let default_constructor = builder.create_default_constructor()?;
let summary = &operation.summary;
let description = &operation.description;
let summary = &operation.summary()?;
let description = &operation.description()?;

Ok(quote! {
#[doc = #summary]
Expand Down
50 changes: 25 additions & 25 deletions macro/src/dialect/operation/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ impl<'o> OperationBuilder<'o> {
field_names: &'a [Ident],
phantoms: &'a [TokenStream],
) -> impl Iterator<Item = Result<TokenStream, Error>> + 'a {
let builder_ident = self.builder_identifier();

self.operation.fields().map(move |field| {
// TODO Initialize a builder identifier out of this closure.
let builder_ident = self.builder_identifier()?;
let name = sanitize_snake_case_name(field.name)?;
let parameter_type = field.kind.parameter_type()?;
let argument = quote! { #name: #parameter_type };
Expand Down Expand Up @@ -132,11 +132,11 @@ impl<'o> OperationBuilder<'o> {
.create_builder_fns(&field_names, phantom_arguments.as_slice())
.collect::<Result<Vec<_>, _>>()?;

let new = self.create_new_fn(phantom_arguments.as_slice());
let build = self.create_build_fn();
let new = self.create_new_fn(phantom_arguments.as_slice())?;
let build = self.create_build_fn()?;

let builder_ident = self.builder_identifier();
let doc = format!("Builder for {}", self.operation.summary);
let builder_ident = self.builder_identifier()?;
let doc = format!("Builder for {}", self.operation.summary()?);
let iter_arguments = self.type_state.parameters();

Ok(quote! {
Expand All @@ -155,31 +155,31 @@ impl<'o> OperationBuilder<'o> {
})
}

fn create_build_fn(&self) -> TokenStream {
let builder_ident = self.builder_identifier();
fn create_build_fn(&self) -> Result<TokenStream, Error> {
let builder_ident = self.builder_identifier()?;
let arguments = self.type_state.arguments_all_set(true);
let class_name = format_ident!("{}", &self.operation.class_name);
let class_name = format_ident!("{}", &self.operation.class_name()?);
let error = format!("should be a valid {class_name}");
let maybe_infer = self
.operation
.can_infer_type
.then_some(quote! { .enable_result_type_inference() });

quote! {
Ok(quote! {
impl<'c> #builder_ident<'c, #(#arguments),*> {
pub fn build(self) -> #class_name<'c> {
self.builder #maybe_infer.build().expect("valid operation").try_into().expect(#error)
}
}
}
})
}

fn create_new_fn(&self, phantoms: &[TokenStream]) -> TokenStream {
let builder_ident = self.builder_identifier();
let name = &self.operation.full_name;
fn create_new_fn(&self, phantoms: &[TokenStream]) -> Result<TokenStream, Error> {
let builder_ident = self.builder_identifier()?;
let name = &self.operation.full_name()?;
let arguments = self.type_state.arguments_all_set(false);

quote! {
Ok(quote! {
impl<'c> #builder_ident<'c, #(#arguments),*> {
pub fn new(context: &'c ::melior::Context, location: ::melior::ir::Location<'c>) -> Self {
Self {
Expand All @@ -189,26 +189,26 @@ impl<'o> OperationBuilder<'o> {
}
}
}
}
})
}

pub fn create_op_builder_fn(&self) -> TokenStream {
let builder_ident = self.builder_identifier();
pub fn create_op_builder_fn(&self) -> Result<TokenStream, Error> {
let builder_ident = self.builder_identifier()?;
let arguments = self.type_state.arguments_all_set(false);

quote! {
Ok(quote! {
pub fn builder(
context: &'c ::melior::Context,
location: ::melior::ir::Location<'c>
) -> #builder_ident<'c, #(#arguments),*> {
#builder_ident::new(context, location)
}
}
})
}

pub fn create_default_constructor(&self) -> Result<TokenStream, Error> {
let class_name = format_ident!("{}", &self.operation.class_name);
let name = sanitize_snake_case_name(self.operation.short_name)?;
let class_name = format_ident!("{}", &self.operation.class_name()?);
let name = sanitize_snake_case_name(self.operation.short_name()?)?;
let arguments = Self::required_fields(self.operation)
.map(|field| {
let field = field?;
Expand All @@ -227,7 +227,7 @@ impl<'o> OperationBuilder<'o> {
})
.collect::<Result<Vec<_>, Error>>()?;

let doc = format!("Creates a new {}", self.operation.summary);
let doc = format!("Creates a new {}", self.operation.summary()?);

Ok(quote! {
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -259,7 +259,7 @@ impl<'o> OperationBuilder<'o> {
))
}

fn builder_identifier(&self) -> Ident {
format_ident!("{}Builder", self.operation.class_name)
fn builder_identifier(&self) -> Result<Ident, Error> {
Ok(format_ident!("{}Builder", self.operation.class_name()?))
}
}

0 comments on commit 8037059

Please sign in to comment.