Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Nov 28, 2024
1 parent f3d79e7 commit ece5b85
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 114 deletions.
44 changes: 15 additions & 29 deletions crates/flux-desugar/src/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use flux_middle::{
try_alloc_slice, MaybeExternId, ResolverOutput,
};
use flux_syntax::{
surface::{self, visit::Visitor as _, ConstructorArgs, FieldExpr, NodeId, Spread},
surface::{self, visit::Visitor as _, ConstructorArg, NodeId},
walk_list,
};
use hir::{def::DefKind, ItemKind};
use itertools::{Either, Itertools};
use rustc_data_structures::fx::FxIndexSet;
use rustc_errors::{Diagnostic, ErrorGuaranteed};
use rustc_hash::FxHashSet;
Expand Down Expand Up @@ -1307,20 +1308,17 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
self.genv().alloc(e2?),
)
}
surface::ExprKind::Constructor(path, constructor_args) => {
let path = match path {
Some(path) => Some(self.desugar_constructor_path(path)?),
None => None,
};
let field_exprs = constructor_args
.iter()
.filter_map(|arg| {
match arg {
ConstructorArgs::FieldExpr(e) => Some(e),
ConstructorArgs::Spread(_) => None,
}
})
.collect::<Vec<&FieldExpr>>();
surface::ExprKind::Constructor(path, args) => {
let path = path
.as_ref()
.map(|p| self.desugar_constructor_path(p))
.transpose()?;
let (field_exprs, spreads): (Vec<_>, Vec<_>) = args.iter().partition_map(|arg| {
match arg {
ConstructorArg::FieldExpr(e) => Either::Left(e),
ConstructorArg::Spread(s) => Either::Right(s),
}
});

let field_exprs = try_alloc_slice!(self.genv(), field_exprs, |field_expr| {
let e = self.desugar_expr(&field_expr.expr)?;
Expand All @@ -1332,16 +1330,6 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
})
})?;

let spreads = constructor_args
.iter()
.filter_map(|arg| {
match arg {
ConstructorArgs::FieldExpr(_) => None,
ConstructorArgs::Spread(s) => Some(s),
}
})
.collect::<Vec<&Spread>>();

let spread = match &spreads[..] {
[] => None,
[s] => {
Expand All @@ -1353,10 +1341,8 @@ trait DesugarCtxt<'genv, 'tcx: 'genv> {
Some(self.genv().alloc(spread))
}
[s1, s2, ..] => {
// Multiple spreads found - emit an error
return Err(self.emit_err(errors::MultipleSpreadsInConstructor::new(
s1.span, s2.span,
)));
let err = errors::MultipleSpreadsInConstructor::new(s1.span, s2.span);
return Err(self.emit_err(err));
}
};
fhir::ExprKind::Constructor(path, field_exprs, spread)
Expand Down
12 changes: 3 additions & 9 deletions crates/flux-driver/src/collector/extern_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,12 @@ struct ExternImplItem {
item_id: DefId,
}

