Skip to content

Commit

Permalink
Refactor macro (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Dec 6, 2023
1 parent 9f10c9e commit c18b6b1
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 221 deletions.
12 changes: 9 additions & 3 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod utility;

use self::{
error::Error,
operation::generate_operation,
utility::{sanitize_documentation, sanitize_snake_case_name},
};
pub use input::DialectInput;
Expand Down Expand Up @@ -64,9 +65,14 @@ fn generate_dialect_module(
.all_derived_definitions("Op")
.map(Operation::new)
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter(|operation| operation.dialect_name() == dialect_name)
.map(|operation| operation.to_tokens())
.iter()
.map(|operation| {
Ok::<_, Error>(if operation.dialect_name()? == dialect_name {
Some(generate_operation(operation)?)
} else {
None
})
})
.collect::<Result<Vec<_>, _>>()?;

let doc = format!(
Expand Down
26 changes: 3 additions & 23 deletions macro/src/dialect/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod ods;

pub use self::ods::OdsError;
use std::{
error,
fmt::{self, Display, Formatter},
Expand Down Expand Up @@ -78,26 +81,3 @@ impl From<FromUtf8Error> for Error {
Self::Utf8(error)
}
}

#[derive(Debug)]
pub enum OdsError {
ExpectedSuperClass(&'static str),
InvalidTrait,
UnexpectedSuperClass(&'static str),
}

impl Display for OdsError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
match self {
Self::ExpectedSuperClass(class) => {
write!(formatter, "record should be a sub-class of {class}",)
}
Self::InvalidTrait => write!(formatter, "record is not a supported trait"),
Self::UnexpectedSuperClass(class) => {
write!(formatter, "record should not be a sub-class of {class}",)
}
}
}
}

impl error::Error for OdsError {}
27 changes: 27 additions & 0 deletions macro/src/dialect/error/ods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::{
error::Error,
fmt::{self, Display, Formatter},
};

#[derive(Debug)]
pub enum OdsError {
ExpectedSuperClass(&'static str),
InvalidTrait,
UnexpectedSuperClass(&'static str),
}

impl Display for OdsError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
match self {
Self::ExpectedSuperClass(class) => {
write!(formatter, "record should be a sub-class of {class}",)
}
Self::InvalidTrait => write!(formatter, "record is not a supported trait"),
Self::UnexpectedSuperClass(class) => {
write!(formatter, "record should not be a sub-class of {class}",)
}
}
}
}

impl Error for OdsError {}
230 changes: 123 additions & 107 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ mod sequence_info;
mod variadic_kind;

use self::{
builder::OperationBuilder, element_kind::ElementKind, field_kind::FieldKind,
operation_field::OperationField, sequence_info::SequenceInfo, variadic_kind::VariadicKind,
builder::{generate_operation_builder, OperationBuilder},
element_kind::ElementKind,
field_kind::FieldKind,
operation_field::OperationField,
sequence_info::SequenceInfo,
variadic_kind::VariadicKind,
};
use super::utility::sanitize_documentation;
use crate::dialect::{
Expand All @@ -19,15 +23,70 @@ use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use tblgen::{error::WithLocation, record::Record};

#[derive(Clone, Debug)]
pub fn generate_operation(operation: &Operation) -> Result<TokenStream, Error> {
let summary = operation.summary()?;
let description = operation.description()?;
let class_name = format_ident!("{}", operation.class_name()?);
let name = &operation.full_name()?;
let accessors = operation
.fields()
.map(|field| field.accessors())
.collect::<Result<Vec<_>, _>>()?;

let builder = OperationBuilder::new(operation)?;
let builder_tokens = generate_operation_builder(&builder)?;
let builder_fn = builder.create_op_builder_fn()?;
let default_constructor = builder.create_default_constructor()?;

Ok(quote! {
#[doc = #summary]
#[doc = "\n\n"]
#[doc = #description]
pub struct #class_name<'c> {
operation: ::melior::ir::operation::Operation<'c>,
}

impl<'c> #class_name<'c> {
pub fn name() -> &'static str {
#name
}

pub fn operation(&self) -> &::melior::ir::operation::Operation<'c> {
&self.operation
}

#builder_fn

#(#accessors)*
}

#builder_tokens

#default_constructor

impl<'c> TryFrom<::melior::ir::operation::Operation<'c>> for #class_name<'c> {
type Error = ::melior::Error;

fn try_from(
operation: ::melior::ir::operation::Operation<'c>,
) -> Result<Self, Self::Error> {
// TODO Check an operation name.
Ok(Self { operation })
}
}

impl<'c> From<#class_name<'c>> for ::melior::ir::operation::Operation<'c> {
fn from(operation: #class_name<'c>) -> ::melior::ir::operation::Operation<'c> {
operation.operation
}
}
})
}

