From ece5b8516d8db4639aa69acd24f6a9eaadfa0766 Mon Sep 17 00:00:00 2001 From: Nico Lehmann Date: Thu, 28 Nov 2024 00:08:42 -0800 Subject: [PATCH] More cleanup --- crates/flux-desugar/src/desugar.rs | 44 ++++------- .../flux-driver/src/collector/extern_specs.rs | 12 +-- crates/flux-fhir-analysis/src/lib.rs | 4 +- crates/flux-fhir-analysis/src/wf/mod.rs | 6 +- crates/flux-infer/src/infer.rs | 13 ++-- crates/flux-middle/src/queries.rs | 6 +- crates/flux-middle/src/rty/mod.rs | 78 ++++++++++++------- crates/flux-middle/src/rty/refining.rs | 29 +++---- crates/flux-refineck/src/checker.rs | 6 +- crates/flux-syntax/src/grammar.lalrpop | 10 +-- crates/flux-syntax/src/surface.rs | 4 +- crates/flux-syntax/src/surface/visit.rs | 8 +- 12 files changed, 106 insertions(+), 114 deletions(-) diff --git a/crates/flux-desugar/src/desugar.rs b/crates/flux-desugar/src/desugar.rs index 4c3a7326d7..588afb5dbf 100644 --- a/crates/flux-desugar/src/desugar.rs +++ b/crates/flux-desugar/src/desugar.rs @@ -9,10 +9,11 @@ use flux_middle::{ try_alloc_slice, MaybeExternId, ResolverOutput, }; use flux_syntax::{ - surface::{self, visit::Visitor as _, ConstructorArgs, FieldExpr, NodeId, Spread}, + surface::{self, visit::Visitor as _, ConstructorArg, NodeId}, walk_list, }; use hir::{def::DefKind, ItemKind}; +use itertools::{Either, Itertools}; use rustc_data_structures::fx::FxIndexSet; use rustc_errors::{Diagnostic, ErrorGuaranteed}; use rustc_hash::FxHashSet; @@ -1307,20 +1308,17 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> { self.genv().alloc(e2?), ) } - surface::ExprKind::Constructor(path, constructor_args) => { - let path = match path { - Some(path) => Some(self.desugar_constructor_path(path)?), - None => None, - }; - let field_exprs = constructor_args - .iter() - .filter_map(|arg| { - match arg { - ConstructorArgs::FieldExpr(e) => Some(e), - ConstructorArgs::Spread(_) => None, - } - }) - .collect::>(); + surface::ExprKind::Constructor(path, args) => { + let path = path + .as_ref() + .map(|p| self.desugar_constructor_path(p)) + .transpose()?; + let (field_exprs, spreads): (Vec<_>, Vec<_>) = args.iter().partition_map(|arg| { + match arg { + ConstructorArg::FieldExpr(e) => Either::Left(e), + ConstructorArg::Spread(s) => Either::Right(s), + } + }); let field_exprs = try_alloc_slice!(self.genv(), field_exprs, |field_expr| { let e = self.desugar_expr(&field_expr.expr)?; @@ -1332,16 +1330,6 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> { }) })?; - let spreads = constructor_args - .iter() - .filter_map(|arg| { - match arg { - ConstructorArgs::FieldExpr(_) => None, - ConstructorArgs::Spread(s) => Some(s), - } - }) - .collect::>(); - let spread = match &spreads[..] { [] => None, [s] => { @@ -1353,10 +1341,8 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> { Some(self.genv().alloc(spread)) } [s1, s2, ..] => { - // Multiple spreads found - emit an error - return Err(self.emit_err(errors::MultipleSpreadsInConstructor::new( - s1.span, s2.span, - ))); + let err = errors::MultipleSpreadsInConstructor::new(s1.span, s2.span); + return Err(self.emit_err(err)); } }; fhir::ExprKind::Constructor(path, field_exprs, spread) diff --git a/crates/flux-driver/src/collector/extern_specs.rs b/crates/flux-driver/src/collector/extern_specs.rs index 258f6b993b..9ab78f876e 100644 --- a/crates/flux-driver/src/collector/extern_specs.rs +++ b/crates/flux-driver/src/collector/extern_specs.rs @@ -26,18 +26,12 @@ struct ExternImplItem { item_id: DefId, } -impl<'mismatch_check, 'sess, 'tcx> ExternSpecCollector<'mismatch_check, 'sess, 'tcx> { - pub(super) fn collect( - inner: &'mismatch_check mut SpecCollector<'sess, 'tcx>, - body_id: BodyId, - ) -> Result { +impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> { + pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result { Self::new(inner, body_id)?.run() } - fn new( - inner: &'mismatch_check mut SpecCollector<'sess, 'tcx>, - body_id: BodyId, - ) -> Result { + fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result { let body = inner.tcx.hir().body(body_id); if let hir::ExprKind::Block(block, _) = body.value.kind { Ok(Self { inner, block }) diff --git a/crates/flux-fhir-analysis/src/lib.rs b/crates/flux-fhir-analysis/src/lib.rs index c05b3b6edc..3dedaeb8da 100644 --- a/crates/flux-fhir-analysis/src/lib.rs +++ b/crates/flux-fhir-analysis/src/lib.rs @@ -403,9 +403,9 @@ fn refinement_generics_of( }) => { let wfckresults = genv.check_wf(local_id)?; let params = conv::conv_refinement_generics(generics.refinement_params, &wfckresults)?; - Ok(rty::RefinementGenerics { parent, parent_count, params }) + Ok(rty::RefinementGenerics { parent, parent_count, own_params: params }) } - _ => Ok(rty::RefinementGenerics { parent, parent_count, params: rty::List::empty() }), + _ => Ok(rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() }), } } diff --git a/crates/flux-fhir-analysis/src/wf/mod.rs b/crates/flux-fhir-analysis/src/wf/mod.rs index eadd3c3f3b..4b9c65e1a0 100644 --- a/crates/flux-fhir-analysis/src/wf/mod.rs +++ b/crates/flux-fhir-analysis/src/wf/mod.rs @@ -372,15 +372,15 @@ impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> { return; }; - if path.refine.len() != generics.params.len() { + if path.refine.len() != generics.own_params.len() { self.errors.emit(errors::EarlyBoundArgCountMismatch::new( path.span, - generics.params.len(), + generics.own_params.len(), path.refine.len(), )); } - for (expr, param) in iter::zip(path.refine, &generics.params) { + for (expr, param) in iter::zip(path.refine, &generics.own_params) { self.infcx .check_expr(expr, ¶m.sort) .collect_err(&mut self.errors); diff --git a/crates/flux-infer/src/infer.rs b/crates/flux-infer/src/infer.rs index ac086c5d08..6ad5d9eb45 100644 --- a/crates/flux-infer/src/infer.rs +++ b/crates/flux-infer/src/infer.rs @@ -11,8 +11,8 @@ use flux_middle::{ fold::TypeFoldable, AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate, ESpan, EVar, EVarGen, EarlyBinder, Expr, ExprKind, GenericArg, GenericArgs, HoleKind, InferMode, - Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind, - Var, + Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, RefineArgs, RefineArgsExt, + Region, Sort, Ty, TyKind, Var, }, }; use itertools::{izip, Itertools}; @@ -164,11 +164,10 @@ impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> { InferCtxtAt { infcx: self, span } } - pub fn instantiate_refine_args(&mut self, callee_def_id: DefId) -> InferResult> { - Ok(self - .genv - .refinement_generics_of(callee_def_id)? - .collect_all_params(self.genv, |param| self.fresh_infer_var(¶m.sort, param.mode))?) + pub fn instantiate_refine_args(&mut self, callee_def_id: DefId) -> InferResult> { + Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| { + self.fresh_infer_var(¶m.sort, param.mode) + })?) } pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec { diff --git a/crates/flux-middle/src/queries.rs b/crates/flux-middle/src/queries.rs index 03f692cb08..7f75e859d9 100644 --- a/crates/flux-middle/src/queries.rs +++ b/crates/flux-middle/src/queries.rs @@ -462,7 +462,11 @@ impl<'genv, 'tcx> Queries<'genv, 'tcx> { |def_id| genv.cstore().refinement_generics_of(def_id), |def_id| { let parent = genv.tcx().generics_of(def_id).parent; - Ok(rty::RefinementGenerics { parent, parent_count: 0, params: List::empty() }) + Ok(rty::RefinementGenerics { + parent, + parent_count: 0, + own_params: List::empty(), + }) }, ) }) diff --git a/crates/flux-middle/src/rty/mod.rs b/crates/flux-middle/src/rty/mod.rs index 7cff5e585e..e0184c54ad 100644 --- a/crates/flux-middle/src/rty/mod.rs +++ b/crates/flux-middle/src/rty/mod.rs @@ -217,7 +217,7 @@ impl Generics { pub struct RefinementGenerics { pub parent: Option, pub parent_count: usize, - pub params: List, + pub own_params: List, } #[derive(PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable)] @@ -1725,16 +1725,41 @@ pub type RefineArgs = List; #[extension(pub trait RefineArgsExt)] impl RefineArgs { fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult { + Self::for_item(genv, def_id, |param, exprs| { + let index = exprs.len() as u32; + Expr::var(Var::EarlyParam(EarlyReftParam { index, name: param.name })) + }) + } + + fn for_item(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult + where + F: FnMut(&RefineParam, &[Expr]) -> Expr, + { let reft_generics = genv.refinement_generics_of(def_id)?; - let mut args = vec![]; - for i in 0..reft_generics.count() { - let param = reft_generics.param_at(i, genv)?; - let expr = - Expr::var(Var::EarlyParam(EarlyReftParam { index: i as u32, name: param.name })); - args.push(expr); - } + let count = reft_generics.count(); + let mut args = Vec::with_capacity(count); + Self::fill_item(genv, &mut args, &reft_generics, &mut mk)?; Ok(List::from_vec(args)) } + + fn fill_item( + genv: GlobalEnv, + args: &mut Vec, + reft_generics: &RefinementGenerics, + mk: &mut F, + ) -> QueryResult<()> + where + F: FnMut(&RefineParam, &[Expr]) -> Expr, + { + if let Some(def_id) = reft_generics.parent { + let parent_generics = genv.refinement_generics_of(def_id)?; + Self::fill_item(genv, args, &parent_generics, mk)?; + } + for param in &reft_generics.own_params { + args.push(mk(param, args)); + } + Ok(()) + } } /// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias @@ -1926,11 +1951,8 @@ impl GenericArg { Ok(List::from_vec(args)) } - pub fn identity_for_item( - genv: GlobalEnv, - def_id: impl Into, - ) -> QueryResult { - Self::for_item(genv, def_id.into(), |param, _| GenericArg::from_param_def(param)) + pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult { + Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param)) } fn fill_item( @@ -2106,12 +2128,12 @@ impl CoroutineObligPredicate { impl RefinementGenerics { pub fn count(&self) -> usize { - self.parent_count + self.params.len() + self.parent_count + self.own_params.len() } pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult { if let Some(index) = param_index.checked_sub(self.parent_count) { - Ok(self.params[index].clone()) + Ok(self.own_params[index].clone()) } else { let parent = self.parent.expect("parent_count > 0 but no parent?"); genv.refinement_generics_of(parent)? @@ -2119,19 +2141,19 @@ impl RefinementGenerics { } } - /// Iterate and collect all parameters in this item including parents - pub fn collect_all_params( - &self, - genv: GlobalEnv, - mut f: impl FnMut(RefineParam) -> T, - ) -> QueryResult - where - S: FromIterator, - { - (0..self.count()) - .map(|i| Ok(f(self.param_at(i, genv)?))) - .try_collect() - } + // /// Iterate and collect all parameters in this item including parents + // pub fn collect_all_params( + // &self, + // genv: GlobalEnv, + // mut f: impl FnMut(RefineParam) -> T, + // ) -> QueryResult + // where + // S: FromIterator, + // { + // (0..self.count()) + // .map(|i| Ok(f(self.param_at(i, genv)?))) + // .try_collect() + // } } impl EarlyBinder { diff --git a/crates/flux-middle/src/rty/refining.rs b/crates/flux-middle/src/rty/refining.rs index 72f3061663..f2fba810d6 100644 --- a/crates/flux-middle/src/rty/refining.rs +++ b/crates/flux-middle/src/rty/refining.rs @@ -11,7 +11,7 @@ use rustc_hir::{def::DefKind, def_id::DefId}; use rustc_middle::ty::ParamTy; use rustc_target::abi::VariantIdx; -use super::fold::TypeFoldable; +use super::{fold::TypeFoldable, RefineArgsExt}; use crate::{ global_env::GlobalEnv, queries::{QueryErr, QueryResult}, @@ -318,10 +318,15 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { let def_id = alias_ty.def_id; let args = self.refine_generic_args(def_id, &alias_ty.args)?; - let refine_args = self.refine_refine_args_for_alias_ty(def_id, alias_kind)?; + let refine_args = if let ty::AliasKind::Opaque = alias_kind { + rty::RefineArgs::for_item(self.genv, def_id, |param, _| { + rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone())) + })? + } else { + List::empty() + }; - let res = rty::AliasTy::new(def_id, args, refine_args); - Ok(res) + Ok(rty::AliasTy::new(def_id, args, refine_args)) } pub fn refine_ty(&self, ty: &ty::Ty) -> QueryResult { @@ -418,22 +423,6 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> { self.genv.generics_of(def_id) } - fn refine_refine_args_for_alias_ty( - &self, - def_id: DefId, - alias_kind: ty::AliasKind, - ) -> QueryResult { - if let ty::AliasKind::Opaque = alias_kind { - self.genv - .refinement_generics_of(def_id)? - .collect_all_params(self.genv, |param| { - rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone())) - }) - } else { - Ok(List::empty()) - } - } - fn param(&self, param_ty: ParamTy) -> QueryResult { self.generics.param_at(param_ty.index as usize, self.genv) } diff --git a/crates/flux-refineck/src/checker.rs b/crates/flux-refineck/src/checker.rs index 4f10394e92..2ff690acfc 100644 --- a/crates/flux-refineck/src/checker.rs +++ b/crates/flux-refineck/src/checker.rs @@ -348,7 +348,7 @@ pub(crate) fn trait_impl_subtyping( let trait_fn_sig = genv.fn_sig(trait_method_id).with_span(span)?; let tcx = genv.tcx(); let impl_id = tcx.impl_of_method(def_id.to_def_id()).unwrap(); - let impl_args = GenericArg::identity_for_item(genv, def_id).with_span(span)?; + let impl_args = GenericArg::identity_for_item(genv, def_id.to_def_id()).with_span(span)?; let trait_args = impl_args.rebase_onto(&tcx, impl_id, &trait_ref.args); let trait_refine_args = RefineArgs::identity_for_item(genv, trait_method_id).with_span(span)?; @@ -782,9 +782,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> { .instantiate_refine_args(callee_def_id) .with_span(span)? } - None => { - vec![] - } + None => rty::List::empty(), }; let clauses = match callee_def_id { diff --git a/crates/flux-syntax/src/grammar.lalrpop b/crates/flux-syntax/src/grammar.lalrpop index 643bdc4e8a..81da483fdb 100644 --- a/crates/flux-syntax/src/grammar.lalrpop +++ b/crates/flux-syntax/src/grammar.lalrpop @@ -515,14 +515,14 @@ Level9 = { > } Level10: surface::Expr = { - "{" > "}" if AllowStruct == "true" => { + "{" > "}" if AllowStruct == "true" => { surface::Expr { kind: surface::ExprKind::Constructor(Some(name), args) , node_id: cx.next_node_id(), span: cx.map_span(lo, hi), } }, - "{" > "}" if AllowStruct == "true" => { + "{" > "}" if AllowStruct == "true" => { surface::Expr { kind: surface::ExprKind::Constructor(None, args) , node_id: cx.next_node_id(), @@ -573,9 +573,9 @@ Level10: surface::Expr = { "(" > ")" } -ConstructorArgs: surface::ConstructorArgs = { +ConstructorArg: surface::ConstructorArg = { ":" > => { - surface::ConstructorArgs::FieldExpr(surface::FieldExpr { + surface::ConstructorArg::FieldExpr(surface::FieldExpr { ident: name, expr: arg, node_id: cx.next_node_id(), @@ -583,7 +583,7 @@ ConstructorArgs: surface::ConstructorArgs = { }) }, ".." > => { - surface::ConstructorArgs::Spread(surface::Spread { + surface::ConstructorArg::Spread(surface::Spread { expr: spread, node_id: cx.next_node_id(), span: cx.map_span(lo, hi) diff --git a/crates/flux-syntax/src/surface.rs b/crates/flux-syntax/src/surface.rs index 4bf787c019..efc52227a6 100644 --- a/crates/flux-syntax/src/surface.rs +++ b/crates/flux-syntax/src/surface.rs @@ -468,7 +468,7 @@ pub struct Spread { } #[derive(Debug)] -pub enum ConstructorArgs { +pub enum ConstructorArg { FieldExpr(FieldExpr), Spread(Spread), } @@ -490,7 +490,7 @@ pub enum ExprKind { App(Ident, Vec), Alias(AliasReft, Vec), IfThenElse(Box<[Expr; 3]>), - Constructor(Option, Vec), + Constructor(Option, Vec), } /// A [`Path`] but for refinement expressions diff --git a/crates/flux-syntax/src/surface/visit.rs b/crates/flux-syntax/src/surface/visit.rs index 8bf131a896..7527685801 100644 --- a/crates/flux-syntax/src/surface/visit.rs +++ b/crates/flux-syntax/src/surface/visit.rs @@ -1,7 +1,7 @@ use rustc_span::symbol::Ident; use super::{ - AliasReft, Async, BaseSort, BaseTy, BaseTyKind, ConstArg, ConstructorArgs, Ensures, EnumDef, + AliasReft, Async, BaseSort, BaseTy, BaseTyKind, ConstArg, ConstructorArg, Ensures, EnumDef, Expr, ExprKind, ExprPath, ExprPathSegment, FieldExpr, FnInput, FnOutput, FnRetTy, FnSig, GenericArg, GenericArgKind, GenericParam, Generics, Impl, ImplAssocReft, Indices, Lit, Path, PathSegment, Qualifier, RefineArg, RefineParam, Sort, SortPath, SpecFunc, StructDef, Trait, @@ -153,10 +153,10 @@ pub trait Visitor: Sized { walk_expr(self, expr); } - fn visit_constructor_args(&mut self, expr: &ConstructorArgs) { + fn visit_constructor_args(&mut self, expr: &ConstructorArg) { match expr { - ConstructorArgs::FieldExpr(field_expr) => walk_field_expr(self, field_expr), - ConstructorArgs::Spread(spread) => self.visit_expr(&spread.expr), + ConstructorArg::FieldExpr(field_expr) => walk_field_expr(self, field_expr), + ConstructorArg::Spread(spread) => self.visit_expr(&spread.expr), } }