impl<'mismatch_check, 'sess, 'tcx> ExternSpecCollector<'mismatch_check, 'sess, 'tcx> {
pub(super) fn collect(
inner: &'mismatch_check mut SpecCollector<'sess, 'tcx>,
body_id: BodyId,
) -> Result {
impl<'a, 'sess, 'tcx> ExternSpecCollector<'a, 'sess, 'tcx> {
pub(super) fn collect(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result {
Self::new(inner, body_id)?.run()
}

fn new(
inner: &'mismatch_check mut SpecCollector<'sess, 'tcx>,
body_id: BodyId,
) -> Result<Self> {
fn new(inner: &'a mut SpecCollector<'sess, 'tcx>, body_id: BodyId) -> Result<Self> {
let body = inner.tcx.hir().body(body_id);
if let hir::ExprKind::Block(block, _) = body.value.kind {
Ok(Self { inner, block })
Expand Down
4 changes: 2 additions & 2 deletions crates/flux-fhir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ fn refinement_generics_of(
}) => {
let wfckresults = genv.check_wf(local_id)?;
let params = conv::conv_refinement_generics(generics.refinement_params, &wfckresults)?;
Ok(rty::RefinementGenerics { parent, parent_count, params })
Ok(rty::RefinementGenerics { parent, parent_count, own_params: params })
}
_ => Ok(rty::RefinementGenerics { parent, parent_count, params: rty::List::empty() }),
_ => Ok(rty::RefinementGenerics { parent, parent_count, own_params: rty::List::empty() }),
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/flux-fhir-analysis/src/wf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,15 @@ impl<'genv> fhir::visit::Visitor<'genv> for Wf<'_, 'genv, '_> {
return;
};

if path.refine.len() != generics.params.len() {
if path.refine.len() != generics.own_params.len() {
self.errors.emit(errors::EarlyBoundArgCountMismatch::new(
path.span,
generics.params.len(),
generics.own_params.len(),
path.refine.len(),
));
}

for (expr, param) in iter::zip(path.refine, &generics.params) {
for (expr, param) in iter::zip(path.refine, &generics.own_params) {
self.infcx
.check_expr(expr, &param.sort)
.collect_err(&mut self.errors);
Expand Down
13 changes: 6 additions & 7 deletions crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use flux_middle::{
fold::TypeFoldable,
AliasKind, AliasTy, BaseTy, Binder, BoundVariableKinds, CoroutineObligPredicate, ESpan,
EVar, EVarGen, EarlyBinder, Expr, ExprKind, GenericArg, GenericArgs, HoleKind, InferMode,
Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, Region, Sort, Ty, TyKind,
Var,
Lambda, List, Loc, Mutability, Path, PolyVariant, PtrKind, Ref, RefineArgs, RefineArgsExt,
Region, Sort, Ty, TyKind, Var,
},
};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -164,11 +164,10 @@ impl<'infcx, 'genv, 'tcx> InferCtxt<'infcx, 'genv, 'tcx> {
InferCtxtAt { infcx: self, span }
}

pub fn instantiate_refine_args(&mut self, callee_def_id: DefId) -> InferResult<Vec<Expr>> {
Ok(self
.genv
.refinement_generics_of(callee_def_id)?
.collect_all_params(self.genv, |param| self.fresh_infer_var(&param.sort, param.mode))?)
pub fn instantiate_refine_args(&mut self, callee_def_id: DefId) -> InferResult<List<Expr>> {
Ok(RefineArgs::for_item(self.genv, callee_def_id, |param, _| {
self.fresh_infer_var(&param.sort, param.mode)
})?)
}

pub fn instantiate_generic_args(&mut self, args: &[GenericArg]) -> Vec<GenericArg> {
Expand Down
6 changes: 5 additions & 1 deletion crates/flux-middle/src/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,11 @@ impl<'genv, 'tcx> Queries<'genv, 'tcx> {
|def_id| genv.cstore().refinement_generics_of(def_id),
|def_id| {
let parent = genv.tcx().generics_of(def_id).parent;
Ok(rty::RefinementGenerics { parent, parent_count: 0, params: List::empty() })
Ok(rty::RefinementGenerics {
parent,
parent_count: 0,
own_params: List::empty(),
})
},
)
})
Expand Down
78 changes: 50 additions & 28 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl Generics {
pub struct RefinementGenerics {
pub parent: Option<DefId>,
pub parent_count: usize,
pub params: List<RefineParam>,
pub own_params: List<RefineParam>,
}

#[derive(PartialEq, Eq, Debug, Clone, Hash, TyEncodable, TyDecodable)]
Expand Down Expand Up @@ -1725,16 +1725,41 @@ pub type RefineArgs = List<Expr>;
#[extension(pub trait RefineArgsExt)]
impl RefineArgs {
fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<RefineArgs> {
Self::for_item(genv, def_id, |param, exprs| {
let index = exprs.len() as u32;
Expr::var(Var::EarlyParam(EarlyReftParam { index, name: param.name }))
})
}

fn for_item<F>(genv: GlobalEnv, def_id: DefId, mut mk: F) -> QueryResult<RefineArgs>
where
F: FnMut(&RefineParam, &[Expr]) -> Expr,
{
let reft_generics = genv.refinement_generics_of(def_id)?;
let mut args = vec![];
for i in 0..reft_generics.count() {
let param = reft_generics.param_at(i, genv)?;
let expr =
Expr::var(Var::EarlyParam(EarlyReftParam { index: i as u32, name: param.name }));
args.push(expr);
}
let count = reft_generics.count();
let mut args = Vec::with_capacity(count);
Self::fill_item(genv, &mut args, &reft_generics, &mut mk)?;
Ok(List::from_vec(args))
}

fn fill_item<F>(
genv: GlobalEnv,
args: &mut Vec<Expr>,
reft_generics: &RefinementGenerics,
mk: &mut F,
) -> QueryResult<()>
where
F: FnMut(&RefineParam, &[Expr]) -> Expr,
{
if let Some(def_id) = reft_generics.parent {
let parent_generics = genv.refinement_generics_of(def_id)?;
Self::fill_item(genv, args, &parent_generics, mk)?;
}
for param in &reft_generics.own_params {
args.push(mk(param, args));
}
Ok(())
}
}

/// A type constructor meant to be used as generic a argument of [kind base]. This is just an alias
Expand Down Expand Up @@ -1926,11 +1951,8 @@ impl GenericArg {
Ok(List::from_vec(args))
}

pub fn identity_for_item(
genv: GlobalEnv,
def_id: impl Into<DefId>,
) -> QueryResult<GenericArgs> {
Self::for_item(genv, def_id.into(), |param, _| GenericArg::from_param_def(param))
pub fn identity_for_item(genv: GlobalEnv, def_id: DefId) -> QueryResult<GenericArgs> {
Self::for_item(genv, def_id, |param, _| GenericArg::from_param_def(param))
}

fn fill_item<F>(
Expand Down Expand Up @@ -2106,32 +2128,32 @@ impl CoroutineObligPredicate {

impl RefinementGenerics {
pub fn count(&self) -> usize {
self.parent_count + self.params.len()
self.parent_count + self.own_params.len()
}

pub fn param_at(&self, param_index: usize, genv: GlobalEnv) -> QueryResult<RefineParam> {
if let Some(index) = param_index.checked_sub(self.parent_count) {
Ok(self.params[index].clone())
Ok(self.own_params[index].clone())
} else {
let parent = self.parent.expect("parent_count > 0 but no parent?");
genv.refinement_generics_of(parent)?
.param_at(param_index, genv)
}
}

/// Iterate and collect all parameters in this item including parents
pub fn collect_all_params<T, S>(
&self,
genv: GlobalEnv,
mut f: impl FnMut(RefineParam) -> T,
) -> QueryResult<S>
where
S: FromIterator<T>,
{
(0..self.count())
.map(|i| Ok(f(self.param_at(i, genv)?)))
.try_collect()
}
// /// Iterate and collect all parameters in this item including parents
// pub fn collect_all_params<T, S>(
// &self,
// genv: GlobalEnv,
// mut f: impl FnMut(RefineParam) -> T,
// ) -> QueryResult<S>
// where
// S: FromIterator<T>,
// {
// (0..self.count())
// .map(|i| Ok(f(self.param_at(i, genv)?)))
// .try_collect()
// }
}

impl EarlyBinder<GenericPredicates> {
Expand Down
29 changes: 9 additions & 20 deletions crates/flux-middle/src/rty/refining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustc_hir::{def::DefKind, def_id::DefId};
use rustc_middle::ty::ParamTy;
use rustc_target::abi::VariantIdx;

use super::fold::TypeFoldable;
use super::{fold::TypeFoldable, RefineArgsExt};
use crate::{
global_env::GlobalEnv,
queries::{QueryErr, QueryResult},
Expand Down Expand Up @@ -318,10 +318,15 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
let def_id = alias_ty.def_id;
let args = self.refine_generic_args(def_id, &alias_ty.args)?;

let refine_args = self.refine_refine_args_for_alias_ty(def_id, alias_kind)?;
let refine_args = if let ty::AliasKind::Opaque = alias_kind {
rty::RefineArgs::for_item(self.genv, def_id, |param, _| {
rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone()))
})?
} else {
List::empty()
};

let res = rty::AliasTy::new(def_id, args, refine_args);
Ok(res)
Ok(rty::AliasTy::new(def_id, args, refine_args))
}

pub fn refine_ty(&self, ty: &ty::Ty) -> QueryResult<rty::Ty> {
Expand Down Expand Up @@ -418,22 +423,6 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
self.genv.generics_of(def_id)
}

fn refine_refine_args_for_alias_ty(
&self,
def_id: DefId,
alias_kind: ty::AliasKind,
) -> QueryResult<rty::RefineArgs> {
if let ty::AliasKind::Opaque = alias_kind {
self.genv
.refinement_generics_of(def_id)?
.collect_all_params(self.genv, |param| {
rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone()))
})
} else {
Ok(List::empty())
}
}

fn param(&self, param_ty: ParamTy) -> QueryResult<rty::GenericParamDef> {
self.generics.param_at(param_ty.index as usize, self.genv)
}
Expand Down
6 changes: 2 additions & 4 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ pub(crate) fn trait_impl_subtyping(
let trait_fn_sig = genv.fn_sig(trait_method_id).with_span(span)?;
let tcx = genv.tcx();
let impl_id = tcx.impl_of_method(def_id.to_def_id()).unwrap();
let impl_args = GenericArg::identity_for_item(genv, def_id).with_span(span)?;
let impl_args = GenericArg::identity_for_item(genv, def_id.to_def_id()).with_span(span)?;
let trait_args = impl_args.rebase_onto(&tcx, impl_id, &trait_ref.args);
let trait_refine_args =
RefineArgs::identity_for_item(genv, trait_method_id).with_span(span)?;
Expand Down Expand Up @@ -782,9 +782,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
.instantiate_refine_args(callee_def_id)
.with_span(span)?
}
None => {
vec![]
}
None => rty::List::empty(),
};

let clauses = match callee_def_id {
Expand Down
10 changes: 5 additions & 5 deletions crates/flux-syntax/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,14 @@ Level9<AllowStruct> = {
<Level10<AllowStruct>>
}
Level10<AllowStruct>: surface::Expr = {
<lo:@L> <name:ExprPath> "{" <args:Comma<ConstructorArgs>> "}" <hi:@L> if AllowStruct == "true" => {
<lo:@L> <name:ExprPath> "{" <args:Comma<ConstructorArg>> "}" <hi:@L> if AllowStruct == "true" => {
surface::Expr {
kind: surface::ExprKind::Constructor(Some(name), args) ,
node_id: cx.next_node_id(),
span: cx.map_span(lo, hi),
}
},
<lo:@L> "{" <args:Comma<ConstructorArgs>> "}" <hi:@L> if AllowStruct == "true" => {
<lo:@L> "{" <args:Comma<ConstructorArg>> "}" <hi:@L> if AllowStruct == "true" => {
surface::Expr {
kind: surface::ExprKind::Constructor(None, args) ,
node_id: cx.next_node_id(),
Expand Down Expand Up @@ -573,17 +573,17 @@ Level10<AllowStruct>: surface::Expr = {
"(" <Level1<AllowStruct>> ")"
}

ConstructorArgs: surface::ConstructorArgs = {
ConstructorArg: surface::ConstructorArg = {
<lo:@L> <name:Ident> ":" <arg:Level1<"true">> <hi:@L> => {
surface::ConstructorArgs::FieldExpr(surface::FieldExpr {
surface::ConstructorArg::FieldExpr(surface::FieldExpr {
ident: name,
expr: arg,
node_id: cx.next_node_id(),
span: cx.map_span(lo, hi)
})
},
<lo:@L> ".." <spread:Level1<"true">> <hi:@L> => {
surface::ConstructorArgs::Spread(surface::Spread {
surface::ConstructorArg::Spread(surface::Spread {
expr: spread,
node_id: cx.next_node_id(),
span: cx.map_span(lo, hi)
Expand Down
Loading

0 comments on commit ece5b85

Please sign in to comment.