#[derive(Debug)]
pub struct Operation<'a> {
dialect_name: &'a str,
short_name: &'a str,
full_name: String,
class_name: &'a str,
summary: String,
definition: Record<'a>,
can_infer_type: bool,
description: String,
regions: Vec<OperationField<'a>>,
successors: Vec<OperationField<'a>>,
results: Vec<OperationField<'a>>,
Expand All @@ -38,7 +97,6 @@ pub struct Operation<'a> {

impl<'a> Operation<'a> {
pub fn new(definition: Record<'a>) -> Result<Self, Error> {
let dialect = definition.def_value("opDialect")?;
let traits = Self::collect_traits(definition)?;
let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name));

Expand All @@ -50,30 +108,7 @@ impl<'a> Operation<'a> {
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
)?;

let name = definition.name()?;
let class_name = if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
};
let short_name = definition.str_value("opName")?;

Ok(Self {
dialect_name: dialect.name()?,
short_name,
full_name: {
let dialect_name = dialect.string_value("name")?;

if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
}
},
class_name,
successors: Self::collect_successors(definition)?,
operands: Self::collect_operands(
&arguments,
Expand All @@ -89,26 +124,65 @@ impl<'a> Operation<'a> {
&& unfixed_result_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
let summary = definition.str_value("summary")?;

[
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" ")
},
description: sanitize_documentation(definition.str_value("description")?)?,
regions,
definition,
})
}

fn dialect(&self) -> Result<Record, Error> {
Ok(self.definition.def_value("opDialect")?)
}

pub fn dialect_name(&self) -> Result<&str, Error> {
Ok(self.dialect()?.name()?)
}

pub fn class_name(&self) -> Result<&str, Error> {
let name = self.definition.name()?;

Ok(if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
})
}

pub fn short_name(&self) -> Result<&str, Error> {
Ok(self.definition.str_value("opName")?)
}

pub fn full_name(&self) -> Result<String, Error> {
let dialect_name = self.dialect()?.string_value("name")?;
let short_name = self.short_name()?;

Ok(if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
})
}

pub fn dialect_name(&self) -> &str {
self.dialect_name
pub fn summary(&self) -> Result<String, Error> {
let short_name = self.short_name()?;
let class_name = self.class_name()?;
let summary = self.definition.str_value("summary")?;

Ok([
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" "))
}

pub fn description(&self) -> Result<String, Error> {
sanitize_documentation(self.definition.str_value("description")?)
}

pub fn fields(&self) -> impl Iterator<Item = &OperationField<'a>> + Clone {
Expand Down Expand Up @@ -334,62 +408,4 @@ impl<'a> Operation<'a> {
})
.collect()
}

pub fn to_tokens(&self) -> Result<TokenStream, Error> {
let class_name = format_ident!("{}", &self.class_name);
let name = &self.full_name;
let accessors = self
.fields()
.map(|field| field.accessors())
.collect::<Result<Vec<_>, _>>()?;
let builder = OperationBuilder::new(self)?;
let builder_tokens = builder.to_tokens()?;
let builder_fn = builder.create_op_builder_fn();
let default_constructor = builder.create_default_constructor()?;
let summary = &self.summary;
let description = &self.description;

Ok(quote! {
#[doc = #summary]
#[doc = "\n\n"]
#[doc = #description]
pub struct #class_name<'c> {
operation: ::melior::ir::operation::Operation<'c>,
}

impl<'c> #class_name<'c> {
pub fn name() -> &'static str {
#name
}

pub fn operation(&self) -> &::melior::ir::operation::Operation<'c> {
&self.operation
}

#builder_fn

#(#accessors)*
}

#builder_tokens

#default_constructor

impl<'c> TryFrom<::melior::ir::operation::Operation<'c>> for #class_name<'c> {
type Error = ::melior::Error;

fn try_from(
operation: ::melior::ir::operation::Operation<'c>,
) -> Result<Self, Self::Error> {
Ok(Self { operation })
}
}

impl<'c> From<#class_name<'c>> for ::melior::ir::operation::Operation<'c> {
fn from(operation: #class_name<'c>) -> ::melior::ir::operation::Operation<'c> {
operation.operation
}
}
})
}
}
Loading

0 comments on commit c18b6b1

Please sign in to comment.