Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed Dec 5, 2023
1 parent 0fccd26 commit 8aaed8e
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 70 deletions.
104 changes: 104 additions & 0 deletions macro/src/dialect/dialect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
mod error;
mod input;
mod operation;
mod types;
mod utility;

use self::{
error::Error,
utility::{sanitize_documentation, sanitize_snake_case_name},
};
pub use input::DialectInput;
use operation::Operation;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use std::{env, fmt::Display, path::Path, process::Command, str};
use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser};

const LLVM_MAJOR_VERSION: usize = 17;

pub fn generate_dialect(input: DialectInput) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut parser = TableGenParser::new();

if let Some(source) = input.table_gen() {
parser = parser.add_source(source).map_err(create_syn_error)?;
}

if let Some(file) = input.td_file() {
parser = parser.add_source_file(file).map_err(create_syn_error)?;
}

// spell-checker: disable-next-line
for path in input.includes().chain([&*llvm_config("--includedir")?]) {
parser = parser.add_include_path(path);
}

let keeper = parser.parse().map_err(Error::Parse)?;

let dialect = generate_dialect_module(
input.name(),
keeper
.all_derived_definitions("Dialect")
.find(|definition| definition.str_value("name") == Ok(input.name()))
.ok_or_else(|| create_syn_error("dialect not found"))?,
&keeper,
)
.map_err(|error| error.add_source_info(keeper.source_info()))?;

Ok(quote! { #dialect }.into())
}

fn generate_dialect_module(
name: &str,
dialect: Record,
record_keeper: &RecordKeeper,
) -> Result<proc_macro2::TokenStream, Error> {
let dialect_name = dialect.name()?;
let operations = record_keeper
.all_derived_definitions("Op")
.map(Operation::new)
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter(|operation| operation.dialect_name() == dialect_name)
.collect::<Vec<_>>();

let doc = format!(
"`{name}` dialect.\n\n{}",
sanitize_documentation(dialect.str_value("description").unwrap_or(""),)?
);
let name = sanitize_snake_case_name(name)?;

Ok(quote! {
#[doc = #doc]
pub mod #name {
#(#operations)*
}
})
}

fn llvm_config(argument: &str) -> Result<String, Box<dyn std::error::Error>> {
let prefix = env::var(format!("MLIR_SYS_{}0_PREFIX", LLVM_MAJOR_VERSION))
.map(|path| Path::new(&path).join("bin"))
.unwrap_or_default();
let call = format!(
"{} --link-static {}",
prefix.join("llvm-config").display(),
argument
);

Ok(str::from_utf8(
&if cfg!(target_os = "windows") {
Command::new("cmd").args(["/C", &call]).output()?
} else {
Command::new("sh").arg("-c").arg(&call).output()?
}
.stdout,
)?
.trim()
.to_string())
}

fn create_syn_error(error: impl Display) -> syn::Error {
syn::Error::new(Span::call_site(), format!("{}", error))
}
140 changes: 70 additions & 70 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,76 @@ pub struct Operation<'a> {
}

impl<'a> Operation<'a> {
pub fn new(definition: Record<'a>) -> Result<Self, Error> {
let dialect = definition.def_value("opDialect")?;
let traits = Self::collect_traits(definition)?;
let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name));

let arguments = Self::dag_constraints(definition, "arguments")?;
let regions = Self::collect_regions(definition)?;
let (results, variable_length_results_count) = Self::collect_results(
definition,
has_trait("::mlir::OpTrait::SameVariadicResultSize"),
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
)?;

let name = definition.name()?;
let class_name = if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
};
let short_name = definition.str_value("opName")?;

Ok(Self {
dialect_name: dialect.name()?,
short_name,
full_name: {
let dialect_name = dialect.string_value("name")?;

if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
}
},
class_name,
successors: Self::collect_successors(definition)?,
operands: Self::collect_operands(
&arguments,
has_trait("::mlir::OpTrait::SameVariadicOperandSize"),
has_trait("::mlir::OpTrait::AttrSizedOperandSegments"),
)?,
results,
attributes: Self::collect_attributes(&arguments)?,
derived_attributes: Self::collect_derived_attributes(definition)?,
can_infer_type: traits.iter().any(|r#trait| {
(r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType")
|| r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType"))
&& variable_length_results_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
let summary = definition.str_value("summary")?;

[
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" ")
},
description: sanitize_documentation(definition.str_value("description")?)?,
regions,
})
}

pub fn dialect_name(&self) -> &str {
self.dialect_name
}
Expand Down Expand Up @@ -254,76 +324,6 @@ impl<'a> Operation<'a> {
})
.collect()
}

pub fn from_definition(definition: Record<'a>) -> Result<Self, Error> {
let dialect = definition.def_value("opDialect")?;
let traits = Self::collect_traits(definition)?;
let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name));

let arguments = Self::dag_constraints(definition, "arguments")?;
let regions = Self::collect_regions(definition)?;
let (results, variable_length_results_count) = Self::collect_results(
definition,
has_trait("::mlir::OpTrait::SameVariadicResultSize"),
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
)?;

let name = definition.name()?;
let class_name = if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
};
let short_name = definition.str_value("opName")?;

Ok(Self {
dialect_name: dialect.name()?,
short_name,
full_name: {
let dialect_name = dialect.string_value("name")?;

if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
}
},
class_name,
successors: Self::collect_successors(definition)?,
operands: Self::collect_operands(
&arguments,
has_trait("::mlir::OpTrait::SameVariadicOperandSize"),
has_trait("::mlir::OpTrait::AttrSizedOperandSegments"),
)?,
results,
attributes: Self::collect_attributes(&arguments)?,
derived_attributes: Self::collect_derived_attributes(definition)?,
can_infer_type: traits.iter().any(|r#trait| {
(r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType")
|| r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType"))
&& variable_length_results_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
let summary = definition.str_value("summary")?;

[
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" ")
},
description: sanitize_documentation(definition.str_value("description")?)?,
regions,
})
}
}

impl<'a> ToTokens for Operation<'a> {
Expand Down

0 comments on commit 8aaed8e

Please sign in to comment.