diff --git a/crates/flux-fhir-analysis/src/conv.rs b/crates/flux-fhir-analysis/src/conv.rs index 423f83aa4a..3c95ba1fca 100644 --- a/crates/flux-fhir-analysis/src/conv.rs +++ b/crates/flux-fhir-analysis/src/conv.rs @@ -498,12 +498,20 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { enum_def .variants .iter() - .map(|variant_def| ConvCtxt::conv_enum_variant(genv, variant_def, wfckresults)) + .map(|variant_def| { + ConvCtxt::conv_enum_variant( + genv, + enum_def.owner_id.to_def_id(), + variant_def, + wfckresults, + ) + }) .try_collect() } fn conv_enum_variant( genv: &GlobalEnv, + adt_def_id: DefId, variant: &fhir::VariantDef, wfckresults: &fhir::WfckResults, ) -> QueryResult { @@ -512,14 +520,19 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { let mut env = Env::new(&[]); env.push_layer(Layer::list(&cx, 0, &variant.params, true)); + let adt_def = genv.adt_def(adt_def_id)?; let fields = variant .fields .iter() .map(|field| cx.conv_ty(&mut env, &field.ty)) .try_collect()?; - let args = rty::Index::from(cx.conv_refine_arg(&mut env, &variant.ret.idx)); - let ret = cx.conv_indexed_type(&mut env, &variant.ret.bty, args)?; - let variant = rty::VariantSig::new(fields, ret); + let idxs = cx.conv_refine_arg(&mut env, &variant.ret.idx).0; + let variant = rty::VariantSig::new( + adt_def, + rty::GenericArgs::identity_for_item(genv, adt_def_id)?, + fields, + idxs, + ); Ok(rty::Binder::new(variant, env.pop_layer().into_bound_vars())) } @@ -535,21 +548,25 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> { let def_id = struct_def.owner_id.def_id; if let fhir::StructKind::Transparent { fields } = &struct_def.kind { + let adt_def = genv.adt_def(def_id)?; + let fields = fields .iter() .map(|field_def| cx.conv_ty(&mut env, &field_def.ty)) .try_collect()?; - let args = rty::GenericArgs::identity_for_item(genv, def_id)?; - let vars = env.pop_layer().into_bound_vars(); let idx = rty::Expr::tuple( (0..vars.len()) .map(|idx| rty::Expr::late_bvar(INNERMOST, idx as u32)) .collect_vec(), ); - let ret = rty::Ty::indexed(rty::BaseTy::adt(genv.adt_def(def_id)?, args), idx); - let variant = rty::VariantSig::new(fields, ret); + let variant = rty::VariantSig::new( + adt_def, + rty::GenericArgs::identity_for_item(genv, def_id)?, + fields, + idx, + ); Ok(rty::Opaqueness::Transparent(rty::Binder::new(variant, vars))) } else { Ok(rty::Opaqueness::Opaque) diff --git a/crates/flux-middle/src/queries.rs b/crates/flux-middle/src/queries.rs index e98398ff98..d17e769643 100644 --- a/crates/flux-middle/src/queries.rs +++ b/crates/flux-middle/src/queries.rs @@ -296,9 +296,8 @@ impl<'tcx> Queries<'tcx> { .iter() .map(|field| Ok(genv.lower_type_of(field.did)?.skip_binder())) .try_collect_vec::<_, QueryErr>()?; - let ret = genv.lower_type_of(def_id)?.skip_binder(); Refiner::default(genv, &genv.generics_of(def_id)?) - .refine_variant_def(&fields, &ret) + .refine_variant_def(def_id, &fields) }) .try_collect()?; Ok(rty::Opaqueness::Transparent(rty::EarlyBinder(variants))) diff --git a/crates/flux-middle/src/rty/fold.rs b/crates/flux-middle/src/rty/fold.rs index 5f4063bd35..f9e96d03be 100644 --- a/crates/flux-middle/src/rty/fold.rs +++ b/crates/flux-middle/src/rty/fold.rs @@ -633,19 +633,16 @@ impl TypeVisitable for VariantSig { self.fields .iter() .try_for_each(|ty| ty.visit_with(visitor))?; - self.ret.visit_with(visitor) + self.idx.visit_with(visitor) } } impl TypeFoldable for VariantSig { fn try_fold_with(&self, folder: &mut F) -> Result { - let fields = self - .fields - .iter() - .map(|ty| ty.try_fold_with(folder)) - .try_collect()?; - let ret = self.ret.try_fold_with(folder)?; - Ok(VariantSig::new(fields, ret)) + let args = self.args.try_fold_with(folder)?; + let fields = self.fields.try_fold_with(folder)?; + let idx = self.idx.try_fold_with(folder)?; + Ok(VariantSig::new(self.adt_def.clone(), args, fields, idx)) } } diff --git a/crates/flux-middle/src/rty/mod.rs b/crates/flux-middle/src/rty/mod.rs index 6f770b0280..3fdf86fa6a 100644 --- a/crates/flux-middle/src/rty/mod.rs +++ b/crates/flux-middle/src/rty/mod.rs @@ -234,8 +234,10 @@ pub type PolyVariant = Binder; #[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)] pub struct VariantSig { + pub adt_def: AdtDef, + pub args: GenericArgs, pub fields: List, - pub ret: Ty, + pub idx: Expr, } #[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)] @@ -823,6 +825,10 @@ impl EarlyBinder { EarlyBinder(f(self.0)) } + pub fn try_map(self, f: impl FnOnce(T) -> Result) -> Result, E> { + Ok(EarlyBinder(f(self.0)?)) + } + pub fn skip_binder(self) -> T { self.0 } @@ -943,13 +949,18 @@ impl EarlyBinder { } impl VariantSig { - pub fn new(fields: Vec, ret: Ty) -> Self { - VariantSig { fields: List::from_vec(fields), ret } + pub fn new(adt_def: AdtDef, args: GenericArgs, fields: List, idx: Expr) -> Self { + VariantSig { adt_def, args, fields, idx } } pub fn fields(&self) -> &[Ty] { &self.fields } + + pub fn ret(&self) -> Ty { + let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone()); + Ty::indexed(bty, self.idx.clone()) + } } impl TupleTree @@ -1138,9 +1149,9 @@ impl Opaqueness> { impl EarlyBinder { pub fn to_poly_fn_sig(&self) -> EarlyBinder { - self.as_ref().map(|poly_fn_sig| { - poly_fn_sig.as_ref().map(|variant| { - let ret = variant.ret.shift_in_escaping(1); + self.as_ref().map(|poly_variant| { + poly_variant.as_ref().map(|variant| { + let ret = variant.ret().shift_in_escaping(1); let output = Binder::new(FnOutput::new(ret, vec![]), List::empty()); FnSig::new(vec![], variant.fields.clone(), output) }) @@ -2096,7 +2107,7 @@ mod pretty { impl Pretty for VariantSig { fn fmt(&self, cx: &PPrintCx, f: &mut fmt::Formatter<'_>) -> fmt::Result { define_scoped!(cx, f); - w!("({:?}) -> {:?}", join!(", ", self.fields()), &self.ret) + w!("({:?}) -> {:?}", join!(", ", self.fields()), &self.idx) } } diff --git a/crates/flux-middle/src/rty/refining.rs b/crates/flux-middle/src/rty/refining.rs index 014aa0c4f1..2c2e4d517f 100644 --- a/crates/flux-middle/src/rty/refining.rs +++ b/crates/flux-middle/src/rty/refining.rs @@ -176,19 +176,17 @@ impl<'a, 'tcx> Refiner<'a, 'tcx> { pub(crate) fn refine_variant_def( &self, + adt_def_id: DefId, fields: &[rustc::ty::Ty], - ret: &rustc::ty::Ty, ) -> QueryResult { + let adt_def = self.adt_def(adt_def_id)?; let fields = fields.iter().map(|ty| self.refine_ty(ty)).try_collect()?; - let rustc::ty::TyKind::Adt(adt_def, args) = ret.kind() else { - bug!(); - }; - let args = self.iter_with_generic_params(self.generics, args, |param, arg| { - self.refine_generic_arg(param, arg) - })?; - let bty = rty::BaseTy::adt(self.adt_def(adt_def.did())?, args); - let ret = rty::Ty::indexed(bty, rty::Expr::unit()); - let value = rty::VariantSig::new(fields, ret); + let value = rty::VariantSig::new( + adt_def, + rty::GenericArgs::identity_for_item(self.genv, adt_def_id)?, + fields, + rty::Expr::unit(), + ); Ok(rty::Binder::new(value, List::empty())) } diff --git a/crates/flux-refineck/src/constraint_gen.rs b/crates/flux-refineck/src/constraint_gen.rs index c6298e9f77..85ec88f0e6 100644 --- a/crates/flux-refineck/src/constraint_gen.rs +++ b/crates/flux-refineck/src/constraint_gen.rs @@ -355,7 +355,7 @@ impl<'a, 'tcx> ConstrGen<'a, 'tcx> { let evars_sol = infcx.solve()?; rcx.replace_evars(&evars_sol); - Ok(variant.ret.replace_evars(&evars_sol)) + Ok(variant.ret().replace_evars(&evars_sol)) } pub(crate) fn check_mk_array( diff --git a/crates/flux-refineck/src/invariants.rs b/crates/flux-refineck/src/invariants.rs index 53241384d5..3932d441a9 100644 --- a/crates/flux-refineck/src/invariants.rs +++ b/crates/flux-refineck/src/invariants.rs @@ -55,8 +55,7 @@ fn check_invariant( let ty = rcx.unpack(ty, crate::refine_tree::AssumeInvariants::No); rcx.assume_invariants(&ty, checker_config.check_overflow); } - let (.., idx) = variant.ret.expect_adt(); - let pred = invariant.pred.replace_bound_expr(&idx.expr); + let pred = invariant.pred.replace_bound_expr(&variant.idx); rcx.check_pred(pred, Tag::new(ConstrReason::Other, DUMMY_SP)); } let mut fcx = FixpointCtxt::new(genv, def_id, KVarStore::default()); diff --git a/crates/flux-refineck/src/type_env/projection.rs b/crates/flux-refineck/src/type_env/projection.rs index 3a52911e4c..cd8aebd6eb 100644 --- a/crates/flux-refineck/src/type_env/projection.rs +++ b/crates/flux-refineck/src/type_env/projection.rs @@ -761,10 +761,9 @@ fn downcast_enum( .instantiate(args, &[]) .replace_bound_exprs_with(|sort, _| rcx.define_vars(sort)); - let (.., idx2) = variant_def.ret.expect_adt(); // FIXME(nilehmann) flatten indices let exprs1 = idx1.expr.expect_tuple(); - let exprs2 = idx2.expr.expect_tuple(); + let exprs2 = variant_def.idx.expect_tuple(); debug_assert_eq!(exprs1.len(), exprs2.len()); let constr = Expr::and(iter::zip(exprs1, exprs2).filter_map(|(e1, e2)| { if !e1.is_abs() && !e2.is_abs() {