Skip to content

Commit

Permalink
Implement refined generics
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Nov 16, 2023
1 parent f434409 commit c8bcbbb
Showing 1 changed file with 271 additions and 7 deletions.
278 changes: 271 additions & 7 deletions lib/flux-attrs/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use syn::{
parse::{Parse, ParseStream, Peek},
punctuated::Punctuated,
token::{self, Mut, Paren},
Attribute, Generics, Ident, Result, Token, Visibility,
Attribute, Ident, Result, Token, Visibility,
};

use crate::flux_tool_attrs;
Expand All @@ -33,6 +33,61 @@ pub struct ItemFn {
pub block: Block,
}

#[derive(Debug)]
pub struct Generics {
pub lt_token: Option<Token![<]>,
pub params: Punctuated<GenericParam, Token![,]>,
pub gt_token: Option<Token![>]>,
pub where_clause: Option<syn::WhereClause>,
}

impl Default for Generics {
fn default() -> Self {
Generics { lt_token: None, params: Punctuated::new(), gt_token: None, where_clause: None }
}
}

#[derive(Debug)]
pub enum GenericParam {
/// A lifetime parameter: `'a: 'b + 'c + 'd`.
Lifetime(syn::LifetimeParam),

/// A generic type parameter: `T: Into<String>`.
Type(TypeParam),

/// A const generic parameter: `const LENGTH: usize`.
Const(syn::ConstParam),
}

#[derive(Debug)]
pub struct TypeParam {
pub attrs: Vec<Attribute>,
pub ident: Ident,
pub as_token: Option<Token![as]>,
pub param_kind: ParamKind,
pub colon_token: Option<Token![:]>,
pub bounds: Punctuated<syn::TypeParamBound, Token![+]>,
// pub eq_token: Option<Token![=]>,
// pub default: Option<Type>,
}

#[derive(Debug)]
pub enum ParamKind {
Type(Token![type]),
Base(kw::base),
Default,
}

impl ToTokens for ParamKind {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ParamKind::Type(t) => t.to_tokens(tokens),
ParamKind::Base(t) => t.to_tokens(tokens),
ParamKind::Default => {}
}
}
}

pub struct ItemStruct {
pub attrs: Vec<Attribute>,
pub vis: Visibility,
Expand Down Expand Up @@ -544,6 +599,138 @@ impl Parse for ItemStruct {
}
}

impl Parse for Generics {
fn parse(input: ParseStream) -> Result<Self> {
if !input.peek(Token![<]) {
return Ok(Generics::default());
}

let lt_token: Token![<] = input.parse()?;

let mut params = Punctuated::new();
loop {
if input.peek(Token![>]) {
break;
}

let attrs = input.call(Attribute::parse_outer)?;
let lookahead = input.lookahead1();
if lookahead.peek(syn::Lifetime) {
params.push_value(GenericParam::Lifetime(syn::LifetimeParam {
attrs,
..input.parse()?
}));
} else if lookahead.peek(Ident) {
params.push_value(GenericParam::Type(TypeParam { attrs, ..input.parse()? }));
} else if lookahead.peek(Token![const]) {
params.push_value(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? }));
} else if input.peek(Token![_]) {
params.push_value(GenericParam::Type(TypeParam {
attrs,
ident: input.call(Ident::parse_any)?,
as_token: None,
param_kind: ParamKind::Default,
colon_token: None,
bounds: Punctuated::new(),
// eq_token: None,
// default: None,
}));
} else {
return Err(lookahead.error());
}

if input.peek(Token![>]) {
break;
}
let punct = input.parse()?;
params.push_punct(punct);
}

let gt_token: Token![>] = input.parse()?;

Ok(Generics {
lt_token: Some(lt_token),
params,
gt_token: Some(gt_token),
where_clause: None,
})
}
}

impl Parse for GenericParam {
fn parse(input: ParseStream) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;

let lookahead = input.lookahead1();
if lookahead.peek(Ident) {
Ok(GenericParam::Type(TypeParam { attrs, ..input.parse()? }))
} else if lookahead.peek(syn::Lifetime) {
Ok(GenericParam::Lifetime(syn::LifetimeParam { attrs, ..input.parse()? }))
} else if lookahead.peek(Token![const]) {
Ok(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? }))
} else {
Err(lookahead.error())
}
}
}

impl Parse for TypeParam {
fn parse(input: ParseStream) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let ident: Ident = input.parse()?;

let as_token: Option<Token![as]> = input.parse()?;
let mut param_kind = ParamKind::Default;
if as_token.is_some() {
param_kind = input.parse()?;
}

let colon_token: Option<Token![:]> = input.parse()?;

let mut bounds = Punctuated::new();
if colon_token.is_some() {
loop {
if input.peek(Token![,]) || input.peek(Token![>]) || input.peek(Token![=]) {
break;
}
let value: syn::TypeParamBound = input.parse()?;
bounds.push_value(value);
if !input.peek(Token![+]) {
break;
}
let punct: Token![+] = input.parse()?;
bounds.push_punct(punct);
}
}
// let eq_token: Option<Token![=]> = input.parse()?;
// let default = if eq_token.is_some() { Some(input.parse::<Type>()?) } else { None };

Ok(TypeParam {
attrs,
ident,
as_token,
param_kind,
colon_token,
bounds,
// eq_token,
// default,
})
}
}

