Skip to content

Commit

Permalink
Refactor pass macros (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Dec 4, 2023
1 parent db45fe9 commit 86f261c
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 91 deletions.
46 changes: 5 additions & 41 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod r#type;
mod utility;

use dialect::DialectInput;
use parse::{DialectOperationSet, IdentifierList};
use parse::{DialectOperationSet, IdentifierList, PassSet};
use proc_macro::TokenStream;
use quote::quote;
use std::error::Error;
Expand Down Expand Up @@ -68,15 +68,6 @@ pub fn attribute_check_functions(stream: TokenStream) -> TokenStream {
convert_result(attribute::generate(identifiers.identifiers()))
}

#[proc_macro]
pub fn async_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Async").unwrap().into()
}))
}

#[proc_macro]
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
Expand All @@ -90,38 +81,11 @@ pub fn conversion_passes(stream: TokenStream) -> TokenStream {
}

#[proc_macro]
pub fn gpu_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("GPU").unwrap().into()
}))
}

#[proc_macro]
pub fn transform_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Transforms").unwrap().into()
}))
}

#[proc_macro]
pub fn linalg_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Linalg").unwrap().into()
}))
}

#[proc_macro]
pub fn sparse_tensor_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
pub fn passes(stream: TokenStream) -> TokenStream {
let set = parse_macro_input!(stream as PassSet);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("SparseTensor").unwrap().into()
convert_result(pass::generate(set.identifiers(), |name| {
name.strip_prefix(&set.prefix().value()).unwrap().into()
}))
}

Expand Down
33 changes: 32 additions & 1 deletion macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use syn::{
bracketed,
parse::{Parse, ParseStream},
punctuated::Punctuated,
Result, Token,
LitStr, Result, Token,
};

pub struct IdentifierList {
Expand Down Expand Up @@ -56,3 +56,34 @@ impl Parse for DialectOperationSet {
})
}
}

pub struct PassSet {
prefix: LitStr,
identifiers: IdentifierList,
}

impl PassSet {
pub const fn prefix(&self) -> &LitStr {
&self.prefix
}

pub fn identifiers(&self) -> &[Ident] {
self.identifiers.identifiers()
}
}

