From d352179057dacdd25fc0b255fd02213371c9bc9b Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Mon, 4 Dec 2023 22:16:24 +0900 Subject: [PATCH] Fix --- macro/src/lib.rs | 46 +++++----------------------------------------- macro/src/parse.rs | 12 ++++++------ 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/macro/src/lib.rs b/macro/src/lib.rs index fb8fb39339..64e01cf6e0 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -7,7 +7,7 @@ mod r#type; mod utility; use dialect::DialectInput; -use parse::{DialectOperationSet, IdentifierList}; +use parse::{DialectOperationSet, IdentifierList, PassSet}; use proc_macro::TokenStream; use quote::quote; use std::error::Error; @@ -68,15 +68,6 @@ pub fn attribute_check_functions(stream: TokenStream) -> TokenStream { convert_result(attribute::generate(identifiers.identifiers())) } -#[proc_macro] -pub fn async_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Async").unwrap().into() - })) -} - #[proc_macro] pub fn conversion_passes(stream: TokenStream) -> TokenStream { let identifiers = parse_macro_input!(stream as IdentifierList); @@ -90,38 +81,11 @@ pub fn conversion_passes(stream: TokenStream) -> TokenStream { } #[proc_macro] -pub fn gpu_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("GPU").unwrap().into() - })) -} - -#[proc_macro] -pub fn transform_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Transforms").unwrap().into() - })) -} - -#[proc_macro] -pub fn linalg_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); - - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("Linalg").unwrap().into() - })) -} - -#[proc_macro] -pub fn sparse_tensor_passes(stream: TokenStream) -> TokenStream { - let identifiers = parse_macro_input!(stream as IdentifierList); +pub fn passes(stream: TokenStream) -> TokenStream { + let set = parse_macro_input!(stream as PassSet); - convert_result(pass::generate(identifiers.identifiers(), |name| { - name.strip_prefix("SparseTensor").unwrap().into() + convert_result(pass::generate(set.identifiers(), |name| { + name.strip_prefix(&set.prefix().value()).unwrap().into() })) } diff --git a/macro/src/parse.rs b/macro/src/parse.rs index d250d4c581..0f8cdbae1d 100644 --- a/macro/src/parse.rs +++ b/macro/src/parse.rs @@ -3,7 +3,7 @@ use syn::{ bracketed, parse::{Parse, ParseStream}, punctuated::Punctuated, - Lit, Result, Token, + LitStr, Result, Token, }; pub struct IdentifierList { @@ -58,13 +58,13 @@ impl Parse for DialectOperationSet { } pub struct PassSet { - prefix: Lit, + prefix: LitStr, identifiers: IdentifierList, } impl PassSet { - pub const fn dialect(&self) -> &Ident { - &self.dialect + pub const fn prefix(&self) -> &LitStr { + &self.prefix } pub fn identifiers(&self) -> &[Ident] { @@ -74,11 +74,11 @@ impl PassSet { impl Parse for PassSet { fn parse(input: ParseStream) -> Result { - let dialect = Ident::parse(input)?; + let prefix = input.parse()?; ::parse(input)?; Ok(Self { - dialect, + prefix, identifiers: { let content; bracketed!(content in input);