Skip to content

Commit

Permalink
Represent variants as set of equations in rty
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Nov 19, 2023
1 parent 14e8687 commit 3083b9f
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 40 deletions.
33 changes: 25 additions & 8 deletions crates/flux-fhir-analysis/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,20 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> {
enum_def
.variants
.iter()
.map(|variant_def| ConvCtxt::conv_enum_variant(genv, variant_def, wfckresults))
.map(|variant_def| {
ConvCtxt::conv_enum_variant(
genv,
enum_def.owner_id.to_def_id(),
variant_def,
wfckresults,
)
})
.try_collect()
}

fn conv_enum_variant(
genv: &GlobalEnv,
adt_def_id: DefId,
variant: &fhir::VariantDef,
wfckresults: &fhir::WfckResults,
) -> QueryResult<rty::PolyVariant> {
Expand All @@ -512,14 +520,19 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> {
let mut env = Env::new(&[]);
env.push_layer(Layer::list(&cx, 0, &variant.params, true));

let adt_def = genv.adt_def(adt_def_id)?;
let fields = variant
.fields
.iter()
.map(|field| cx.conv_ty(&mut env, &field.ty))
.try_collect()?;
let args = rty::Index::from(cx.conv_refine_arg(&mut env, &variant.ret.idx));
let ret = cx.conv_indexed_type(&mut env, &variant.ret.bty, args)?;
let variant = rty::VariantSig::new(fields, ret);
let idxs = cx.conv_refine_arg(&mut env, &variant.ret.idx).0;
let variant = rty::VariantSig::new(
adt_def,
rty::GenericArgs::identity_for_item(genv, adt_def_id)?,
fields,
idxs,
);

Ok(rty::Binder::new(variant, env.pop_layer().into_bound_vars()))
}
Expand All @@ -535,21 +548,25 @@ impl<'a, 'tcx> ConvCtxt<'a, 'tcx> {

let def_id = struct_def.owner_id.def_id;
if let fhir::StructKind::Transparent { fields } = &struct_def.kind {
let adt_def = genv.adt_def(def_id)?;

let fields = fields
.iter()
.map(|field_def| cx.conv_ty(&mut env, &field_def.ty))
.try_collect()?;

let args = rty::GenericArgs::identity_for_item(genv, def_id)?;

let vars = env.pop_layer().into_bound_vars();
let idx = rty::Expr::tuple(
(0..vars.len())
.map(|idx| rty::Expr::late_bvar(INNERMOST, idx as u32))
.collect_vec(),
);
let ret = rty::Ty::indexed(rty::BaseTy::adt(genv.adt_def(def_id)?, args), idx);
let variant = rty::VariantSig::new(fields, ret);
let variant = rty::VariantSig::new(
adt_def,
rty::GenericArgs::identity_for_item(genv, def_id)?,
fields,
idx,
);
Ok(rty::Opaqueness::Transparent(rty::Binder::new(variant, vars)))
} else {
Ok(rty::Opaqueness::Opaque)
Expand Down
3 changes: 1 addition & 2 deletions crates/flux-middle/src/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,8 @@ impl<'tcx> Queries<'tcx> {
.iter()
.map(|field| Ok(genv.lower_type_of(field.did)?.skip_binder()))
.try_collect_vec::<_, QueryErr>()?;
let ret = genv.lower_type_of(def_id)?.skip_binder();
Refiner::default(genv, &genv.generics_of(def_id)?)
.refine_variant_def(&fields, &ret)
.refine_variant_def(def_id, &fields)
})
.try_collect()?;
Ok(rty::Opaqueness::Transparent(rty::EarlyBinder(variants)))
Expand Down
13 changes: 5 additions & 8 deletions crates/flux-middle/src/rty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,19 +633,16 @@ impl TypeVisitable for VariantSig {
self.fields
.iter()
.try_for_each(|ty| ty.visit_with(visitor))?;
self.ret.visit_with(visitor)
self.idx.visit_with(visitor)
}
}

impl TypeFoldable for VariantSig {
fn try_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
let fields = self
.fields
.iter()
.map(|ty| ty.try_fold_with(folder))
.try_collect()?;
let ret = self.ret.try_fold_with(folder)?;
Ok(VariantSig::new(fields, ret))
let args = self.args.try_fold_with(folder)?;
let fields = self.fields.try_fold_with(folder)?;
let idx = self.idx.try_fold_with(folder)?;
Ok(VariantSig::new(self.adt_def.clone(), args, fields, idx))
}
}

Expand Down
25 changes: 18 additions & 7 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,10 @@ pub type PolyVariant = Binder<VariantSig>;

#[derive(Clone, Eq, PartialEq, Hash, TyEncodable, TyDecodable)]
pub struct VariantSig {
pub adt_def: AdtDef,
pub args: GenericArgs,
pub fields: List<Ty>,
pub ret: Ty,
pub idx: Expr,
}

#[derive(Clone, PartialEq, Eq, Hash, Debug, TyEncodable, TyDecodable)]
Expand Down Expand Up @@ -823,6 +825,10 @@ impl<T> EarlyBinder<T> {
EarlyBinder(f(self.0))
}