impl Parse for ParamKind {
fn parse(input: ParseStream) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(Token![type]) {
input.parse().map(ParamKind::Type)
} else if lookahead.peek(kw::base) {
input.parse().map(ParamKind::Base)
} else {
Err(lookahead.error())
}
}
}

impl Parse for ItemEnum {
fn parse(input: ParseStream) -> Result<Self> {
let mut attrs = input.call(Attribute::parse_outer)?;
Expand Down Expand Up @@ -857,7 +1044,7 @@ impl Parse for Signature {
let content;
let fn_token = input.parse()?;
let ident = input.parse()?;
let mut generics: syn::Generics = input.parse()?;
let mut generics: Generics = input.parse()?;
let paren_token = parenthesized!(content in input);
let inputs = content.parse_terminated(FnArg::parse, Token![,])?;
let output = input.parse()?;
Expand Down Expand Up @@ -1185,6 +1372,7 @@ mod kw {
syn::custom_keyword!(requires);
syn::custom_keyword!(refined);
syn::custom_keyword!(by);
syn::custom_keyword!(base);
}

#[derive(Copy, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -1235,7 +1423,7 @@ impl ToTokens for ItemStruct {
self.vis.to_tokens(tokens);
self.struct_token.to_tokens(tokens);
self.ident.to_tokens(tokens);
self.generics.to_tokens(tokens);
self.generics.to_tokens(tokens, Mode::Rust);
self.fields.to_tokens(tokens, |field, tokens| {
#[cfg(flux_sysroot)]
field.flux_tool_attr().to_tokens(tokens);
Expand All @@ -1255,7 +1443,7 @@ impl ToTokens for ItemEnum {
self.vis.to_tokens(tokens);
self.enum_token.to_tokens(tokens);
self.ident.to_tokens(tokens);
self.generics.to_tokens(tokens);
self.generics.to_tokens(tokens, Mode::Rust);
self.brace_token.surround(tokens, |tokens| {
self.variants.to_tokens(tokens);
});
Expand Down Expand Up @@ -1348,7 +1536,7 @@ impl ItemType {
}
self.type_token.to_tokens(tokens);
self.ident.to_tokens(tokens);
self.generics.to_tokens(tokens);
self.generics.to_tokens(tokens, mode);
if let Some(params) = &self.index_params {
params.to_tokens_inner(tokens, mode);
}
Expand All @@ -1360,6 +1548,65 @@ impl ItemType {
}
}

impl Generics {
fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) {
if self.params.is_empty() {
return;
}

tokens_or_default(self.lt_token.as_ref(), tokens);

for param in self.params.pairs() {
match mode {
Mode::Rust => {
param.to_tokens(tokens);
}
Mode::Flux => {
if let GenericParam::Type(p) = param.value() {
p.to_tokens(tokens, mode);
param.punct().to_tokens(tokens);
}
}
}
}

tokens_or_default(self.gt_token.as_ref(), tokens);
}
}

impl ToTokens for GenericParam {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
GenericParam::Lifetime(p) => p.to_tokens(tokens),
GenericParam::Type(p) => p.to_tokens(tokens, Mode::Rust),
GenericParam::Const(p) => p.to_tokens(tokens),
}
}
}

impl TypeParam {
fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) {
tokens.append_all(outer(&self.attrs));
self.ident.to_tokens(tokens);

if mode == Mode::Flux {
if let Some(as_token) = self.as_token {
as_token.to_tokens(tokens);
self.param_kind.to_tokens(tokens);
}
}

if !self.bounds.is_empty() && mode == Mode::Rust {
tokens_or_default(self.colon_token.as_ref(), tokens);
self.bounds.to_tokens(tokens);
}
// if let Some(default) = &self.default {
// tokens_or_default(self.eq_token.as_ref(), tokens);
// default.to_tokens(tokens);
// }
}
}

impl IndexParams {
fn to_tokens_inner(&self, tokens: &mut TokenStream, mode: Mode) {
if mode == Mode::Flux {
Expand All @@ -1377,7 +1624,7 @@ impl ToTokens for ItemImpl {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.append_all(&self.attrs);
self.impl_token.to_tokens(tokens);
self.generics.to_tokens(tokens);
self.generics.to_tokens(tokens, Mode::Rust);
if let Some((trait_, for_token)) = &self.trait_ {
trait_.to_tokens(tokens);
for_token.to_tokens(tokens);
Expand Down Expand Up @@ -1435,8 +1682,8 @@ impl Signature {
self.fn_token.to_tokens(tokens);
if mode == Mode::Rust {
self.ident.to_tokens(tokens);
self.generics.to_tokens(tokens);
}
self.generics.to_tokens(tokens, mode);
self.paren_token.surround(tokens, |tokens| {
for fn_arg in self.inputs.pairs() {
fn_arg.value().to_tokens_inner(tokens, mode);
Expand Down Expand Up @@ -1735,3 +1982,20 @@ impl ToTokens for Block {
.surround(tokens, |tokens| self.stmts.to_tokens(tokens));
}
}

fn tokens_or_default<T: ToTokens + Default>(x: Option<&T>, tokens: &mut TokenStream) {
match x {
Some(t) => t.to_tokens(tokens),
None => T::default().to_tokens(tokens),
}
}

fn outer(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
fn is_outer(attr: &&Attribute) -> bool {
match attr.style {
syn::AttrStyle::Outer => true,
syn::AttrStyle::Inner(_) => false,
}
}
attrs.iter().filter(is_outer)
}

0 comments on commit c8bcbbb

Please sign in to comment.