impl Parse for PassSet {
fn parse(input: ParseStream) -> Result<Self> {
let prefix = input.parse()?;
<Token![,]>::parse(input)?;

Ok(Self {
prefix,
identifiers: {
let content;
bracketed!(content in input);
content.parse::<IdentifierList>()?
},
})
}
}
17 changes: 10 additions & 7 deletions melior/src/pass/async.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
//! Async passes.
melior_macro::async_passes!(
mlirCreateAsyncAsyncFuncToAsyncRuntime,
mlirCreateAsyncAsyncParallelFor,
mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting,
mlirCreateAsyncAsyncRuntimeRefCounting,
mlirCreateAsyncAsyncRuntimeRefCountingOpt,
mlirCreateAsyncAsyncToAsyncRuntime,
melior_macro::passes!(
"Async",
[
mlirCreateAsyncAsyncFuncToAsyncRuntime,
mlirCreateAsyncAsyncParallelFor,
mlirCreateAsyncAsyncRuntimePolicyBasedRefCounting,
mlirCreateAsyncAsyncRuntimeRefCounting,
mlirCreateAsyncAsyncRuntimeRefCountingOpt,
mlirCreateAsyncAsyncToAsyncRuntime,
]
);
15 changes: 9 additions & 6 deletions melior/src/pass/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! GPU passes.
melior_macro::gpu_passes!(
// spell-checker: disable-next-line
mlirCreateGPUGpuAsyncRegionPass,
mlirCreateGPUGpuKernelOutlining,
mlirCreateGPUGpuLaunchSinkIndexComputations,
mlirCreateGPUGpuMapParallelLoopsPass,
melior_macro::passes!(
"GPU",
[
// spell-checker: disable-next-line
mlirCreateGPUGpuAsyncRegionPass,
mlirCreateGPUGpuKernelOutlining,
mlirCreateGPUGpuLaunchSinkIndexComputations,
mlirCreateGPUGpuMapParallelLoopsPass,
]
);
27 changes: 15 additions & 12 deletions melior/src/pass/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
//! Linalg passes.
melior_macro::linalg_passes!(
mlirCreateLinalgConvertElementwiseToLinalg,
mlirCreateLinalgLinalgBufferize,
mlirCreateLinalgLinalgDetensorize,
mlirCreateLinalgLinalgElementwiseOpFusion,
mlirCreateLinalgLinalgFoldUnitExtentDims,
mlirCreateLinalgLinalgGeneralization,
mlirCreateLinalgLinalgInlineScalarOperands,
mlirCreateLinalgLinalgLowerToAffineLoops,
mlirCreateLinalgLinalgLowerToLoops,
mlirCreateLinalgLinalgLowerToParallelLoops,
mlirCreateLinalgLinalgNamedOpConversion,
melior_macro::passes!(
"Linalg",
[
mlirCreateLinalgConvertElementwiseToLinalg,
mlirCreateLinalgLinalgBufferize,
mlirCreateLinalgLinalgDetensorize,
mlirCreateLinalgLinalgElementwiseOpFusion,
mlirCreateLinalgLinalgFoldUnitExtentDims,
mlirCreateLinalgLinalgGeneralization,
mlirCreateLinalgLinalgInlineScalarOperands,
mlirCreateLinalgLinalgLowerToAffineLoops,
mlirCreateLinalgLinalgLowerToLoops,
mlirCreateLinalgLinalgLowerToParallelLoops,
mlirCreateLinalgLinalgNamedOpConversion,
]
);
21 changes: 12 additions & 9 deletions melior/src/pass/sparse_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Sparse tensor passes.
melior_macro::sparse_tensor_passes!(
mlirCreateSparseTensorPostSparsificationRewrite,
mlirCreateSparseTensorPreSparsificationRewrite,
mlirCreateSparseTensorSparseBufferRewrite,
mlirCreateSparseTensorSparseTensorCodegen,
mlirCreateSparseTensorSparseTensorConversionPass,
mlirCreateSparseTensorSparseVectorization,
mlirCreateSparseTensorSparsificationPass,
mlirCreateSparseTensorStorageSpecifierToLLVM,
melior_macro::passes!(
"SparseTensor",
[
mlirCreateSparseTensorPostSparsificationRewrite,
mlirCreateSparseTensorPreSparsificationRewrite,
mlirCreateSparseTensorSparseBufferRewrite,
mlirCreateSparseTensorSparseTensorCodegen,
mlirCreateSparseTensorSparseTensorConversionPass,
mlirCreateSparseTensorSparseVectorization,
mlirCreateSparseTensorSparsificationPass,
mlirCreateSparseTensorStorageSpecifierToLLVM,
]
);
33 changes: 18 additions & 15 deletions melior/src/pass/transform.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
//! Transform passes.
melior_macro::transform_passes!(
mlirCreateTransformsCSE,
mlirCreateTransformsCanonicalizer,
mlirCreateTransformsControlFlowSink,
mlirCreateTransformsGenerateRuntimeVerification,
mlirCreateTransformsInliner,
mlirCreateTransformsLocationSnapshot,
mlirCreateTransformsLoopInvariantCodeMotion,
mlirCreateTransformsPrintOpStats,
mlirCreateTransformsSCCP,
mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE,
mlirCreateTransformsSymbolPrivatize,
mlirCreateTransformsTopologicalSort,
mlirCreateTransformsViewOpGraph,
melior_macro::passes!(
"Transforms",
[
mlirCreateTransformsCSE,
mlirCreateTransformsCanonicalizer,
mlirCreateTransformsControlFlowSink,
mlirCreateTransformsGenerateRuntimeVerification,
mlirCreateTransformsInliner,
mlirCreateTransformsLocationSnapshot,
mlirCreateTransformsLoopInvariantCodeMotion,
mlirCreateTransformsPrintOpStats,
mlirCreateTransformsSCCP,
mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE,
mlirCreateTransformsSymbolPrivatize,
mlirCreateTransformsTopologicalSort,
mlirCreateTransformsViewOpGraph,
]
);

0 comments on commit 86f261c

Please sign in to comment.