diff --git a/lib/flux-attrs/src/ast.rs b/lib/flux-attrs/src/ast.rs index f99ea4edf4..33fe59798e 100644 --- a/lib/flux-attrs/src/ast.rs +++ b/lib/flux-attrs/src/ast.rs @@ -9,7 +9,7 @@ use syn::{ parse::{Parse, ParseStream, Peek}, punctuated::Punctuated, token::{self, Mut, Paren}, - Attribute, Generics, Ident, Result, Token, Visibility, + Attribute, Ident, Result, Token, Visibility, }; use crate::flux_tool_attrs; @@ -33,6 +33,61 @@ pub struct ItemFn { pub block: Block, } +#[derive(Debug)] +pub struct Generics { + pub lt_token: Option, + pub params: Punctuated, + pub gt_token: Option]>, + pub where_clause: Option, +} + +impl Default for Generics { + fn default() -> Self { + Generics { lt_token: None, params: Punctuated::new(), gt_token: None, where_clause: None } + } +} + +#[derive(Debug)] +pub enum GenericParam { + /// A lifetime parameter: `'a: 'b + 'c + 'd`. + Lifetime(syn::LifetimeParam), + + /// A generic type parameter: `T: Into`. + Type(TypeParam), + + /// A const generic parameter: `const LENGTH: usize`. + Const(syn::ConstParam), +} + +#[derive(Debug)] +pub struct TypeParam { + pub attrs: Vec, + pub ident: Ident, + pub as_token: Option, + pub param_kind: ParamKind, + pub colon_token: Option, + pub bounds: Punctuated, + // pub eq_token: Option, + // pub default: Option, +} + +#[derive(Debug)] +pub enum ParamKind { + Type(Token![type]), + Base(kw::base), + Default, +} + +impl ToTokens for ParamKind { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + ParamKind::Type(t) => t.to_tokens(tokens), + ParamKind::Base(t) => t.to_tokens(tokens), + ParamKind::Default => {} + } + } +} + pub struct ItemStruct { pub attrs: Vec, pub vis: Visibility, @@ -544,6 +599,138 @@ impl Parse for ItemStruct { } } +impl Parse for Generics { + fn parse(input: ParseStream) -> Result { + if !input.peek(Token![<]) { + return Ok(Generics::default()); + } + + let lt_token: Token![<] = input.parse()?; + + let mut params = Punctuated::new(); + loop { + if input.peek(Token![>]) { + break; + } + + let attrs = input.call(Attribute::parse_outer)?; + let lookahead = input.lookahead1(); + if lookahead.peek(syn::Lifetime) { + params.push_value(GenericParam::Lifetime(syn::LifetimeParam { + attrs, + ..input.parse()? + })); + } else if lookahead.peek(Ident) { + params.push_value(GenericParam::Type(TypeParam { attrs, ..input.parse()? })); + } else if lookahead.peek(Token![const]) { + params.push_value(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? })); + } else if input.peek(Token![_]) { + params.push_value(GenericParam::Type(TypeParam { + attrs, + ident: input.call(Ident::parse_any)?, + as_token: None, + param_kind: ParamKind::Default, + colon_token: None, + bounds: Punctuated::new(), + // eq_token: None, + // default: None, + })); + } else { + return Err(lookahead.error()); + } + + if input.peek(Token![>]) { + break; + } + let punct = input.parse()?; + params.push_punct(punct); + } + + let gt_token: Token![>] = input.parse()?; + + Ok(Generics { + lt_token: Some(lt_token), + params, + gt_token: Some(gt_token), + where_clause: None, + }) + } +} + +impl Parse for GenericParam { + fn parse(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + + let lookahead = input.lookahead1(); + if lookahead.peek(Ident) { + Ok(GenericParam::Type(TypeParam { attrs, ..input.parse()? })) + } else if lookahead.peek(syn::Lifetime) { + Ok(GenericParam::Lifetime(syn::LifetimeParam { attrs, ..input.parse()? })) + } else if lookahead.peek(Token![const]) { + Ok(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? })) + } else { + Err(lookahead.error()) + } + } +} + +impl Parse for TypeParam { + fn parse(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + let ident: Ident = input.parse()?; + + let as_token: Option = input.parse()?; + let mut param_kind = ParamKind::Default; + if as_token.is_some() { + param_kind = input.parse()?; + } + + let colon_token: Option = input.parse()?; + + let mut bounds = Punctuated::new(); + if colon_token.is_some() { + loop { + if input.peek(Token![,]) || input.peek(Token![>]) || input.peek(Token![=]) { + break; + } + let value: syn::TypeParamBound = input.parse()?; + bounds.push_value(value); + if !input.peek(Token![+]) { + break; + } + let punct: Token![+] = input.parse()?; + bounds.push_punct(punct); + } + } + // let eq_token: Option = input.parse()?; + // let default = if eq_token.is_some() { Some(input.parse::()?) } else { None }; + + Ok(TypeParam { + attrs, + ident, + as_token, + param_kind, + colon_token, + bounds, + // eq_token, + // default, + }) + } +} + +impl Parse for ParamKind { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![type]) { + input.parse().map(ParamKind::Type) + } else if lookahead.peek(kw::base) { + input.parse().map(ParamKind::Base) + } else { + Err(lookahead.error()) + } + } +} + impl Parse for ItemEnum { fn parse(input: ParseStream) -> Result { let mut attrs = input.call(Attribute::parse_outer)?; @@ -857,7 +1044,7 @@ impl Parse for Signature { let content; let fn_token = input.parse()?; let ident = input.parse()?; - let mut generics: syn::Generics = input.parse()?; + let mut generics: Generics = input.parse()?; let paren_token = parenthesized!(content in input); let inputs = content.parse_terminated(FnArg::parse, Token![,])?; let output = input.parse()?; @@ -1185,6 +1372,7 @@ mod kw { syn::custom_keyword!(requires); syn::custom_keyword!(refined); syn::custom_keyword!(by); + syn::custom_keyword!(base); } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1235,7 +1423,7 @@ impl ToTokens for ItemStruct { self.vis.to_tokens(tokens); self.struct_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); self.fields.to_tokens(tokens, |field, tokens| { #[cfg(flux_sysroot)] field.flux_tool_attr().to_tokens(tokens); @@ -1255,7 +1443,7 @@ impl ToTokens for ItemEnum { self.vis.to_tokens(tokens); self.enum_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); self.brace_token.surround(tokens, |tokens| { self.variants.to_tokens(tokens); }); @@ -1348,7 +1536,7 @@ impl ItemType { } self.type_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, mode); if let Some(params) = &self.index_params { params.to_tokens_inner(tokens, mode); } @@ -1360,6 +1548,65 @@ impl ItemType { } } +impl Generics { + fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) { + if self.params.is_empty() { + return; + } + + tokens_or_default(self.lt_token.as_ref(), tokens); + + for param in self.params.pairs() { + match mode { + Mode::Rust => { + param.to_tokens(tokens); + } + Mode::Flux => { + if let GenericParam::Type(p) = param.value() { + p.to_tokens(tokens, mode); + param.punct().to_tokens(tokens); + } + } + } + } + + tokens_or_default(self.gt_token.as_ref(), tokens); + } +} + +impl ToTokens for GenericParam { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + GenericParam::Lifetime(p) => p.to_tokens(tokens), + GenericParam::Type(p) => p.to_tokens(tokens, Mode::Rust), + GenericParam::Const(p) => p.to_tokens(tokens), + } + } +} + +impl TypeParam { + fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) { + tokens.append_all(outer(&self.attrs)); + self.ident.to_tokens(tokens); + + if mode == Mode::Flux { + if let Some(as_token) = self.as_token { + as_token.to_tokens(tokens); + self.param_kind.to_tokens(tokens); + } + } + + if !self.bounds.is_empty() && mode == Mode::Rust { + tokens_or_default(self.colon_token.as_ref(), tokens); + self.bounds.to_tokens(tokens); + } + // if let Some(default) = &self.default { + // tokens_or_default(self.eq_token.as_ref(), tokens); + // default.to_tokens(tokens); + // } + } +} + impl IndexParams { fn to_tokens_inner(&self, tokens: &mut TokenStream, mode: Mode) { if mode == Mode::Flux { @@ -1377,7 +1624,7 @@ impl ToTokens for ItemImpl { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.append_all(&self.attrs); self.impl_token.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); if let Some((trait_, for_token)) = &self.trait_ { trait_.to_tokens(tokens); for_token.to_tokens(tokens); @@ -1435,8 +1682,8 @@ impl Signature { self.fn_token.to_tokens(tokens); if mode == Mode::Rust { self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); } + self.generics.to_tokens(tokens, mode); self.paren_token.surround(tokens, |tokens| { for fn_arg in self.inputs.pairs() { fn_arg.value().to_tokens_inner(tokens, mode); @@ -1735,3 +1982,20 @@ impl ToTokens for Block { .surround(tokens, |tokens| self.stmts.to_tokens(tokens)); } } + +fn tokens_or_default(x: Option<&T>, tokens: &mut TokenStream) { + match x { + Some(t) => t.to_tokens(tokens), + None => T::default().to_tokens(tokens), + } +} + +fn outer(attrs: &[Attribute]) -> impl Iterator { + fn is_outer(attr: &&Attribute) -> bool { + match attr.style { + syn::AttrStyle::Outer => true, + syn::AttrStyle::Inner(_) => false, + } + } + attrs.iter().filter(is_outer) +}