diff --git a/macro/src/dialect/dialect.rs b/macro/src/dialect/dialect.rs new file mode 100644 index 0000000000..67833e57fd --- /dev/null +++ b/macro/src/dialect/dialect.rs @@ -0,0 +1,104 @@ +mod error; +mod input; +mod operation; +mod types; +mod utility; + +use self::{ + error::Error, + utility::{sanitize_documentation, sanitize_snake_case_name}, +}; +pub use input::DialectInput; +use operation::Operation; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use std::{env, fmt::Display, path::Path, process::Command, str}; +use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser}; + +const LLVM_MAJOR_VERSION: usize = 17; + +pub fn generate_dialect(input: DialectInput) -> Result> { + let mut parser = TableGenParser::new(); + + if let Some(source) = input.table_gen() { + parser = parser.add_source(source).map_err(create_syn_error)?; + } + + if let Some(file) = input.td_file() { + parser = parser.add_source_file(file).map_err(create_syn_error)?; + } + + // spell-checker: disable-next-line + for path in input.includes().chain([&*llvm_config("--includedir")?]) { + parser = parser.add_include_path(path); + } + + let keeper = parser.parse().map_err(Error::Parse)?; + + let dialect = generate_dialect_module( + input.name(), + keeper + .all_derived_definitions("Dialect") + .find(|definition| definition.str_value("name") == Ok(input.name())) + .ok_or_else(|| create_syn_error("dialect not found"))?, + &keeper, + ) + .map_err(|error| error.add_source_info(keeper.source_info()))?; + + Ok(quote! { #dialect }.into()) +} + +fn generate_dialect_module( + name: &str, + dialect: Record, + record_keeper: &RecordKeeper, +) -> Result { + let dialect_name = dialect.name()?; + let operations = record_keeper + .all_derived_definitions("Op") + .map(Operation::new) + .collect::, _>>()? + .into_iter() + .filter(|operation| operation.dialect_name() == dialect_name) + .collect::>(); + + let doc = format!( + "`{name}` dialect.\n\n{}", + sanitize_documentation(dialect.str_value("description").unwrap_or(""),)? + ); + let name = sanitize_snake_case_name(name)?; + + Ok(quote! { + #[doc = #doc] + pub mod #name { + #(#operations)* + } + }) +} + +fn llvm_config(argument: &str) -> Result> { + let prefix = env::var(format!("MLIR_SYS_{}0_PREFIX", LLVM_MAJOR_VERSION)) + .map(|path| Path::new(&path).join("bin")) + .unwrap_or_default(); + let call = format!( + "{} --link-static {}", + prefix.join("llvm-config").display(), + argument + ); + + Ok(str::from_utf8( + &if cfg!(target_os = "windows") { + Command::new("cmd").args(["/C", &call]).output()? + } else { + Command::new("sh").arg("-c").arg(&call).output()? + } + .stdout, + )? + .trim() + .to_string()) +} + +fn create_syn_error(error: impl Display) -> syn::Error { + syn::Error::new(Span::call_site(), format!("{}", error)) +} diff --git a/macro/src/dialect/operation.rs b/macro/src/dialect/operation.rs index 480f7a4255..35b9f33285 100644 --- a/macro/src/dialect/operation.rs +++ b/macro/src/dialect/operation.rs @@ -37,6 +37,76 @@ pub struct Operation<'a> { } impl<'a> Operation<'a> { + pub fn new(definition: Record<'a>) -> Result { + let dialect = definition.def_value("opDialect")?; + let traits = Self::collect_traits(definition)?; + let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name)); + + let arguments = Self::dag_constraints(definition, "arguments")?; + let regions = Self::collect_regions(definition)?; + let (results, variable_length_results_count) = Self::collect_results( + definition, + has_trait("::mlir::OpTrait::SameVariadicResultSize"), + has_trait("::mlir::OpTrait::AttrSizedResultSegments"), + )?; + + let name = definition.name()?; + let class_name = if name.starts_with('_') { + name + } else if let Some(name) = name.split('_').nth(1) { + // Trim dialect prefix from name. + name + } else { + name + }; + let short_name = definition.str_value("opName")?; + + Ok(Self { + dialect_name: dialect.name()?, + short_name, + full_name: { + let dialect_name = dialect.string_value("name")?; + + if dialect_name.is_empty() { + short_name.into() + } else { + format!("{dialect_name}.{short_name}") + } + }, + class_name, + successors: Self::collect_successors(definition)?, + operands: Self::collect_operands( + &arguments, + has_trait("::mlir::OpTrait::SameVariadicOperandSize"), + has_trait("::mlir::OpTrait::AttrSizedOperandSegments"), + )?, + results, + attributes: Self::collect_attributes(&arguments)?, + derived_attributes: Self::collect_derived_attributes(definition)?, + can_infer_type: traits.iter().any(|r#trait| { + (r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") + || r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType")) + && variable_length_results_count == 0 + || r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty() + }), + summary: { + let summary = definition.str_value("summary")?; + + [ + format!("[`{short_name}`]({class_name}) operation."), + if summary.is_empty() { + Default::default() + } else { + summary[0..1].to_uppercase() + &summary[1..] + "." + }, + ] + .join(" ") + }, + description: sanitize_documentation(definition.str_value("description")?)?, + regions, + }) + } + pub fn dialect_name(&self) -> &str { self.dialect_name } @@ -254,76 +324,6 @@ impl<'a> Operation<'a> { }) .collect() } - - pub fn from_definition(definition: Record<'a>) -> Result { - let dialect = definition.def_value("opDialect")?; - let traits = Self::collect_traits(definition)?; - let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name)); - - let arguments = Self::dag_constraints(definition, "arguments")?; - let regions = Self::collect_regions(definition)?; - let (results, variable_length_results_count) = Self::collect_results( - definition, - has_trait("::mlir::OpTrait::SameVariadicResultSize"), - has_trait("::mlir::OpTrait::AttrSizedResultSegments"), - )?; - - let name = definition.name()?; - let class_name = if name.starts_with('_') { - name - } else if let Some(name) = name.split('_').nth(1) { - // Trim dialect prefix from name. - name - } else { - name - }; - let short_name = definition.str_value("opName")?; - - Ok(Self { - dialect_name: dialect.name()?, - short_name, - full_name: { - let dialect_name = dialect.string_value("name")?; - - if dialect_name.is_empty() { - short_name.into() - } else { - format!("{dialect_name}.{short_name}") - } - }, - class_name, - successors: Self::collect_successors(definition)?, - operands: Self::collect_operands( - &arguments, - has_trait("::mlir::OpTrait::SameVariadicOperandSize"), - has_trait("::mlir::OpTrait::AttrSizedOperandSegments"), - )?, - results, - attributes: Self::collect_attributes(&arguments)?, - derived_attributes: Self::collect_derived_attributes(definition)?, - can_infer_type: traits.iter().any(|r#trait| { - (r#trait.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") - || r#trait.has_name("::mlir::OpTrait::SameOperandsAndResultType")) - && variable_length_results_count == 0 - || r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty() - }), - summary: { - let summary = definition.str_value("summary")?; - - [ - format!("[`{short_name}`]({class_name}) operation."), - if summary.is_empty() { - Default::default() - } else { - summary[0..1].to_uppercase() + &summary[1..] + "." - }, - ] - .join(" ") - }, - description: sanitize_documentation(definition.str_value("description")?)?, - regions, - }) - } } impl<'a> ToTokens for Operation<'a> {