pub fn try_map<U, E>(self, f: impl FnOnce(T) -> Result<U, E>) -> Result<EarlyBinder<U>, E> {
Ok(EarlyBinder(f(self.0)?))
}

pub fn skip_binder(self) -> T {
self.0
}
Expand Down Expand Up @@ -943,13 +949,18 @@ impl EarlyBinder<GenericPredicates> {
}

impl VariantSig {
pub fn new(fields: Vec<Ty>, ret: Ty) -> Self {
VariantSig { fields: List::from_vec(fields), ret }
pub fn new(adt_def: AdtDef, args: GenericArgs, fields: List<Ty>, idx: Expr) -> Self {
VariantSig { adt_def, args, fields, idx }
}

pub fn fields(&self) -> &[Ty] {
&self.fields
}

pub fn ret(&self) -> Ty {
let bty = BaseTy::Adt(self.adt_def.clone(), self.args.clone());
Ty::indexed(bty, self.idx.clone())
}
}

impl<T> TupleTree<T>
Expand Down Expand Up @@ -1138,9 +1149,9 @@ impl<T, E> Opaqueness<Result<T, E>> {

impl EarlyBinder<PolyVariant> {
pub fn to_poly_fn_sig(&self) -> EarlyBinder<PolyFnSig> {
self.as_ref().map(|poly_fn_sig| {
poly_fn_sig.as_ref().map(|variant| {
let ret = variant.ret.shift_in_escaping(1);
self.as_ref().map(|poly_variant| {
poly_variant.as_ref().map(|variant| {
let ret = variant.ret().shift_in_escaping(1);
let output = Binder::new(FnOutput::new(ret, vec![]), List::empty());
FnSig::new(vec![], variant.fields.clone(), output)
})
Expand Down Expand Up @@ -2096,7 +2107,7 @@ mod pretty {
impl Pretty for VariantSig {
fn fmt(&self, cx: &PPrintCx, f: &mut fmt::Formatter<'_>) -> fmt::Result {
define_scoped!(cx, f);
w!("({:?}) -> {:?}", join!(", ", self.fields()), &self.ret)
w!("({:?}) -> {:?}", join!(", ", self.fields()), &self.idx)
}
}

Expand Down
18 changes: 8 additions & 10 deletions crates/flux-middle/src/rty/refining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,17 @@ impl<'a, 'tcx> Refiner<'a, 'tcx> {

pub(crate) fn refine_variant_def(
&self,
adt_def_id: DefId,
fields: &[rustc::ty::Ty],
ret: &rustc::ty::Ty,
) -> QueryResult<rty::PolyVariant> {
let adt_def = self.adt_def(adt_def_id)?;
let fields = fields.iter().map(|ty| self.refine_ty(ty)).try_collect()?;
let rustc::ty::TyKind::Adt(adt_def, args) = ret.kind() else {
bug!();
};
let args = self.iter_with_generic_params(self.generics, args, |param, arg| {
self.refine_generic_arg(param, arg)
})?;
let bty = rty::BaseTy::adt(self.adt_def(adt_def.did())?, args);
let ret = rty::Ty::indexed(bty, rty::Expr::unit());
let value = rty::VariantSig::new(fields, ret);
let value = rty::VariantSig::new(
adt_def,
rty::GenericArgs::identity_for_item(self.genv, adt_def_id)?,
fields,
rty::Expr::unit(),
);
Ok(rty::Binder::new(value, List::empty()))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/flux-refineck/src/constraint_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ impl<'a, 'tcx> ConstrGen<'a, 'tcx> {
let evars_sol = infcx.solve()?;
rcx.replace_evars(&evars_sol);

Ok(variant.ret.replace_evars(&evars_sol))
Ok(variant.ret().replace_evars(&evars_sol))
}

pub(crate) fn check_mk_array(
Expand Down
3 changes: 1 addition & 2 deletions crates/flux-refineck/src/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ fn check_invariant(
let ty = rcx.unpack(ty, crate::refine_tree::AssumeInvariants::No);
rcx.assume_invariants(&ty, checker_config.check_overflow);
}
let (.., idx) = variant.ret.expect_adt();
let pred = invariant.pred.replace_bound_expr(&idx.expr);
let pred = invariant.pred.replace_bound_expr(&variant.idx);
rcx.check_pred(pred, Tag::new(ConstrReason::Other, DUMMY_SP));
}
let mut fcx = FixpointCtxt::new(genv, def_id, KVarStore::default());
Expand Down
3 changes: 1 addition & 2 deletions crates/flux-refineck/src/type_env/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,10 +761,9 @@ fn downcast_enum(
.instantiate(args, &[])
.replace_bound_exprs_with(|sort, _| rcx.define_vars(sort));

let (.., idx2) = variant_def.ret.expect_adt();
// FIXME(nilehmann) flatten indices
let exprs1 = idx1.expr.expect_tuple();
let exprs2 = idx2.expr.expect_tuple();
let exprs2 = variant_def.idx.expect_tuple();
debug_assert_eq!(exprs1.len(), exprs2.len());
let constr = Expr::and(iter::zip(exprs1, exprs2).filter_map(|(e1, e2)| {
if !e1.is_abs() && !e2.is_abs() {
Expand Down

0 comments on commit 3083b9f

Please sign in to comment.