diff --git a/Cargo.lock b/Cargo.lock index a2a8d74b10..c3a242cf8f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -35,7 +35,30 @@ version = "0.65.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfdf7b466f9a4903edc73f95d6d2bcd5baf8ae620638762244d3f60143643cc5" dependencies = [ - "bitflags", + "bitflags 1.3.2", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + +[[package]] +name = "bindgen" +version = "0.66.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b84e06fc203107bfbad243f4aba2af864eb7db3b1cf46ea0a023b0b433d2a7" +dependencies = [ + "bitflags 2.4.0", "cexpr", "clang-sys", "lazy_static", @@ -58,6 +81,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" + [[package]] name = "bumpalo" version = "3.12.2" @@ -145,7 +174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f423e341edefb78c9caba2d9c7f7687d0e72e89df3ce3394554754393ac3990" dependencies = [ "anstyle", - "bitflags", + "bitflags 1.3.2", "clap_lex", ] @@ -484,6 +513,7 @@ name = "melior-macro" version = "0.4.2" dependencies = [ "convert_case", + "lazy_static", "melior", "mlir-sys", "once_cell", @@ -491,6 +521,8 @@ dependencies = [ "quote", "regex", "syn", + "tblgen", + "unindent", ] [[package]] @@ -520,7 +552,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1262be288d5f59eaa5a6367722e4fd2eb2f668229d2e3e3680530f266a193b3" dependencies = [ - "bindgen", + "bindgen 0.65.1", ] [[package]] @@ -577,6 +609,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "peeking_take_while" version = "0.1.2" @@ -677,7 +715,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -721,7 +759,7 @@ version = "0.37.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -810,6 +848,38 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tblgen" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d19c09266feb8b16718d1183044d14703a0b4b59e55ce8beb4d6e21dd066b1b" +dependencies = [ + "bindgen 0.66.1", + "cc", + "paste", + "thiserror", +] + +[[package]] +name = "thiserror" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -832,6 +902,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "unindent" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f86d931b9d0b666761dcfcbac3ec5e9daff8a2becfff93a8fce2591ae297b95" + [[package]] name = "walkdir" version = "2.3.3" diff --git a/macro/Cargo.toml b/macro/Cargo.toml index dfc678a7af..35ac01a3db 100644 --- a/macro/Cargo.toml +++ b/macro/Cargo.toml @@ -14,11 +14,14 @@ proc-macro = true [dependencies] convert_case = "0.6.0" +lazy_static = "1.4.0" once_cell = "1.18.0" proc-macro2 = "1" quote = "1" regex = "1.9.3" syn = { version = "2", features = ["full"] } +tblgen = { version = "0.3.0", features = ["llvm16-0"] } +unindent = "0.2.2" [dev-dependencies] melior = { path = "../melior" } diff --git a/macro/src/dialect/error.rs b/macro/src/dialect/error.rs new file mode 100644 index 0000000000..7935bfebe6 --- /dev/null +++ b/macro/src/dialect/error.rs @@ -0,0 +1,76 @@ +use std::fmt::Display; + +use proc_macro2::Span; +use tblgen::{ + error::{SourceError, TableGenError}, + SourceInfo, +}; + +#[derive(Debug)] +pub enum Error { + Syn(syn::Error), + TableGen(tblgen::Error), + ExpectedSuperClass(SourceError), + ParseError, +} + +impl Error { + pub fn add_source_info(self, info: SourceInfo) -> Self { + match self { + Self::TableGen(e) => e.add_source_info(info).into(), + Self::ExpectedSuperClass(e) => e.add_source_info(info).into(), + _ => self, + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Syn(e) => write!(f, "failed to parse macro input: {e}"), + Error::TableGen(e) => write!(f, "invalid ODS input: {e}"), + Error::ExpectedSuperClass(e) => write!(f, "invalid ODS input: {e}"), + Error::ParseError => write!(f, "error parsing TableGen source"), + } + } +} + +impl std::error::Error for Error {} + +impl From> for Error { + fn from(value: SourceError) -> Self { + Self::ExpectedSuperClass(value) + } +} + +impl From> for Error { + fn from(value: SourceError) -> Self { + Self::TableGen(value) + } +} + +impl From for Error { + fn from(value: syn::Error) -> Self { + Self::Syn(value) + } +} + +impl From for syn::Error { + fn from(value: Error) -> Self { + match value { + Error::Syn(e) => e, + _ => syn::Error::new(Span::call_site(), format!("{}", value)), + } + } +} + +#[derive(Debug)] +pub struct ExpectedSuperClassError(pub String); + +impl Display for ExpectedSuperClassError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "expected this record to be a subclass of {}", self.0) + } +} + +impl std::error::Error for ExpectedSuperClassError {} diff --git a/macro/src/dialect/mod.rs b/macro/src/dialect/mod.rs new file mode 100644 index 0000000000..96cfdba0c1 --- /dev/null +++ b/macro/src/dialect/mod.rs @@ -0,0 +1,204 @@ +extern crate proc_macro; +mod error; +mod operation; +mod types; + +use std::io::Write; +use std::{env, error::Error, fs::OpenOptions, path::Path, process::Command}; + +use crate::utility::sanitize_name_snake; +use operation::Operation; +use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; +use quote::{format_ident, quote}; +use syn::{bracketed, parse::Parse, punctuated::Punctuated, LitStr, Token}; +use tblgen::{record::Record, record_keeper::RecordKeeper, TableGenParser}; + +const LLVM_MAJOR_VERSION: usize = 16; + +fn dialect_module<'a>( + name: &str, + dialect: Record<'a>, + record_keeper: &'a RecordKeeper, +) -> Result { + let operations = record_keeper + .all_derived_definitions("Op") + .map(Operation::from_def) + .filter_map(|o: Result| match o { + Ok(o) => (o.dialect.name() == dialect.name()).then_some(Ok(o)), + Err(e) => Some(Err(e)), + }) + .collect::, _>>()?; + + let mut doc = format!("`{}` dialect.\n\n", name); + doc.push_str(&unindent::unindent( + dialect.str_value("description").unwrap_or(""), + )); + let name = sanitize_name_snake(name); + Ok(quote! { + #[doc = #doc] + pub mod #name { + #(#operations)* + } + }) +} + +enum InputField { + Name(LitStr), + TableGen(LitStr), + TdFile(LitStr), + Includes(Punctuated), +} + +impl Parse for InputField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let ident: Ident = input.parse()?; + let _: Token![:] = input.parse()?; + if ident == format_ident!("name") { + return Ok(Self::Name(input.parse()?)); + } + if ident == format_ident!("tablegen") { + return Ok(Self::TableGen(input.parse()?)); + } + if ident == format_ident!("td_file") { + return Ok(Self::TdFile(input.parse()?)); + } + if ident == format_ident!("include_dirs") { + let content; + bracketed!(content in input); + return Ok(Self::Includes( + Punctuated::::parse_terminated(&content)?, + )); + } + + Err(input.error(format!("invalid field {}", ident))) + } +} + +pub struct DialectMacroInput { + name: String, + tablegen: Option, + td_file: Option, + includes: Vec, +} + +impl Parse for DialectMacroInput { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let list = Punctuated::::parse_terminated(input)?; + let mut name = None; + let mut tablegen = None; + let mut td_file = None; + let mut includes = None; + + for item in list { + match item { + InputField::Name(n) => name = Some(n.value()), + InputField::TableGen(td) => tablegen = Some(td.value()), + InputField::TdFile(f) => td_file = Some(f.value()), + InputField::Includes(inc) => { + includes = Some(inc.into_iter().map(|l| l.value()).collect()) + } + } + } + + Ok(Self { + name: name.ok_or(input.error("dialect name required"))?, + tablegen, + td_file, + includes: includes.unwrap_or(Vec::new()), + }) + } +} + +// Writes `tablegen_compile_commands.yaml` for any TableGen file that is being parsed. +// See: https://mlir.llvm.org/docs/Tools/MLIRLSP/#tablegen-lsp-language-server--tblgen-lsp-server +fn emit_tablegen_compile_commands(td_file: &str, includes: &[String]) { + let pwd = std::env::current_dir(); + if let Ok(pwd) = pwd { + let path = pwd.join(td_file); + let file = OpenOptions::new() + .write(true) + .append(true) + .create(true) + .open(pwd.join("tablegen_compile_commands.yml")); + if let Ok(mut file) = file { + writeln!(file, "--- !FileInfo:").unwrap(); + writeln!(file, " filepath: \"{}\"", path.to_str().unwrap()).unwrap(); + let _ = writeln!( + file, + " includes: \"{}\"", + includes + .iter() + .map(|s| pwd.join(s.as_str()).to_str().unwrap().to_owned()) + .collect::>() + .join(";") + ); + } + } +} + +pub fn generate_dialect(mut input: DialectMacroInput) -> Result> { + input.includes.push(llvm_config("--includedir").unwrap()); + + let mut td_parser = TableGenParser::new(); + + if let Some(source) = input.tablegen.as_ref() { + td_parser = td_parser + .add_source(source.as_str()) + .map_err(|e| syn::Error::new(Span::call_site(), format!("{}", e)))?; + } + if let Some(file) = input.td_file.as_ref() { + td_parser = td_parser + .add_source_file(file.as_str()) + .map_err(|e| syn::Error::new(Span::call_site(), format!("{}", e)))?; + } + for include in input.includes.iter() { + td_parser = td_parser.add_include_path(include.as_str()); + } + + if std::env::var("DIALECTGEN_TABLEGEN_COMPILE_COMMANDS").is_ok() { + if let Some(td_file) = input.td_file.as_ref() { + emit_tablegen_compile_commands(td_file, &input.includes); + } + } + + let keeper = td_parser.parse().map_err(|_| error::Error::ParseError)?; + + let dialect_def = keeper + .all_derived_definitions("Dialect") + .find_map(|def| { + def.str_value("name") + .ok() + .and_then(|n| if n == input.name { Some(def) } else { None }) + }) + .ok_or_else(|| syn::Error::new(Span::call_site(), "dialect not found"))?; + let dialect = dialect_module(&input.name, dialect_def, &keeper) + .map_err(|e| e.add_source_info(keeper.source_info()))?; + + Ok(quote! { + #dialect + } + .into()) +} + +fn llvm_config(argument: &str) -> Result> { + let prefix = env::var(format!("MLIR_SYS_{}0_PREFIX", LLVM_MAJOR_VERSION)) + .map(|path| Path::new(&path).join("bin")) + .unwrap_or_default(); + let call = format!( + "{} --link-static {}", + prefix.join("llvm-config").display(), + argument + ); + + Ok(std::str::from_utf8( + &if cfg!(target_os = "windows") { + Command::new("cmd").args(["/C", &call]).output()? + } else { + Command::new("sh").arg("-c").arg(&call).output()? + } + .stdout, + )? + .trim() + .to_string()) +} diff --git a/macro/src/dialect/operation/accessors.rs b/macro/src/dialect/operation/accessors.rs new file mode 100644 index 0000000000..51a168e749 --- /dev/null +++ b/macro/src/dialect/operation/accessors.rs @@ -0,0 +1,263 @@ +use super::{FieldKind, OperationField, SequenceInfo, VariadicKind}; + +use crate::utility::sanitize_name_snake; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +impl<'a> OperationField<'a> { + fn getter_impl(&self) -> Option { + match &self.kind { + FieldKind::Operand(t) | FieldKind::Result(t) => { + let kind = self.kind.as_str(); + let kind_ident = format_ident!("{}", kind); + let plural = format_ident!("{}s", kind); + let count = format_ident!("{}_count", kind); + let SequenceInfo { index, len } = self + .seq_info + .as_ref() + .expect("operands and results need sequence info"); + let variadic_kind = self + .variadic_info + .as_ref() + .expect("operands and results need variadic info"); + Some(match variadic_kind { + VariadicKind::Simple { + seen_variable_length, + } => { + // At most one variable length group + if t.is_variable_length() { + if t.is_optional() { + // Optional element, and some singular elements. + // Only present if the amount of groups is at least the number of elements. + quote! { + if self.operation.#count() < #len { + None + } else { + self.operation.#kind_ident(#index).ok() + } + } + } else { + // A variable length group + // Length computed by subtracting the amount of other + // singular elements from the number of elements. + quote! { + let group_length = self.operation.#count() - #len + 1; + self.operation.#plural().skip(#index).take(group_length) + } + } + } else if *seen_variable_length { + // Single element after variable length group + // Compute the length of that variable group and take the next element + let error = format!("operation should have this {}", kind); + quote! { + let group_length = self.operation.#count() - #len + 1; + self.operation.#kind_ident(#index + group_length - 1).expect(#error) + } + } else { + // All elements so far are singular + let error = format!("operation should have this {}", kind); + quote! { + self.operation.#kind_ident(#index).expect(#error) + } + } + } + VariadicKind::SameSize { + num_variable_length, + num_preceding_simple, + num_preceding_variadic, + } => { + let error = format!("operation should have this {}", kind); + let compute_start_length = quote! { + let total_var_len = self.operation.#count() - #num_variable_length + 1; + let group_len = total_var_len / #num_variable_length; + let start = #num_preceding_simple + #num_preceding_variadic * group_len; + }; + + let get_elements = if t.is_variable_length() { + quote! { + self.operation.#plural().skip(start).take(group_len) + } + } else { + quote! { + self.operation.#kind_ident(start).expect(#error) + } + }; + quote! { #compute_start_length #get_elements } + } + VariadicKind::AttrSized {} => { + let error = format!("operation should have this {}", kind); + let attr_name = format!("{}_segment_sizes", kind); + let attr_missing_error = format!("operation has {} attribute", attr_name); + let compute_start_length = quote! { + let attribute = + ::melior::ir::attribute::DenseI32ArrayAttribute::<'c>::try_from( + self.operation + .attribute(#attr_name) + .expect(#attr_missing_error) + ).expect("is a DenseI32ArrayAttribute"); + let start = (0..#index) + .map(|i| attribute.element(i) + .expect("has segment size")) + .sum::() as usize; + let group_len = attribute + .element(#index) + .expect("has segment size") as usize; + }; + let get_elements = if t.is_variable_length() { + if t.is_optional() { + quote! { + if group_len == 0 { + None + } else { + self.operation.#kind_ident(start).ok() + } + } + } else { + quote! { + self.operation.#plural().skip(start).take(group_len) + } + } + } else { + quote! { + self.operation.#kind_ident(start).expect(#error) + } + }; + quote! { #compute_start_length #get_elements } + } + }) + } + FieldKind::Successor(s) => { + let SequenceInfo { index, .. } = self + .seq_info + .as_ref() + .expect("successors need sequence info"); + Some(if s.is_variadic() { + // Only the last successor can be variadic + quote! { + self.operation.successors().skip(#index) + } + } else { + quote! { + self.operation.successor(#index).expect("operation should have this successor") + } + }) + } + FieldKind::Region(r) => { + let SequenceInfo { index, .. } = + self.seq_info.as_ref().expect("regions need sequence info"); + Some(if r.is_variadic() { + // Only the last region can be variadic + quote! { + self.operation.regions().skip(#index) + } + } else { + quote! { + self.operation.region(#index).expect("operation should have this region") + } + }) + } + FieldKind::Attribute(a) => { + let n = &self.name; + let attr_error = format!("operation should have attribute {}", n); + let type_error = format!("{} should be a {}", n, a.storage_type()); + Some(if a.is_unit() { + quote! { self.operation.attribute(#n).is_some() } + } else if a.is_optional() { + quote! { + self.operation + .attribute(#n) + .map(|a| a.try_into().expect(#type_error)) + } + } else { + quote! { + self.operation + .attribute(#n) + .expect(#attr_error) + .try_into() + .expect(#type_error) + } + }) + } + } + } + + fn remover_impl(&self) -> Option { + match &self.kind { + FieldKind::Attribute(a) => { + let n = &self.name; + + if a.is_unit() || a.is_optional() { + Some(quote! { + let _ = self.operation.remove_attribute(#n); + }) + } else { + None + } + } + _ => None, + } + } + + fn setter_impl(&self) -> Option { + match &self.kind { + FieldKind::Attribute(a) => { + let n = &self.name; + + Some(if a.is_unit() { + quote! { + if value { + self.operation.set_attribute(#n, Attribute::unit(&self.operation.context())); + } else { + let _ = self.operation.remove_attribute(#n); + } + } + } else { + quote! { + self.operation.set_attribute(#n, &value.into()); + } + }) + } + _ => None, + } + } + + pub fn accessors(&self) -> TokenStream { + let setter = { + let set_fn_ident = sanitize_name_snake(&format!("set_{}", self.name)); + self.setter_impl().map_or(quote!(), |imp| { + let param_type = &self.param_type; + quote! { + pub fn #set_fn_ident(&mut self, value: #param_type) { + #imp + } + } + }) + }; + let remover = { + let remove_fn_ident = sanitize_name_snake(&format!("remove_{}", self.name)); + self.remover_impl().map_or(quote!(), |imp| { + quote! { + pub fn #remove_fn_ident(&mut self) { + #imp + } + } + }) + }; + let getter = { + let get_fn_ident = &self.sanitized; + let return_type = &self.return_type; + self.getter_impl().map_or(quote!(), |imp| { + quote! { + pub fn #get_fn_ident(&self) -> #return_type { + #imp + } + } + }) + }; + quote! { + #getter + #setter + #remover + } + } +} diff --git a/macro/src/dialect/operation/builder.rs b/macro/src/dialect/operation/builder.rs new file mode 100644 index 0000000000..3f91753d14 --- /dev/null +++ b/macro/src/dialect/operation/builder.rs @@ -0,0 +1,360 @@ +use convert_case::{Case, Casing}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote}; + +use crate::utility::sanitize_name_snake; + +use super::{FieldKind, Operation}; + +#[derive(Debug)] +struct TypeStateItem { + pub(crate) field_name: String, + pub(crate) yes: Ident, + pub(crate) no: Ident, + pub(crate) t: Ident, +} + +impl TypeStateItem { + pub fn new(class_name: &str, field_name: &str) -> Self { + let new_field_name = field_name.to_string().to_case(Case::Pascal); + Self { + field_name: field_name.to_string(), + yes: format_ident!("{}__Yes__{}", class_name, new_field_name), + no: format_ident!("{}__No__{}", class_name, new_field_name), + t: format_ident!("{}__Any__{}", class_name, new_field_name), + } + } +} + +#[derive(Debug)] +struct TypeStateList(Vec); + +impl TypeStateList { + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_all_any(&self) -> impl Iterator { + self.0.iter().map(|i| &i.t) + } + + pub fn iter_all_any_without(&self, field_name: String) -> impl Iterator { + self.0.iter().filter_map(move |i| { + if i.field_name != field_name { + Some(&i.t) + } else { + None + } + }) + } + + pub fn iter_set_yes(&self, field_name: String) -> impl Iterator { + self.0.iter().map(move |i| { + if i.field_name == field_name { + &i.yes + } else { + &i.t + } + }) + } + + pub fn iter_set_no(&self, field_name: String) -> impl Iterator { + self.0.iter().map(move |i| { + if i.field_name == field_name { + &i.no + } else { + &i.t + } + }) + } + + pub fn iter_all_yes(&self) -> impl Iterator { + self.0.iter().map(|i| &i.yes) + } + + pub fn iter_all_no(&self) -> impl Iterator { + self.0.iter().map(|i| &i.no) + } +} + +pub struct OperationBuilder<'o, 'c> { + pub(crate) operation: &'c Operation<'o>, + type_state: TypeStateList, +} + +impl<'o, 'c> OperationBuilder<'o, 'c> { + pub fn new(operation: &'c Operation<'o>) -> Self { + let type_state = Self::create_type_state(operation); + Self { + operation, + type_state, + } + } + + pub fn methods<'a, 's: 'a>( + &'s self, + field_names: &'a [Ident], + phantoms: &'a [TokenStream], + ) -> impl Iterator + 'a { + let builder_ident = format_ident!("{}Builder", self.operation.class_name); + self.operation.fields.iter().map(move |f| { + let n = sanitize_name_snake(f.name); + let st = &f.param_type; + let args = quote! { #n: #st }; + let add = format_ident!("add_{}s", f.kind.as_str()); + + let add_args = { + let mlir_ident = { + let name_str = &f.name; + quote! { ::melior::ir::Identifier::new(self.context, #name_str) } + }; + + // Argument types can be singular and variadic, but add functions in melior + // are always variadic, so we need to create a slice or vec for singular arguments + match &f.kind { + FieldKind::Operand(tc) | FieldKind::Result(tc) => { + if tc.is_variable_length() && !tc.is_optional() { + quote! { #n } + } else { + quote! { &[#n] } + } + } + FieldKind::Attribute(_) => { + quote! { &[(#mlir_ident, #n.into())] } + } + FieldKind::Successor(sc) => { + if sc.is_variadic() { + quote! { #n } + } else { + quote! { &[#n] } + } + } + FieldKind::Region(rc) => { + if rc.is_variadic() { + quote! { #n } + } else { + quote! { vec![#n] } + } + } + } + }; + + if !f.optional && !f.has_default { + if let FieldKind::Result(_) = f.kind { + if self.operation.can_infer_type { + // Don't allow setting the result type when it can be inferred + return quote!(); + } + } + let iter_all_any_without = self.type_state.iter_all_any_without(f.name.to_string()); + let iter_set_yes = self.type_state.iter_set_yes(f.name.to_string()); + let iter_set_no = self.type_state.iter_set_no(f.name.to_string()); + quote! { + impl<'c, #(#iter_all_any_without),*> #builder_ident<'c, #(#iter_set_no),*> { + pub fn #n(mut self, #args) -> #builder_ident<'c, #(#iter_set_yes),*> { + self.builder = self.builder.#add(#add_args); + let Self { context, mut builder, #(#field_names),* } = self; + #builder_ident { + context, + builder, + #(#phantoms),* + } + } + } + } + } else { + let iter_all_any = self.type_state.iter_all_any().collect::>(); + quote! { + impl<'c, #(#iter_all_any),*> #builder_ident<'c, #(#iter_all_any),*> { + pub fn #n(mut self, #args) -> #builder_ident<'c, #(#iter_all_any),*> { + self.builder = self.builder.#add(#add_args); + self + } + } + } + } + }) + } + + pub fn builder(&self) -> TokenStream { + let type_state_structs = self.type_state_structs(); + let builder_ident = format_ident!("{}Builder", self.operation.class_name); + + let field_names = self + .type_state + .iter() + .map(|f| sanitize_name_snake(&f.field_name)) + .collect::>(); + + let fields = self + .type_state + .iter_all_any() + .zip(field_names.iter()) + .map(|(g, n)| { + Some(quote! { + #[doc(hidden)] + #n: ::std::marker::PhantomData<#g> + }) + }); + + let phantoms: Vec<_> = field_names + .iter() + .map(|n| quote! { #n: ::std::marker::PhantomData }) + .collect(); + + let methods = self.methods(field_names.as_slice(), phantoms.as_slice()); + + let new = { + let name_str = self.operation.name(); + let iter_all_no = self.type_state.iter_all_no(); + let phantoms = phantoms.clone(); + quote! { + impl<'c> #builder_ident<'c, #(#iter_all_no),*> { + pub fn new(location: ::melior::ir::Location<'c>) -> Self { + Self { + context: unsafe { location.context().to_ref() }, + builder: ::melior::ir::operation::OperationBuilder::new(#name_str, location), + #(#phantoms),* + } + } + } + } + }; + + let build = { + let iter_all_yes = self.type_state.iter_all_yes(); + let class_name = format_ident!("{}", &self.operation.class_name); + let err = format!("should be a valid {}", class_name); + let maybe_infer = if self.operation.can_infer_type { + quote! { .enable_result_type_inference() } + } else { + quote! {} + }; + quote! { + impl<'c> #builder_ident<'c, #(#iter_all_yes),*> { + pub fn build(self) -> #class_name<'c> { + self.builder #maybe_infer.build().try_into().expect(#err) + } + } + } + }; + + let doc = format!("Builder for {}", self.operation.summary); + + let iter_all_any = self.type_state.iter_all_any(); + quote! { + #type_state_structs + + #[doc = #doc] + pub struct #builder_ident <'c, #(#iter_all_any),* > { + #[doc(hidden)] + builder: ::melior::ir::operation::OperationBuilder<'c>, + #[doc(hidden)] + context: &'c ::melior::Context, + #(#fields),* + } + + #new + + #(#methods)* + + #build + } + } + + pub fn create_op_builder_fn(&self) -> TokenStream { + let builder_ident = format_ident!("{}Builder", self.operation.class_name); + let iter_all_no = self.type_state.iter_all_no(); + quote! { + pub fn builder(location: ::melior::ir::Location<'c>) -> #builder_ident<'c, #(#iter_all_no),*> { + #builder_ident::new(location) + } + } + } + + pub fn default_constructor(&self) -> TokenStream { + let class_name = format_ident!("{}", &self.operation.class_name); + let name = sanitize_name_snake(&self.operation.short_name()); + let mut args = self + .operation + .fields + .iter() + .filter_map(|f| { + if !f.optional && !f.has_default { + if let FieldKind::Result(_) = f.kind { + if self.operation.can_infer_type { + return None; + } + } + let param_type = &f.param_type; + let param_name = &f.sanitized; + Some(quote! { #param_name: #param_type }) + } else { + None + } + }) + .collect::>(); + let builder_calls = self.operation.fields.iter().filter_map(|f| { + if !f.optional && !f.has_default { + if let FieldKind::Result(_) = f.kind { + if self.operation.can_infer_type { + return None; + } + } + let param_name = &f.sanitized; + Some(quote! { .#param_name(#param_name) }) + } else { + None + } + }); + args.push(quote! { location: ::melior::ir::Location<'c> }); + + let doc = format!("Create a new {}", self.operation.summary); + quote! { + #[allow(clippy::too_many_arguments)] + #[doc = #doc] + pub fn #name<'c>(#(#args),*) -> #class_name<'c> { + #class_name::builder(location)#(#builder_calls)*.build() + } + } + } + + fn create_type_state(operation: &Operation) -> TypeStateList { + TypeStateList( + operation + .fields + .iter() + .filter_map(|f| { + if !f.optional && !f.has_default { + if let FieldKind::Result(_) = f.kind { + if operation.can_infer_type { + return None; + } + } + Some(TypeStateItem::new(operation.class_name, f.name)) + } else { + None + } + }) + .collect(), + ) + } + + fn type_state_structs(&self) -> TokenStream { + self.type_state + .iter() + .map(|item| { + let yes = &item.yes; + let no = &item.no; + quote! { + #[allow(non_camel_case_types)] + #[doc(hidden)] + pub struct #yes; + #[allow(non_camel_case_types)] + #[doc(hidden)] + pub struct #no; + } + }) + .collect() + } +} diff --git a/macro/src/dialect/operation/mod.rs b/macro/src/dialect/operation/mod.rs new file mode 100644 index 0000000000..2c849a1ffd --- /dev/null +++ b/macro/src/dialect/operation/mod.rs @@ -0,0 +1,584 @@ +mod accessors; +mod builder; + +use crate::dialect::{ + error::{Error, ExpectedSuperClassError}, + types::{AttributeConstraint, RegionConstraint, SuccessorConstraint, Trait, TypeConstraint}, +}; +use crate::utility::sanitize_name_snake; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; +use tblgen::error::WithLocation; +use tblgen::record::Record; + +use self::builder::OperationBuilder; + +#[derive(Debug, Clone, Copy)] +pub enum FieldKind<'a> { + Operand(TypeConstraint<'a>), + Result(TypeConstraint<'a>), + Attribute(AttributeConstraint<'a>), + Successor(SuccessorConstraint<'a>), + Region(RegionConstraint<'a>), +} + +impl<'a> FieldKind<'a> { + pub fn as_str(&self) -> &'static str { + match self { + Self::Operand(_) => "operand", + Self::Result(_) => "result", + Self::Attribute(_) => "attribute", + Self::Successor(_) => "successor", + Self::Region(_) => "region", + } + } +} + +#[derive(Debug, Clone)] +pub struct SequenceInfo { + index: usize, + len: usize, +} + +#[derive(Clone, Debug)] +pub enum VariadicKind { + Simple { + seen_variable_length: bool, + }, + SameSize { + num_variable_length: usize, + num_preceding_simple: usize, + num_preceding_variadic: usize, + }, + AttrSized {}, +} + +#[derive(Debug, Clone)] +pub struct OperationField<'a> { + name: &'a str, + param_type: TokenStream, + return_type: TokenStream, + optional: bool, + has_default: bool, + kind: FieldKind<'a>, + seq_info: Option, + variadic_info: Option, + pub(crate) sanitized: Ident, +} + +impl<'a> OperationField<'a> { + pub fn from_attribute(name: &'a str, ac: AttributeConstraint<'a>) -> Self { + let kind_type: TokenStream = + syn::parse_str(ac.storage_type()).expect("storage type strings are valid"); + let (param_type, return_type) = { + if ac.is_unit() { + (quote! { bool }, quote! { bool }) + } else if ac.is_optional() { + (quote! { #kind_type<'c> }, quote! { Option<#kind_type<'c>> }) + } else { + (quote! { #kind_type<'c> }, quote! { #kind_type<'c> }) + } + }; + let sanitized = sanitize_name_snake(name); + Self { + name, + sanitized, + param_type, + return_type, + optional: ac.is_optional(), + has_default: ac.has_default_value(), + seq_info: None, + variadic_info: None, + kind: FieldKind::Attribute(ac), + } + } + + pub fn from_region(name: &'a str, rc: RegionConstraint<'a>, seq_info: SequenceInfo) -> Self { + let sanitized = sanitize_name_snake(name); + + let (param_type, return_type) = { + if rc.is_variadic() { + ( + quote! { Vec<::melior::ir::Region<'c>> }, + quote! { impl Iterator> }, + ) + } else { + ( + quote! { ::melior::ir::Region<'c> }, + quote! { ::melior::ir::RegionRef<'c, '_> }, + ) + } + }; + + Self { + name, + sanitized, + param_type, + return_type, + optional: false, + has_default: false, + kind: FieldKind::Region(rc), + seq_info: Some(seq_info), + variadic_info: None, + } + } + + pub fn from_successor( + name: &'a str, + sc: SuccessorConstraint<'a>, + seq_info: SequenceInfo, + ) -> Self { + let sanitized = sanitize_name_snake(name); + + let (param_type, return_type) = { + if sc.is_variadic() { + ( + quote! { &[&::melior::ir::Block<'c>] }, + quote! { impl Iterator> }, + ) + } else { + ( + quote! { &::melior::ir::Block<'c> }, + quote! { ::melior::ir::BlockRef<'c, '_> }, + ) + } + }; + + Self { + name, + sanitized, + param_type, + return_type, + optional: false, + has_default: false, + kind: FieldKind::Successor(sc), + seq_info: Some(seq_info), + variadic_info: None, + } + } + + pub fn from_operand( + name: &'a str, + tc: TypeConstraint<'a>, + seq_info: SequenceInfo, + variadic_info: VariadicKind, + ) -> Self { + Self::from_element(name, tc, FieldKind::Operand(tc), seq_info, variadic_info) + } + + pub fn from_result( + name: &'a str, + tc: TypeConstraint<'a>, + seq_info: SequenceInfo, + variadic_info: VariadicKind, + ) -> Self { + Self::from_element(name, tc, FieldKind::Result(tc), seq_info, variadic_info) + } + + fn from_element( + name: &'a str, + tc: TypeConstraint<'a>, + kind: FieldKind<'a>, + seq_info: SequenceInfo, + variadic_info: VariadicKind, + ) -> Self { + let (param_kind_type, return_kind_type) = match &kind { + FieldKind::Operand(_) => ( + quote!(::melior::ir::Value<'c, '_>), + quote!(::melior::ir::Value<'c, '_>), + ), + FieldKind::Result(_) => ( + quote!(::melior::ir::Type<'c>), + quote!(::melior::ir::operation::OperationResult<'c, '_>), + ), + _ => unreachable!(), + }; + let (param_type, return_type) = { + if tc.is_variable_length() { + if tc.is_optional() { + ( + quote! { #param_kind_type }, + quote! { Option<#return_kind_type> }, + ) + } else { + ( + quote! { &[#param_kind_type] }, + quote! { impl Iterator }, + ) + } + } else { + (param_kind_type, return_kind_type) + } + }; + let sanitized = sanitize_name_snake(name); + Self { + name, + sanitized, + param_type, + return_type, + optional: tc.is_optional(), + has_default: false, + seq_info: Some(seq_info), + variadic_info: Some(variadic_info), + kind, + } + } +} + +#[derive(Debug, Clone)] +pub struct Operation<'a> { + def: Record<'a>, + pub(crate) dialect: Record<'a>, + pub(crate) class_name: &'a str, + pub(crate) fields: Vec>, + pub(crate) can_infer_type: bool, + pub(crate) summary: String, + pub(crate) description: String, +} + +impl<'a> Operation<'a> { + pub fn from_def(def: Record<'a>) -> Result { + let dialect = def.def_value("opDialect")?; + + let mut work_list: Vec<_> = vec![def.list_value("traits")?]; + let mut traits = Vec::new(); + while let Some(trait_def) = work_list.pop() { + for v in trait_def.iter() { + let trait_def: Record = v + .try_into() + .map_err(|e: tblgen::Error| e.set_location(def))?; + if trait_def.subclass_of("TraitList") { + work_list.push(trait_def.list_value("traits")?); + } else { + if trait_def.subclass_of("Interface") { + work_list.push(trait_def.list_value("baseInterfaces")?); + } + traits.push(Trait::new(trait_def)) + } + } + } + + let successors_dag = def.dag_value("successors")?; + let len = successors_dag.num_args(); + let successors = successors_dag.args().enumerate().map(|(i, (n, v))| { + Result::<_, Error>::Ok(OperationField::from_successor( + n, + SuccessorConstraint::new( + v.try_into() + .map_err(|e: tblgen::Error| e.set_location(def))?, + ), + SequenceInfo { index: i, len }, + )) + }); + + let regions_dag = def.dag_value("regions").expect("operation has regions"); + let len = regions_dag.num_args(); + let regions = regions_dag.args().enumerate().map(|(i, (n, v))| { + Ok(OperationField::from_region( + n, + RegionConstraint::new( + v.try_into() + .map_err(|e: tblgen::Error| e.set_location(def))?, + ), + SequenceInfo { index: i, len }, + )) + }); + + // Creates an initial `VariadicKind` instance based on SameSize and AttrSized traits. + let initial_variadic_kind = |num_variable_length: usize, kind_name_upper: &str| { + let same_size_trait = format!("::mlir::OpTrait::SameVariadic{}Size", kind_name_upper); + let attr_sized = format!("::mlir::OpTrait::AttrSized{}Segments", kind_name_upper); + if num_variable_length <= 1 { + VariadicKind::Simple { + seen_variable_length: false, + } + } else if traits.iter().any(|t| t.has_name(&same_size_trait)) { + VariadicKind::SameSize { + num_variable_length, + num_preceding_simple: 0, + num_preceding_variadic: 0, + } + } else if traits.iter().any(|t| t.has_name(&attr_sized)) { + VariadicKind::AttrSized {} + } else { + unimplemented!("unsupported {} structure", kind_name_upper) + } + }; + + // Updates the given `VariadicKind` and returns the original value. + let update_variadic_kind = |tc: &TypeConstraint, variadic_kind: &mut VariadicKind| { + let orig_variadic_kind = variadic_kind.clone(); + match variadic_kind { + VariadicKind::Simple { + seen_variable_length, + } => { + if tc.is_variable_length() { + *seen_variable_length = true; + } + variadic_kind.clone() + } + VariadicKind::SameSize { + num_preceding_simple, + num_preceding_variadic, + .. + } => { + if tc.is_variable_length() { + *num_preceding_variadic += 1; + } else { + *num_preceding_simple += 1; + } + orig_variadic_kind + } + VariadicKind::AttrSized {} => variadic_kind.clone(), + } + }; + + let results_dag = def.dag_value("results")?; + let results = results_dag.args().map(|(n, arg)| { + let mut arg_def: Record = arg + .try_into() + .map_err(|e: tblgen::Error| e.set_location(def))?; + + if arg_def.subclass_of("OpVariable") { + arg_def = arg_def.def_value("constraint")?; + } + + Ok((n, TypeConstraint::new(arg_def))) + }); + let num_results = results.clone().count(); + let num_variable_length_results = results + .clone() + .filter(|res| { + res.as_ref() + .map(|(_, tc)| tc.is_variable_length()) + .unwrap_or_default() + }) + .count(); + let mut kind = initial_variadic_kind(num_variable_length_results, "Result"); + let results = results.enumerate().map(|(i, res)| { + res.map(|(n, tc)| { + let current_kind = update_variadic_kind(&tc, &mut kind); + OperationField::from_result( + n, + tc, + SequenceInfo { + index: i, + len: num_results, + }, + current_kind, + ) + }) + }); + + let arguments_dag = def.dag_value("arguments")?; + let arguments = arguments_dag.args().map(|(name, arg)| { + let mut arg_def: Record = arg + .try_into() + .map_err(|e: tblgen::Error| e.set_location(def))?; + + if arg_def.subclass_of("OpVariable") { + arg_def = arg_def.def_value("constraint")?; + } + + Ok((name, arg_def)) + }); + + let operands = arguments.clone().filter_map(|res| { + res.map(|(n, arg_def)| { + if arg_def.subclass_of("TypeConstraint") { + Some((n, TypeConstraint::new(arg_def))) + } else { + None + } + }) + .transpose() + }); + let num_operands = operands.clone().count(); + let num_variable_length_operands = operands + .clone() + .filter(|res| { + res.as_ref() + .map(|(_, tc)| tc.is_variable_length()) + .unwrap_or_default() + }) + .count(); + let mut kind = initial_variadic_kind(num_variable_length_operands, "Operand"); + let operands = operands.enumerate().map(|(i, res)| { + res.map(|(name, tc)| { + let current_kind = update_variadic_kind(&tc, &mut kind); + OperationField::from_operand( + name, + tc, + SequenceInfo { + index: i, + len: num_operands, + }, + current_kind, + ) + }) + }); + + let attributes = arguments.clone().filter_map(|res| { + res.map(|(name, arg_def)| { + if arg_def.subclass_of("Attr") { + assert!(!name.is_empty()); + assert!(!arg_def.subclass_of("DerivedAttr")); + Some(OperationField::from_attribute( + name, + AttributeConstraint::new(arg_def), + )) + } else { + None + } + }) + .transpose() + }); + + let derived_attrs = def.values().map(Ok).filter_map(|val| { + val.and_then(|val| { + if let Ok(def) = Record::try_from(val) { + if def.subclass_of("Attr") { + def.subclass_of("DerivedAttr") + .then_some(()) + .ok_or_else(|| { + ExpectedSuperClassError("DerivedAttr".into()).with_location(def) + })?; + return Ok(Some(OperationField::from_attribute( + def.name()?, + AttributeConstraint::new(def), + ))); + } + } + Ok(None) + }) + .transpose() + }); + + let fields = successors + .chain(regions) + .chain(results) + .chain(operands) + .chain(attributes) + .chain(derived_attrs) + .collect::, _>>()?; + + let name = def.name().unwrap(); + let class_name = if !name.contains('_') { + // Class name with a leading underscore and without dialect prefix + name + } else if !name.starts_with('_') { + // Class name without dialect prefix + let mut split = name.split('_'); + split.next(); + split.next().unwrap() + } else { + name + }; + + let can_infer_type = traits.iter().any(|t| { + (t.has_name("::mlir::OpTrait::FirstAttrDerivedResultType") + || t.has_name("::mlir::OpTrait::SameOperandsAndResultType")) + && num_variable_length_results == 0 + || t.has_name("::mlir::InferTypeOpInterface::Trait") && regions_dag.num_args() == 0 + }); + + let short_name = def.string_value("opName").expect("operation has name"); + let summary = def.str_value("summary").unwrap_or(&short_name); + let description = def.str_value("description").unwrap_or(""); + + let summary = if !summary.is_empty() { + format!( + "[`{}`]({}) operation: {}", + short_name, + class_name, + summary[0..1].to_uppercase() + &summary[1..] + ) + } else { + format!("[`{}`]({}) operation", short_name, class_name) + }; + let description = unindent::unindent(description); + + Ok(Self { + def, + dialect, + class_name, + fields, + can_infer_type, + summary, + description, + }) + } + + fn short_name(&self) -> String { + self.def.string_value("opName").expect("operation has name") + } + + fn name(&self) -> String { + let op_name = self.def.string_value("opName").expect("operation has name"); + let dialect_name = self + .dialect + .string_value("name") + .expect("dialect name is string"); + if !dialect_name.is_empty() { + format!("{}.{}", dialect_name, op_name) + } else { + op_name + } + } +} + +impl<'a> ToTokens for Operation<'a> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let class_name = format_ident!("{}", &self.class_name); + let name = self.name(); + let accessors = self.fields.iter().map(|f| f.accessors()); + let builder = OperationBuilder::new(self); + let builder_tokens = builder.builder(); + let builder_fn = builder.create_op_builder_fn(); + let default_constructor = builder.default_constructor(); + let summary = &self.summary; + let description = &self.description; + tokens.append_all(quote! { + #[doc = #summary] + #[doc = "\n\n"] + #[doc = #description] + pub struct #class_name<'c> { + operation: ::melior::ir::operation::Operation<'c>, + } + + impl<'c> #class_name<'c> { + pub fn name() -> &'static str { + #name + } + + pub fn operation(&self) -> &::melior::ir::operation::Operation<'c> { + &self.operation + } + + #builder_fn + + #(#accessors)* + } + + #builder_tokens + + #default_constructor + + impl<'c> TryFrom<::melior::ir::operation::Operation<'c>> for #class_name<'c> { + type Error = ::melior::Error; + + fn try_from( + operation: ::melior::ir::operation::Operation<'c>, + ) -> Result { + Ok(Self { operation }) + } + } + + impl<'c> Into<::melior::ir::operation::Operation<'c>> for #class_name<'c> { + fn into(self) -> ::melior::ir::operation::Operation<'c> { + self.operation + } + } + }) + } +} diff --git a/macro/src/dialect/types.rs b/macro/src/dialect/types.rs new file mode 100644 index 0000000000..eeeb59e592 --- /dev/null +++ b/macro/src/dialect/types.rs @@ -0,0 +1,244 @@ +use std::collections::HashMap; + +use lazy_static::lazy_static; +use tblgen::record::Record; + +lazy_static! { + pub static ref ATTRIBUTE_TYPES: HashMap<&'static str, &'static str> = { + let mut m = HashMap::new(); + macro_rules! attr { + ($($mlir:ident => $melior:ident),* $(,)*) => { + $( + m.insert( + concat!("::mlir::", stringify!($mlir)), + concat!("::melior::ir::attribute::", stringify!($melior)), + ); + )* + }; + } + attr!( + ArrayAttr => ArrayAttribute, + Attribute => Attribute, + DenseElementsAttr => DenseElementsAttribute, + DenseI32ArrayAttr => DenseI32ArrayAttribute, + FlatSymbolRefAttr => FlatSymbolRefAttribute, + FloatAttr => FloatAttribute, + IntegerAttr => IntegerAttribute, + StringAttr => StringAttribute, + TypeAttr => TypeAttribute, + ); + m + }; +} + +#[derive(Debug, Clone, Copy)] +pub struct RegionConstraint<'a>(Record<'a>); + +#[allow(unused)] +impl<'a> RegionConstraint<'a> { + pub fn new(record: Record<'a>) -> Self { + Self(record) + } + pub fn is_variadic(&self) -> bool { + self.0.subclass_of("VariadicRegion") + } +} + +impl<'a> std::ops::Deref for RegionConstraint<'a> { + type Target = Record<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SuccessorConstraint<'a>(Record<'a>); + +#[allow(unused)] +impl<'a> SuccessorConstraint<'a> { + pub fn new(record: Record<'a>) -> Self { + Self(record) + } + pub fn is_variadic(&self) -> bool { + self.0.subclass_of("VariadicSuccessor") + } +} + +impl<'a> std::ops::Deref for SuccessorConstraint<'a> { + type Target = Record<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct TypeConstraint<'a>(Record<'a>); + +#[allow(unused)] +impl<'a> TypeConstraint<'a> { + pub fn new(record: Record<'a>) -> Self { + Self(record) + } + + pub fn is_optional(&self) -> bool { + self.0.subclass_of("Optional") + } + pub fn is_variadic(&self) -> bool { + self.0.subclass_of("Variadic") + } + pub fn is_variadic_of_variadic(&self) -> bool { + self.0.subclass_of("VariadicOfVariadic") + } + pub fn is_variable_length(&self) -> bool { + self.is_variadic() || self.is_optional() + } +} + +impl<'a> std::ops::Deref for TypeConstraint<'a> { + type Target = Record<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, Copy)] +pub struct AttributeConstraint<'a>(Record<'a>); + +#[allow(unused)] +impl<'a> AttributeConstraint<'a> { + pub fn new(record: Record<'a>) -> Self { + Self(record) + } + + pub fn is_derived(&self) -> bool { + self.0.subclass_of("DerivedAttr") + } + + pub fn is_type_attr(&self) -> bool { + self.0.subclass_of("TypeAttrBase") + } + + pub fn is_symbol_ref_attr(&self) -> bool { + self.0.name() == Ok("SymbolRefAttr") + || self.0.name() == Ok("FlatSymbolRefAttr") + || self.0.subclass_of("SymbolRefAttr") + || self.0.subclass_of("FlatSymbolRefAttr") + } + + pub fn is_enum_attr(&self) -> bool { + self.0.subclass_of("EnumAttrInfo") + } + + pub fn is_optional(&self) -> bool { + self.0.bit_value("isOptional").unwrap_or(false) + } + + pub fn storage_type(&self) -> &'static str { + self.0 + .string_value("storageType") + .ok() + .and_then(|v| ATTRIBUTE_TYPES.get(v.as_str().trim())) + .copied() + .unwrap_or("::melior::ir::attribute::Attribute") + } + + pub fn is_unit(&self) -> bool { + self.0 + .string_value("storageType") + .map(|v| v == "::mlir::UnitAttr") + .unwrap_or(false) + } + + pub fn has_default_value(&self) -> bool { + self.0 + .string_value("defaultValue") + .map(|s| !s.is_empty()) + .unwrap_or(false) + } +} + +impl<'a> std::ops::Deref for AttributeConstraint<'a> { + type Target = Record<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub enum TraitKind { + Native { name: String, is_structural: bool }, + Pred {}, + Internal { name: String }, + Interface { name: String }, +} + +#[derive(Debug, Clone)] +pub struct Trait<'a> { + kind: TraitKind, + def: Record<'a>, +} + +#[allow(unused)] +impl<'a> Trait<'a> { + pub fn new(def: Record<'a>) -> Self { + let kind = if def.subclass_of("PredTrait") { + TraitKind::Pred {} + } else if def.subclass_of("InterfaceTrait") { + TraitKind::Interface { + name: Self::name(def), + } + } else if def.subclass_of("NativeTrait") { + TraitKind::Native { + name: Self::name(def), + is_structural: def.subclass_of("StructuralOpTrait"), + } + } else if def.subclass_of("GenInternalTrait") { + TraitKind::Internal { + name: def + .string_value("trait") + .expect("trait def has trait value"), + } + } else { + unreachable!("invalid trait") + }; + Self { kind, def } + } + + pub fn has_name(&self, n: &str) -> bool { + match &self.kind { + TraitKind::Native { name, .. } + | TraitKind::Internal { name } + | TraitKind::Interface { name } => n == name, + TraitKind::Pred {} => false, + } + } + + fn name(def: Record) -> String { + let r#trait = def + .string_value("trait") + .expect("trait def has trait value"); + let namespace = def.string_value("cppNamespace").ok().and_then(|n| { + if n.is_empty() { + None + } else { + Some(n) + } + }); + if let Some(namespace) = namespace { + format!("{}::{}", namespace, r#trait) + } else { + r#trait + } + } + + pub fn kind(&self) -> &TraitKind { + &self.kind + } +} + +impl<'a> std::ops::Deref for Trait<'a> { + type Target = Record<'a>; + fn deref(&self) -> &Self::Target { + &self.def + } +} diff --git a/macro/src/lib.rs b/macro/src/lib.rs index d419e89459..05f593d743 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -1,16 +1,25 @@ mod attribute; +mod dialect; mod operation; mod parse; mod pass; mod r#type; mod utility; +use dialect::DialectMacroInput; use parse::{DialectOperationSet, IdentifierList}; use proc_macro::TokenStream; use quote::quote; use std::error::Error; use syn::parse_macro_input; +#[proc_macro] +pub fn dialect(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DialectMacroInput); + + convert_result(dialect::generate_dialect(input)) +} + #[proc_macro] pub fn binary_operations(stream: TokenStream) -> TokenStream { let set = parse_macro_input!(stream as DialectOperationSet); diff --git a/macro/src/utility.rs b/macro/src/utility.rs index c395fb594d..ab4ed1d5fe 100644 --- a/macro/src/utility.rs +++ b/macro/src/utility.rs @@ -1,6 +1,35 @@ +use convert_case::{Case, Casing}; use once_cell::sync::Lazy; +use proc_macro2::Ident; +use quote::format_ident; use regex::{Captures, Regex}; +static RESERVED_NAMES: &[&str] = &["name", "operation", "builder"]; + +pub fn sanitize_name_snake(name: &str) -> Ident { + sanitize_name(&name.to_case(Case::Snake)) +} + +pub fn sanitize_name(name: &str) -> Ident { + // Replace any "." with "_" + let mut name = name.replace('.', "_"); + + // Add "_" suffix to avoid conflicts with existing methods + if RESERVED_NAMES.contains(&name.as_str()) + || name + .chars() + .next() + .expect("name has at least one char") + .is_numeric() + { + name = format!("_{}", name); + } + + // Try to parse the string as an ident, and prefix the identifier + // with "r#" if it is not a valid identifier. + syn::parse_str::(&name).unwrap_or(format_ident!("r#{}", name)) +} + static PATTERN: Lazy = Lazy::new(|| { Regex::new(r#"(bf_16|f_16|f_32|f_64|i_8|i_16|i_32|i_64|float_8_e_[0-9]_m_[0-9](_fn)?)"#) .unwrap() diff --git a/macro/tests/ods_include/operand.td b/macro/tests/ods_include/operand.td new file mode 100644 index 0000000000..89abc52c94 --- /dev/null +++ b/macro/tests/ods_include/operand.td @@ -0,0 +1,20 @@ +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +def OperandTest_Dialect : Dialect { + let name = "operand_test"; + let cppNamespace = "::mlir::operand_test"; +} + +class OperandTest_Op traits = []> : + Op; + +def OperandTest_SimpleOp : OperandTest_Op<"simple"> { + let arguments = (ins I32:$lhs, I32:$rhs); + let results = (outs I32:$res); +} + +def OperandTest_VariadicOp : OperandTest_Op<"variadic"> { + let arguments = (ins I32:$first, Variadic:$others); + let results = (outs I32:$res); +} diff --git a/macro/tests/ods_include/region.td b/macro/tests/ods_include/region.td new file mode 100644 index 0000000000..bc8aa132db --- /dev/null +++ b/macro/tests/ods_include/region.td @@ -0,0 +1,18 @@ +include "mlir/IR/OpBase.td" + +def RegionTest_Dialect : Dialect { + let name = "region_test"; + let cppNamespace = "::mlir::region_test"; +} + +class RegionTest_Op traits = []> : + Op; + +def RegionTest_SingleOp : RegionTest_Op<"single"> { + let regions = (region SizedRegion<1>:$defaultRegion); +} + +def RegionTest_VariadicOp : RegionTest_Op<"variadic"> { + let regions = (region SizedRegion<1>:$defaultRegion, + VariadicRegion>:$otherRegions); +} diff --git a/macro/tests/operand.rs b/macro/tests/operand.rs new file mode 100644 index 0000000000..4a27a8b25b --- /dev/null +++ b/macro/tests/operand.rs @@ -0,0 +1,57 @@ +use melior::ir::{operation::OperationBuilder, Block, Location, Type, ValueLike}; + +mod utility; + +use utility::*; + +melior_macro::dialect! { + name: "operand_test", + td_file: "macro/tests/ods_include/operand.td", +} + +#[test] +fn simple() { + let context = create_test_context(); + context.set_allow_unregistered_dialects(true); + + let location = Location::unknown(&context); + + let r#type = Type::parse(&context, "i32").unwrap(); + let block = Block::new(&[(r#type, location), (r#type, location)]); + let op = operand_test::simple( + r#type, + block.argument(0).unwrap().into(), + block.argument(1).unwrap().into(), + location, + ); + + assert_eq!(op.lhs(), block.argument(0).unwrap().into()); + assert_eq!(op.rhs(), block.argument(1).unwrap().into()); + assert_eq!(op.operation().operand_count(), 2); +} + +#[test] +fn variadic_after_single() { + let context = create_test_context(); + context.set_allow_unregistered_dialects(true); + + let location = Location::unknown(&context); + + let r#type = Type::parse(&context, "i32").unwrap(); + let block = Block::new(&[(r#type, location), (r#type, location), (r#type, location)]); + let op = operand_test::variadic( + r#type, + block.argument(0).unwrap().into(), + &[ + block.argument(2).unwrap().into(), + block.argument(1).unwrap().into(), + ], + location, + ); + + assert_eq!(op.first(), block.argument(0).unwrap().into()); + assert_eq!(op.others().next(), Some(block.argument(2).unwrap().into())); + assert_eq!(op.others().nth(1), Some(block.argument(1).unwrap().into())); + assert_eq!(op.operation().operand_count(), 3); + assert_eq!(op.others().count(), 2); +} diff --git a/macro/tests/region.rs b/macro/tests/region.rs new file mode 100644 index 0000000000..24e39f13ed --- /dev/null +++ b/macro/tests/region.rs @@ -0,0 +1,59 @@ +use melior::ir::{Block, Location, Region}; + +mod utility; + +use utility::*; + +melior_macro::dialect! { + name: "region_test", + td_file: "macro/tests/ods_include/region.td", +} + +#[test] +fn single() { + let context = create_test_context(); + context.set_allow_unregistered_dialects(true); + + let location = Location::unknown(&context); + + let op = { + let block = Block::new(&[]); + let r1 = Region::new(); + r1.append_block(block); + region_test::single(r1, location) + }; + + assert!(op.default_region().first_block().is_some()); +} + +#[test] +fn variadic_after_single() { + let context = create_test_context(); + context.set_allow_unregistered_dialects(true); + + let location = Location::unknown(&context); + + let op = { + let block = Block::new(&[]); + let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); + r2.append_block(block); + region_test::variadic(r1, vec![r2, r3], location) + }; + + let op2 = { + let block = Block::new(&[]); + let (r1, r2, r3) = (Region::new(), Region::new(), Region::new()); + r2.append_block(block); + region_test::VariadicOp::builder(location) + .default_region(r1) + .other_regions(vec![r2, r3]) + .build() + }; + + assert_eq!(op.operation().to_string(), op2.operation().to_string()); + + assert!(op.default_region().first_block().is_none()); + assert_eq!(op.other_regions().count(), 2); + assert!(op.other_regions().next().unwrap().first_block().is_some()); + assert!(op.other_regions().nth(1).unwrap().first_block().is_none()); +} diff --git a/macro/tests/utility.rs b/macro/tests/utility.rs new file mode 100644 index 0000000000..eae7bc2d8f --- /dev/null +++ b/macro/tests/utility.rs @@ -0,0 +1,26 @@ +use melior::{ + dialect::DialectRegistry, + utility::{register_all_dialects, register_all_llvm_translations}, + Context, +}; + +pub fn load_all_dialects(context: &Context) { + let registry = DialectRegistry::new(); + register_all_dialects(®istry); + context.append_dialect_registry(®istry); + context.load_all_available_dialects(); +} + +pub fn create_test_context() -> Context { + let context = Context::new(); + + context.attach_diagnostic_handler(|diagnostic| { + eprintln!("{}", diagnostic); + true + }); + + load_all_dialects(&context); + register_all_llvm_translations(&context); + + context +} diff --git a/melior/Cargo.toml b/melior/Cargo.toml index c05240c4e8..4cdb999879 100644 --- a/melior/Cargo.toml +++ b/melior/Cargo.toml @@ -9,6 +9,9 @@ documentation = "https://raviqqe.github.io/melior/melior/" readme = "../README.md" keywords = ["mlir", "llvm"] +[features] +ods-dialects = [] + [dependencies] criterion = "0.5.1" dashmap = "5.5.0" diff --git a/melior/src/dialect.rs b/melior/src/dialect.rs index 1cac4672f8..a7cafbd2d3 100644 --- a/melior/src/dialect.rs +++ b/melior/src/dialect.rs @@ -18,6 +18,9 @@ use crate::{ use mlir_sys::{mlirDialectEqual, mlirDialectGetContext, mlirDialectGetNamespace, MlirDialect}; use std::marker::PhantomData; +#[cfg(feature = "ods-dialects")] +pub mod ods; + /// A dialect. #[derive(Clone, Copy, Debug)] pub struct Dialect<'c> { diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs new file mode 100644 index 0000000000..d00131991b --- /dev/null +++ b/melior/src/dialect/ods.rs @@ -0,0 +1,100 @@ +melior_macro::dialect! { + name: "affine", + tablegen: r#"include "mlir/Dialect/Affine/IR/AffineOps.td""# +} +melior_macro::dialect! { + name: "amdgpu", + tablegen: r#"include "mlir/Dialect/AMDGPU/AMDGPU.td""# +} +melior_macro::dialect! { + name: "arith", + tablegen: r#"include "mlir/Dialect/Arith/IR/ArithOps.td""# +} +melior_macro::dialect! { + name: "arm_neon", + tablegen: r#"include "mlir/Dialect/ArmNeon/ArmNeon.td""# +} +melior_macro::dialect! { + name: "arm_sve", + tablegen: r#"include "mlir/Dialect/ArmSVE/ArmSVE.td""# +} +melior_macro::dialect! { + name: "async", + tablegen: r#"include "mlir/Dialect/Async/IR/AsyncOps.td""# +} +melior_macro::dialect! { + name: "bufferization", + tablegen: r#"include "mlir/Dialect/Bufferization/IR/BufferizationOps.td""# +} +melior_macro::dialect! { + name: "cf", + tablegen: r#"include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td""# +} +melior_macro::dialect! { + name: "func", + tablegen: r#"include "mlir/Dialect/Func/IR/FuncOps.td""# +} +melior_macro::dialect! { + name: "index", + tablegen: r#"include "mlir/Dialect/Index/IR/IndexOps.td""# +} +melior_macro::dialect! { + name: "llvm", + tablegen: r#"include "mlir/Dialect/LLVMIR/LLVMOps.td""# +} +melior_macro::dialect! { + name: "memref", + tablegen: r#"include "mlir/Dialect/MemRef/IR/MemRefOps.td""# +} +melior_macro::dialect! { + name: "scf", + tablegen: r#"include "mlir/Dialect/SCF/IR/SCFOps.td""# +} +melior_macro::dialect! { + name: "pdl", + tablegen: r#"include "mlir/Dialect/PDL/IR/PDLOps.td""# +} +melior_macro::dialect! { + name: "pdl_interp", + tablegen: r#"include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td""# +} +melior_macro::dialect! { + name: "math", + tablegen: r#"include "mlir/Dialect/Math/IR/MathOps.td""# +} +melior_macro::dialect! { + name: "gpu", + tablegen: r#"include "mlir/Dialect/GPU/IR/GPUOps.td""# +} +melior_macro::dialect! { + name: "linalg", + tablegen: r#"include "mlir/Dialect/Linalg/IR/LinalgOps.td""# +} +melior_macro::dialect! { + name: "quant", + tablegen: r#"include "mlir/Dialect/Quant/QuantOps.td""# +} +melior_macro::dialect! { + name: "shape", + tablegen: r#"include "mlir/Dialect/Shape/IR/ShapeOps.td""# +} +melior_macro::dialect! { + name: "sparse_tensor", + tablegen: r#"include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td""# +} +melior_macro::dialect! { + name: "tensor", + tablegen: r#"include "mlir/Dialect/Tensor/IR/TensorOps.td""# +} +melior_macro::dialect! { + name: "tosa", + tablegen: r#"include "mlir/Dialect/Tosa/IR/TosaOps.td""# +} +melior_macro::dialect! { + name: "transform", + tablegen: r#"include "mlir/Dialect/Transform/IR/TransformOps.td""# +} +melior_macro::dialect! { + name: "vector", + tablegen: r#"include "mlir/Dialect/Vector/IR/VectorOps.td""# +} diff --git a/melior/src/lib.rs b/melior/src/lib.rs index b88c3d246d..eda303821f 100644 --- a/melior/src/lib.rs +++ b/melior/src/lib.rs @@ -1,5 +1,7 @@ #![doc = include_str!("../README.md")] +extern crate self as melior; + #[macro_use] mod r#macro; mod context; @@ -22,6 +24,8 @@ pub use self::{ string_ref::StringRef, }; +pub use melior_macro::dialect; + #[cfg(test)] mod tests { use crate::{