From f434409f1e6f08debfba28234375e1c039ad3a4a Mon Sep 17 00:00:00 2001 From: Nico Lehmann Date: Thu, 16 Nov 2023 10:50:49 -0800 Subject: [PATCH 1/2] Implement syntax for enums --- crates/flux-desugar/src/desugar.rs | 6 +- crates/flux-desugar/src/desugar/gather.rs | 4 +- crates/flux-desugar/src/resolver.rs | 4 +- crates/flux-middle/src/fhir/lift.rs | 40 ++- crates/flux-syntax/src/grammar.lalrpop | 2 +- crates/flux-syntax/src/surface.rs | 2 +- crates/flux-syntax/src/surface/visit.rs | 4 +- lib/flux-attrs/src/ast.rs | 362 +++++++++++++++++----- xtask/src/main.rs | 46 ++- 9 files changed, 355 insertions(+), 115 deletions(-) diff --git a/crates/flux-desugar/src/desugar.rs b/crates/flux-desugar/src/desugar.rs index 162b4409ad..e4fa52b1fa 100644 --- a/crates/flux-desugar/src/desugar.rs +++ b/crates/flux-desugar/src/desugar.rs @@ -441,7 +441,11 @@ impl<'a, 'tcx> RustItemCtxt<'a, 'tcx> { }) .try_collect_exhaust()?; - let ret = self.desugar_variant_ret(&variant_def.ret, &mut env)?; + let ret = if let Some(ret) = &variant_def.ret { + self.desugar_variant_ret(ret, &mut env)? + } else { + self.as_lift_cx().lift_variant_ret() + }; Ok(fhir::VariantDef { def_id: hir_variant.def_id, diff --git a/crates/flux-desugar/src/desugar/gather.rs b/crates/flux-desugar/src/desugar/gather.rs index e3e45e4dc9..c1751f82fa 100644 --- a/crates/flux-desugar/src/desugar/gather.rs +++ b/crates/flux-desugar/src/desugar/gather.rs @@ -148,7 +148,9 @@ impl RustItemCtxt<'_, '_> { self.gather_params_ty(None, ty, TypePos::Input, &mut env)?; } - self.gather_params_variant_ret(&variant_def.ret, &mut env)?; + if let Some(ret) = &variant_def.ret { + self.gather_params_variant_ret(ret, &mut env)?; + } self.check_param_uses(&mut env, |vis| vis.visit_variant(variant_def))?; diff --git a/crates/flux-desugar/src/resolver.rs b/crates/flux-desugar/src/resolver.rs index 8db62b1b69..57b6e7bfbb 100644 --- a/crates/flux-desugar/src/resolver.rs +++ b/crates/flux-desugar/src/resolver.rs @@ -157,7 +157,9 @@ impl<'a> ItemLikeResolver<'a> { .fields .iter() .try_for_each_exhaust(|ty| self.resolve_ty(ty))?; - self.resolve_variant_ret(&variant_def.ret)?; + if let Some(ret) = &variant_def.ret { + self.resolve_variant_ret(ret)?; + } } Ok(()) } diff --git a/crates/flux-middle/src/fhir/lift.rs b/crates/flux-middle/src/fhir/lift.rs index d06871e718..c816e64a43 100644 --- a/crates/flux-middle/src/fhir/lift.rs +++ b/crates/flux-middle/src/fhir/lift.rs @@ -364,9 +364,7 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { variant: &hir::Variant, ) -> Result { let item = self.tcx.hir().expect_item(self.owner.def_id); - let hir::ItemKind::Enum(_, generics) = &item.kind else { - bug!("expected an enum or struct") - }; + let hir::ItemKind::Enum(_, generics) = &item.kind else { bug!("expected an enum") }; let fields = variant .data @@ -375,6 +373,29 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { .map(|field| self.lift_field_def(field)) .try_collect_exhaust()?; + let ret = self.lift_variant_ret_inner(item, generics); + + Ok(fhir::VariantDef { + def_id: variant.def_id, + params: vec![], + fields, + ret, + span: variant.span, + lifted: true, + }) + } + + pub fn lift_variant_ret(&mut self) -> fhir::VariantRet { + let item = self.tcx.hir().expect_item(self.owner.def_id); + let hir::ItemKind::Enum(_, generics) = &item.kind else { bug!("expected an enum") }; + self.lift_variant_ret_inner(item, generics) + } + + fn lift_variant_ret_inner( + &mut self, + item: &hir::Item, + generics: &hir::Generics, + ) -> fhir::VariantRet { let span = item.ident.span.to(generics.span); let path = fhir::Path { res: fhir::Res::SelfTyAlias { alias_to: self.owner.to_def_id(), is_trait_impl: false }, @@ -384,7 +405,7 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { span, }; let bty = fhir::BaseTy::from(fhir::QPath::Resolved(None, path)); - let ret = fhir::VariantRet { + fhir::VariantRet { bty, idx: fhir::RefineArg::Record( self.owner.to_def_id(), @@ -392,16 +413,7 @@ impl<'a, 'tcx> LiftCtxt<'a, 'tcx> { vec![], generics.span.shrink_to_hi(), ), - }; - - Ok(fhir::VariantDef { - def_id: variant.def_id, - params: vec![], - fields, - ret, - span: variant.span, - lifted: true, - }) + } } fn lift_ty(&mut self, ty: &hir::Ty) -> Result { diff --git a/crates/flux-syntax/src/grammar.lalrpop b/crates/flux-syntax/src/grammar.lalrpop index 96ab9a968d..851c954ba3 100644 --- a/crates/flux-syntax/src/grammar.lalrpop +++ b/crates/flux-syntax/src/grammar.lalrpop @@ -192,7 +192,7 @@ Async: surface::Async = { } pub Variant: surface::VariantDef = { - => { + => { let fields = match tys { Some(fields) => fields, None => vec![], diff --git a/crates/flux-syntax/src/surface.rs b/crates/flux-syntax/src/surface.rs index 7acf97f177..ca89450137 100644 --- a/crates/flux-syntax/src/surface.rs +++ b/crates/flux-syntax/src/surface.rs @@ -120,7 +120,7 @@ impl EnumDef { #[derive(Debug)] pub struct VariantDef { pub fields: Vec, - pub ret: VariantRet, + pub ret: Option, pub node_id: NodeId, pub span: Span, } diff --git a/crates/flux-syntax/src/surface/visit.rs b/crates/flux-syntax/src/surface/visit.rs index 640a6baaf3..5ed10ab69b 100644 --- a/crates/flux-syntax/src/surface/visit.rs +++ b/crates/flux-syntax/src/surface/visit.rs @@ -196,7 +196,9 @@ pub fn walk_enum_def(vis: &mut V, enum_def: &EnumDef) { pub fn walk_variant(vis: &mut V, variant: &VariantDef) { walk_list!(vis, visit_ty, &variant.fields); - vis.visit_variant_ret(&variant.ret); + if let Some(ret) = &variant.ret { + vis.visit_variant_ret(ret); + } } pub fn walk_variant_ret(vis: &mut V, ret: &VariantRet) { diff --git a/lib/flux-attrs/src/ast.rs b/lib/flux-attrs/src/ast.rs index b55d5bd67f..f99ea4edf4 100644 --- a/lib/flux-attrs/src/ast.rs +++ b/lib/flux-attrs/src/ast.rs @@ -18,7 +18,7 @@ pub struct Items(Vec); pub enum Item { Struct(ItemStruct), - Enum(syn::ItemEnum), + Enum(ItemEnum), Use(syn::ItemUse), Type(ItemType), Fn(ItemFn), @@ -44,18 +44,94 @@ pub struct ItemStruct { pub semi_token: Option, } +#[derive(Debug)] +pub struct ItemEnum { + pub attrs: Vec, + pub vis: Visibility, + pub enum_token: Token![enum], + pub ident: Ident, + pub generics: Generics, + pub refined_by: Option, + pub brace_token: token::Brace, + pub variants: Punctuated, +} + +#[derive(Debug)] +pub struct Variant { + pub attrs: Vec, + + /// Name of the variant. + pub ident: Ident, + + /// Content stored in the variant. + pub fields: Fields, + + /// Explicit discriminant: `Variant = 1` + pub discriminant: Option<(Token![=], syn::Expr)>, + + pub ret: Option, +} + +impl Variant { + #[cfg(flux_sysroot)] + fn flux_tool_attr(&self) -> TokenStream { + let variant = ToTokensFlux(self); + quote! { + #[flux_tool::variant(#variant)] + } + } +} + +impl ToTokens for ToTokensFlux<&Variant> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let variant = &self.0; + variant + .fields + .to_tokens(tokens, |f, tokens| f.ty.to_tokens_inner(tokens, Mode::Flux)); + if let Some(ret) = &variant.ret { + if !matches!(variant.fields, Fields::Unit) { + ret.arrow_token.to_tokens(tokens); + } + ret.path.to_tokens_inner(tokens, Mode::Flux); + ret.bracket_token.surround(tokens, |tokens| { + ret.indices.to_tokens(tokens); + }); + } + } +} + +#[derive(Debug)] +pub struct VariantRet { + pub arrow_token: Option]>, + pub path: Path, + pub bracket_token: token::Bracket, + pub indices: TokenStream, +} + +#[derive(Debug)] pub struct RefinedBy { pub refined_by: Option<(kw::refined, kw::by)>, pub bracket_token: token::Bracket, pub params: Punctuated, } +impl RefinedBy { + #[cfg(flux_sysroot)] + fn flux_tool_attr(&self) -> TokenStream { + quote! { + #[flux_tool::refined_by(#self)] + } + } +} + +#[derive(Debug)] pub struct RefinedByParam { pub ident: Ident, pub colon_token: Option, pub sort: Ident, } +#[derive(Debug)] pub enum Fields { /// Named fields of a struct or struct variant such as `Point { x: f64, /// y: f64 }`. @@ -68,16 +144,43 @@ pub enum Fields { Unit, } +impl Fields { + fn to_tokens(&self, tokens: &mut TokenStream, mut f: impl FnMut(&Field, &mut TokenStream)) { + match self { + Fields::Named(fields) => { + fields.brace_token.surround(tokens, |tokens| { + for param in fields.named.pairs() { + f(param.value(), tokens); + param.punct().to_tokens(tokens); + } + }); + } + Fields::Unnamed(fields) => { + fields.paren_token.surround(tokens, |tokens| { + for param in fields.unnamed.pairs() { + f(param.value(), tokens); + param.punct().to_tokens(tokens); + } + }); + } + Fields::Unit => {} + } + } +} + +#[derive(Debug)] pub struct FieldsNamed { pub brace_token: token::Brace, pub named: Punctuated, } +#[derive(Debug)] pub struct FieldsUnnamed { pub paren_token: token::Paren, pub unnamed: Punctuated, } +#[derive(Debug)] pub struct Field { pub attrs: Vec, @@ -95,6 +198,52 @@ pub struct Field { pub ty: Type, } +impl Field { + fn parse_unnamed(input: ParseStream) -> Result { + Ok(Field { + attrs: input.call(Attribute::parse_outer)?, + vis: input.parse()?, + mutability: syn::FieldMutability::None, + ident: None, + colon_token: None, + ty: input.parse()?, + }) + } + + fn parse_named(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + let vis: Visibility = input.parse()?; + + let ident = input.parse()?; + let colon_token: Token![:] = input.parse()?; + let ty = input.parse()?; + Ok(Field { + attrs, + vis, + mutability: syn::FieldMutability::None, + ident: Some(ident), + colon_token: Some(colon_token), + ty, + }) + } + + #[cfg(flux_sysroot)] + fn flux_tool_attr(&self) -> TokenStream { + let flux_ty = ToTokensFlux(&self.ty); + quote! { + #[flux_tool::field(#flux_ty)] + } + } + + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.append_all(&self.attrs); + self.vis.to_tokens(tokens); + self.ident.to_tokens(tokens); + self.colon_token.to_tokens(tokens); + self.ty.to_tokens_inner(tokens, Mode::Rust); + } +} + pub struct ItemType { pub attrs: Vec, pub vis: Visibility, @@ -375,7 +524,7 @@ impl Parse for Item { impl Parse for ItemStruct { fn parse(input: ParseStream) -> Result { let mut attrs = input.call(Attribute::parse_outer)?; - flux_tool_attrs(&mut attrs, &["opaque", "invariant"]); + flux_tool_attrs(&mut attrs, FLUX_ATTRS); let vis = input.parse::()?; let struct_token = input.parse::()?; let ident = input.parse::()?; @@ -395,6 +544,29 @@ impl Parse for ItemStruct { } } +impl Parse for ItemEnum { + fn parse(input: ParseStream) -> Result { + let mut attrs = input.call(Attribute::parse_outer)?; + flux_tool_attrs(&mut attrs, FLUX_ATTRS); + let vis = input.parse::()?; + let enum_token = input.parse::()?; + let ident = input.parse::()?; + let generics = input.parse::()?; + let refined_by = parse_opt_refined_by(input)?; + let (where_clause, brace_token, variants) = data_enum(input)?; + Ok(ItemEnum { + attrs, + vis, + enum_token, + ident, + generics: Generics { where_clause, ..generics }, + refined_by, + brace_token, + variants, + }) + } +} + fn parse_opt_refined_by(input: ParseStream) -> Result> { if input.peek(kw::refined) || input.peek(token::Bracket) { input.parse().map(Some) @@ -426,6 +598,72 @@ impl Parse for RefinedByParam { } } +fn data_enum( + input: ParseStream, +) -> Result<(Option, token::Brace, Punctuated)> { + let where_clause = input.parse()?; + + let content; + let brace = braced!(content in input); + let variants = content.parse_terminated(Variant::parse, Token![,])?; + + Ok((where_clause, brace, variants)) +} + +impl Parse for Variant { + fn parse(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + let _visibility: Visibility = input.parse()?; + let ident: Ident = input.parse()?; + let fields = if input.peek(token::Brace) { + Fields::Named(input.parse()?) + } else if input.peek(token::Paren) { + Fields::Unnamed(input.parse()?) + } else { + Fields::Unit + }; + let discriminant = if input.peek(Token![=]) { + let eq_token: Token![=] = input.parse()?; + let discriminant: syn::Expr = input.parse()?; + Some((eq_token, discriminant)) + } else { + None + }; + let ret = parse_opt_variant_ret(input)?; + Ok(Variant { attrs, ident, fields, discriminant, ret }) + } +} + +fn parse_opt_variant_ret(input: ParseStream) -> Result> { + if input.peek(Token![->]) { + input.parse().map(Some) + } else { + Ok(None) + } +} + +impl Parse for VariantRet { + fn parse(input: ParseStream) -> Result { + let mut indices = TokenStream::new(); + let content; + Ok(VariantRet { + arrow_token: input.parse()?, + path: input.parse()?, + bracket_token: bracketed!(content in input), + indices: { + loop { + if content.is_empty() { + break; + } + let tt: TokenTree = content.parse()?; + indices.append(tt); + } + indices + }, + }) + } +} + fn data_struct( input: ParseStream, ) -> Result<(Option, Fields, Option)> { @@ -482,36 +720,6 @@ impl Parse for FieldsNamed { } } -impl Field { - fn parse_unnamed(input: ParseStream) -> Result { - Ok(Field { - attrs: input.call(Attribute::parse_outer)?, - vis: input.parse()?, - mutability: syn::FieldMutability::None, - ident: None, - colon_token: None, - ty: input.parse()?, - }) - } - - fn parse_named(input: ParseStream) -> Result { - let attrs = input.call(Attribute::parse_outer)?; - let vis: Visibility = input.parse()?; - - let ident = input.parse()?; - let colon_token: Token![:] = input.parse()?; - let ty = input.parse()?; - Ok(Field { - attrs, - vis, - mutability: syn::FieldMutability::None, - ident: Some(ident), - colon_token: Some(colon_token), - ty, - }) - } -} - impl Parse for ItemFn { fn parse(input: ParseStream) -> Result { Ok(ItemFn { @@ -990,7 +1198,7 @@ impl Item { match self { Item::Fn(ItemFn { attrs, .. }) | Item::Impl(ItemImpl { attrs, .. }) - | Item::Enum(syn::ItemEnum { attrs, .. }) + | Item::Enum(ItemEnum { attrs, .. }) | Item::Struct(ItemStruct { attrs, .. }) | Item::Use(syn::ItemUse { attrs, .. }) | Item::Type(ItemType { attrs, .. }) => mem::replace(attrs, new), @@ -1021,83 +1229,67 @@ impl ToTokens for ItemStruct { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.append_all(&self.attrs); #[cfg(flux_sysroot)] - { - let refined_by = &self.refined_by; - quote! { - #[flux_tool::refined_by(#refined_by)] - } - .to_tokens(tokens); + if let Some(refined_by) = &self.refined_by { + refined_by.flux_tool_attr().to_tokens(tokens); } self.vis.to_tokens(tokens); self.struct_token.to_tokens(tokens); self.ident.to_tokens(tokens); self.generics.to_tokens(tokens); - self.fields.to_tokens(tokens); + self.fields.to_tokens(tokens, |field, tokens| { + #[cfg(flux_sysroot)] + field.flux_tool_attr().to_tokens(tokens); + field.to_tokens(tokens); + }); self.semi_token.to_tokens(tokens); } } -impl ToTokens for RefinedBy { +impl ToTokens for ItemEnum { fn to_tokens(&self, tokens: &mut TokenStream) { - for param in self.params.pairs() { - param.value().to_tokens(tokens); - param.punct().to_tokens(tokens); + tokens.append_all(&self.attrs); + #[cfg(flux_sysroot)] + if let Some(refined_by) = &self.refined_by { + refined_by.flux_tool_attr().to_tokens(tokens); } - } -} - -impl ToTokens for RefinedByParam { - fn to_tokens(&self, tokens: &mut TokenStream) { + self.vis.to_tokens(tokens); + self.enum_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.colon_token.to_tokens(tokens); - self.sort.to_tokens(tokens); + self.generics.to_tokens(tokens); + self.brace_token.surround(tokens, |tokens| { + self.variants.to_tokens(tokens); + }); } } -impl ToTokens for Fields { +impl ToTokens for Variant { fn to_tokens(&self, tokens: &mut TokenStream) { - match self { - Fields::Named(fields) => fields.to_tokens(tokens), - Fields::Unnamed(fields) => fields.to_tokens(tokens), - Fields::Unit => {} + #[cfg(flux_sysroot)] + self.flux_tool_attr().to_tokens(tokens); + tokens.append_all(&self.attrs); + self.ident.to_tokens(tokens); + self.fields.to_tokens(tokens, Field::to_tokens); + if let Some((eq_token, expr)) = &self.discriminant { + eq_token.to_tokens(tokens); + expr.to_tokens(tokens); } } } -impl ToTokens for FieldsNamed { - fn to_tokens(&self, tokens: &mut TokenStream) { - let FieldsNamed { brace_token, named } = self; - brace_token.surround(tokens, |tokens| { - named.to_tokens(tokens); - }); - } -} - -impl ToTokens for FieldsUnnamed { +impl ToTokens for RefinedBy { fn to_tokens(&self, tokens: &mut TokenStream) { - let FieldsUnnamed { paren_token, unnamed } = self; - paren_token.surround(tokens, |tokens| { - unnamed.to_tokens(tokens); - }); + for param in self.params.pairs() { + param.value().to_tokens(tokens); + param.punct().to_tokens(tokens); + } } } -impl ToTokens for Field { +impl ToTokens for RefinedByParam { fn to_tokens(&self, tokens: &mut TokenStream) { - let Field { attrs, vis, mutability: _, ident, colon_token, ty } = self; - #[cfg(flux_sysroot)] - { - let flux_ty = ToTokensFlux(ty); - quote! { - #[flux_tool::field(#flux_ty)] - } - .to_tokens(tokens); - } - tokens.append_all(attrs); - vis.to_tokens(tokens); - ident.to_tokens(tokens); - colon_token.to_tokens(tokens); - ty.to_tokens_inner(tokens, Mode::Rust); + self.ident.to_tokens(tokens); + self.colon_token.to_tokens(tokens); + self.sort.to_tokens(tokens); } } diff --git a/xtask/src/main.rs b/xtask/src/main.rs index b24f4891b4..c626416c49 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,6 +1,5 @@ use std::{ env, - ffi::OsString, path::{Path, PathBuf}, }; @@ -20,7 +19,12 @@ xflags::xflags! { /// Input file required input: PathBuf /// Extra options to pass to the flux binary, e.g. `cargo xtask run file.rs -- -Zdump-mir=y` - repeated opts: OsString + repeated opts: String + } + /// Expand flux macros + cmd expand { + /// Input file + required input: PathBuf } /// Install flux binaries to ~/.cargo/bin and precompiled libraries and driver to ~/.flux cmd install { @@ -61,14 +65,19 @@ fn main() -> anyhow::Result<()> { XtaskCmd::Doc(args) => doc(sh, args), XtaskCmd::BuildSysroot(_) => build_sysroot(&sh), XtaskCmd::Uninstall(_) => uninstall(&sh), + XtaskCmd::Expand(args) => expand(&sh, args), } } -fn test(sh: Shell, args: Test) -> anyhow::Result<()> { - let Test { filter } = args; - build_sysroot(&sh)?; +fn prepare(sh: &Shell) -> Result<(), anyhow::Error> { + build_sysroot(sh)?; cmd!(sh, "cargo build").run()?; + Ok(()) +} +fn test(sh: Shell, args: Test) -> anyhow::Result<()> { + let Test { filter } = args; + prepare(&sh)?; if let Some(filter) = filter { cmd!(sh, "cargo test -p flux-tests -- --test-args {filter}").run()?; } else { @@ -78,16 +87,33 @@ fn test(sh: Shell, args: Test) -> anyhow::Result<()> { } fn run(sh: Shell, args: Run) -> anyhow::Result<()> { - let Run { input, opts } = args; - build_sysroot(&sh)?; - cmd!(sh, "cargo build").run()?; + run_inner( + &sh, + args.input, + ["-Ztrack-diagnostics=y".to_string()] + .into_iter() + .chain(args.opts), + )?; + Ok(()) +} + +fn expand(sh: &Shell, args: Expand) -> Result<(), anyhow::Error> { + run_inner(sh, args.input, ["-Zunpretty=expanded".to_string()])?; + Ok(()) +} +fn run_inner( + sh: &Shell, + input: PathBuf, + flags: impl IntoIterator, +) -> Result<(), anyhow::Error> { + prepare(sh)?; let flux_path = find_flux_path(); let _env = sh.push_env(FLUX_SYSROOT, flux_path.parent().unwrap()); let mut rustc_flags = flux_tests::rustc_flags(); - rustc_flags.extend(["-Ztrack-diagnostics=y".to_string()]); + rustc_flags.extend(flags); - cmd!(sh, "{flux_path} {rustc_flags...} {opts...} {input}").run()?; + cmd!(sh, "{flux_path} {rustc_flags...} {input}").run()?; Ok(()) } From c8bcbbb88fae972e9bc08e72945bae08a438a03c Mon Sep 17 00:00:00 2001 From: Nico Lehmann Date: Thu, 16 Nov 2023 13:58:42 -0800 Subject: [PATCH 2/2] Implement refined generics --- lib/flux-attrs/src/ast.rs | 278 +++++++++++++++++++++++++++++++++++++- 1 file changed, 271 insertions(+), 7 deletions(-) diff --git a/lib/flux-attrs/src/ast.rs b/lib/flux-attrs/src/ast.rs index f99ea4edf4..33fe59798e 100644 --- a/lib/flux-attrs/src/ast.rs +++ b/lib/flux-attrs/src/ast.rs @@ -9,7 +9,7 @@ use syn::{ parse::{Parse, ParseStream, Peek}, punctuated::Punctuated, token::{self, Mut, Paren}, - Attribute, Generics, Ident, Result, Token, Visibility, + Attribute, Ident, Result, Token, Visibility, }; use crate::flux_tool_attrs; @@ -33,6 +33,61 @@ pub struct ItemFn { pub block: Block, } +#[derive(Debug)] +pub struct Generics { + pub lt_token: Option, + pub params: Punctuated, + pub gt_token: Option]>, + pub where_clause: Option, +} + +impl Default for Generics { + fn default() -> Self { + Generics { lt_token: None, params: Punctuated::new(), gt_token: None, where_clause: None } + } +} + +#[derive(Debug)] +pub enum GenericParam { + /// A lifetime parameter: `'a: 'b + 'c + 'd`. + Lifetime(syn::LifetimeParam), + + /// A generic type parameter: `T: Into`. + Type(TypeParam), + + /// A const generic parameter: `const LENGTH: usize`. + Const(syn::ConstParam), +} + +#[derive(Debug)] +pub struct TypeParam { + pub attrs: Vec, + pub ident: Ident, + pub as_token: Option, + pub param_kind: ParamKind, + pub colon_token: Option, + pub bounds: Punctuated, + // pub eq_token: Option, + // pub default: Option, +} + +#[derive(Debug)] +pub enum ParamKind { + Type(Token![type]), + Base(kw::base), + Default, +} + +impl ToTokens for ParamKind { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + ParamKind::Type(t) => t.to_tokens(tokens), + ParamKind::Base(t) => t.to_tokens(tokens), + ParamKind::Default => {} + } + } +} + pub struct ItemStruct { pub attrs: Vec, pub vis: Visibility, @@ -544,6 +599,138 @@ impl Parse for ItemStruct { } } +impl Parse for Generics { + fn parse(input: ParseStream) -> Result { + if !input.peek(Token![<]) { + return Ok(Generics::default()); + } + + let lt_token: Token![<] = input.parse()?; + + let mut params = Punctuated::new(); + loop { + if input.peek(Token![>]) { + break; + } + + let attrs = input.call(Attribute::parse_outer)?; + let lookahead = input.lookahead1(); + if lookahead.peek(syn::Lifetime) { + params.push_value(GenericParam::Lifetime(syn::LifetimeParam { + attrs, + ..input.parse()? + })); + } else if lookahead.peek(Ident) { + params.push_value(GenericParam::Type(TypeParam { attrs, ..input.parse()? })); + } else if lookahead.peek(Token![const]) { + params.push_value(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? })); + } else if input.peek(Token![_]) { + params.push_value(GenericParam::Type(TypeParam { + attrs, + ident: input.call(Ident::parse_any)?, + as_token: None, + param_kind: ParamKind::Default, + colon_token: None, + bounds: Punctuated::new(), + // eq_token: None, + // default: None, + })); + } else { + return Err(lookahead.error()); + } + + if input.peek(Token![>]) { + break; + } + let punct = input.parse()?; + params.push_punct(punct); + } + + let gt_token: Token![>] = input.parse()?; + + Ok(Generics { + lt_token: Some(lt_token), + params, + gt_token: Some(gt_token), + where_clause: None, + }) + } +} + +impl Parse for GenericParam { + fn parse(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + + let lookahead = input.lookahead1(); + if lookahead.peek(Ident) { + Ok(GenericParam::Type(TypeParam { attrs, ..input.parse()? })) + } else if lookahead.peek(syn::Lifetime) { + Ok(GenericParam::Lifetime(syn::LifetimeParam { attrs, ..input.parse()? })) + } else if lookahead.peek(Token![const]) { + Ok(GenericParam::Const(syn::ConstParam { attrs, ..input.parse()? })) + } else { + Err(lookahead.error()) + } + } +} + +impl Parse for TypeParam { + fn parse(input: ParseStream) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + let ident: Ident = input.parse()?; + + let as_token: Option = input.parse()?; + let mut param_kind = ParamKind::Default; + if as_token.is_some() { + param_kind = input.parse()?; + } + + let colon_token: Option = input.parse()?; + + let mut bounds = Punctuated::new(); + if colon_token.is_some() { + loop { + if input.peek(Token![,]) || input.peek(Token![>]) || input.peek(Token![=]) { + break; + } + let value: syn::TypeParamBound = input.parse()?; + bounds.push_value(value); + if !input.peek(Token![+]) { + break; + } + let punct: Token![+] = input.parse()?; + bounds.push_punct(punct); + } + } + // let eq_token: Option = input.parse()?; + // let default = if eq_token.is_some() { Some(input.parse::()?) } else { None }; + + Ok(TypeParam { + attrs, + ident, + as_token, + param_kind, + colon_token, + bounds, + // eq_token, + // default, + }) + } +} + +impl Parse for ParamKind { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![type]) { + input.parse().map(ParamKind::Type) + } else if lookahead.peek(kw::base) { + input.parse().map(ParamKind::Base) + } else { + Err(lookahead.error()) + } + } +} + impl Parse for ItemEnum { fn parse(input: ParseStream) -> Result { let mut attrs = input.call(Attribute::parse_outer)?; @@ -857,7 +1044,7 @@ impl Parse for Signature { let content; let fn_token = input.parse()?; let ident = input.parse()?; - let mut generics: syn::Generics = input.parse()?; + let mut generics: Generics = input.parse()?; let paren_token = parenthesized!(content in input); let inputs = content.parse_terminated(FnArg::parse, Token![,])?; let output = input.parse()?; @@ -1185,6 +1372,7 @@ mod kw { syn::custom_keyword!(requires); syn::custom_keyword!(refined); syn::custom_keyword!(by); + syn::custom_keyword!(base); } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1235,7 +1423,7 @@ impl ToTokens for ItemStruct { self.vis.to_tokens(tokens); self.struct_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); self.fields.to_tokens(tokens, |field, tokens| { #[cfg(flux_sysroot)] field.flux_tool_attr().to_tokens(tokens); @@ -1255,7 +1443,7 @@ impl ToTokens for ItemEnum { self.vis.to_tokens(tokens); self.enum_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); self.brace_token.surround(tokens, |tokens| { self.variants.to_tokens(tokens); }); @@ -1348,7 +1536,7 @@ impl ItemType { } self.type_token.to_tokens(tokens); self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, mode); if let Some(params) = &self.index_params { params.to_tokens_inner(tokens, mode); } @@ -1360,6 +1548,65 @@ impl ItemType { } } +impl Generics { + fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) { + if self.params.is_empty() { + return; + } + + tokens_or_default(self.lt_token.as_ref(), tokens); + + for param in self.params.pairs() { + match mode { + Mode::Rust => { + param.to_tokens(tokens); + } + Mode::Flux => { + if let GenericParam::Type(p) = param.value() { + p.to_tokens(tokens, mode); + param.punct().to_tokens(tokens); + } + } + } + } + + tokens_or_default(self.gt_token.as_ref(), tokens); + } +} + +impl ToTokens for GenericParam { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + GenericParam::Lifetime(p) => p.to_tokens(tokens), + GenericParam::Type(p) => p.to_tokens(tokens, Mode::Rust), + GenericParam::Const(p) => p.to_tokens(tokens), + } + } +} + +impl TypeParam { + fn to_tokens(&self, tokens: &mut TokenStream, mode: Mode) { + tokens.append_all(outer(&self.attrs)); + self.ident.to_tokens(tokens); + + if mode == Mode::Flux { + if let Some(as_token) = self.as_token { + as_token.to_tokens(tokens); + self.param_kind.to_tokens(tokens); + } + } + + if !self.bounds.is_empty() && mode == Mode::Rust { + tokens_or_default(self.colon_token.as_ref(), tokens); + self.bounds.to_tokens(tokens); + } + // if let Some(default) = &self.default { + // tokens_or_default(self.eq_token.as_ref(), tokens); + // default.to_tokens(tokens); + // } + } +} + impl IndexParams { fn to_tokens_inner(&self, tokens: &mut TokenStream, mode: Mode) { if mode == Mode::Flux { @@ -1377,7 +1624,7 @@ impl ToTokens for ItemImpl { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.append_all(&self.attrs); self.impl_token.to_tokens(tokens); - self.generics.to_tokens(tokens); + self.generics.to_tokens(tokens, Mode::Rust); if let Some((trait_, for_token)) = &self.trait_ { trait_.to_tokens(tokens); for_token.to_tokens(tokens); @@ -1435,8 +1682,8 @@ impl Signature { self.fn_token.to_tokens(tokens); if mode == Mode::Rust { self.ident.to_tokens(tokens); - self.generics.to_tokens(tokens); } + self.generics.to_tokens(tokens, mode); self.paren_token.surround(tokens, |tokens| { for fn_arg in self.inputs.pairs() { fn_arg.value().to_tokens_inner(tokens, mode); @@ -1735,3 +1982,20 @@ impl ToTokens for Block { .surround(tokens, |tokens| self.stmts.to_tokens(tokens)); } } + +fn tokens_or_default(x: Option<&T>, tokens: &mut TokenStream) { + match x { + Some(t) => t.to_tokens(tokens), + None => T::default().to_tokens(tokens), + } +} + +fn outer(attrs: &[Attribute]) -> impl Iterator { + fn is_outer(attr: &&Attribute) -> bool { + match attr.style { + syn::AttrStyle::Outer => true, + syn::AttrStyle::Inner(_) => false, + } + } + attrs.iter().filter(is_outer) +}