Skip to content

Commit

Permalink
Avoid generating 1-tuples in fixpoint encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Nov 26, 2024
1 parent c5a037f commit c7d3f46
Showing 1 changed file with 54 additions and 56 deletions.
110 changes: 54 additions & 56 deletions crates/flux-infer/src/fixpoint_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,27 @@ impl SortEncodingCtxt {
fixpoint::Sort::App(fixpoint::SortCtor::Map, args)
}
rty::Sort::App(rty::SortCtor::Adt(sort_def), args) => {
self.declare_tuple(sort_def.fields());
let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sort_def.fields()));
let args = sort_def
.field_sorts(args)
.iter()
.map(|s| self.sort_to_fixpoint(s))
.collect_vec();
fixpoint::Sort::App(ctor, args)
let sorts = sort_def.field_sorts(args);
// do not generate 1-tuples
if let [sort] = &sorts[..] {
self.sort_to_fixpoint(sort)
} else {
self.declare_tuple(sorts.len());
let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect_vec();
fixpoint::Sort::App(ctor, args)
}
}
rty::Sort::Tuple(sorts) => {
self.declare_tuple(sorts.len());
let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect();
fixpoint::Sort::App(ctor, args)
// do not generate 1-tuples
if let [sort] = &sorts[..] {
self.sort_to_fixpoint(sort)
} else {
self.declare_tuple(sorts.len());
let ctor = fixpoint::SortCtor::Data(fixpoint::DataSort::Tuple(sorts.len()));
let args = sorts.iter().map(|s| self.sort_to_fixpoint(s)).collect();
fixpoint::Sort::App(ctor, args)
}
}
rty::Sort::Func(sort) => self.func_sort_to_fixpoint(sort),
rty::Sort::Var(k) => fixpoint::Sort::Var(k.index()),
Expand Down Expand Up @@ -933,25 +940,27 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
rty::ExprKind::Constant(c) => fixpoint::Expr::Constant(const_to_fixpoint(*c)),
rty::ExprKind::BinaryOp(op, e1, e2) => self.bin_op_to_fixpoint(op, e1, e2, scx)?,
rty::ExprKind::UnaryOp(op, e) => self.un_op_to_fixpoint(*op, e, scx)?,
rty::ExprKind::FieldProj(e, proj) => {
let proj = fixpoint::Expr::Var(self.proj_to_fixpoint(*proj, scx)?);
fixpoint::Expr::App(Box::new(proj), vec![self.expr_to_fixpoint(e, scx)?])
}
rty::ExprKind::FieldProj(e, proj) => self.proj_to_fixpoint(e, *proj, scx)?,
rty::ExprKind::Aggregate(_, flds) => {
scx.declare_tuple(flds.len());
let ctor = fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity: flds.len() });
let args = flds
.iter()
.map(|e| self.expr_to_fixpoint(e, scx))
.try_collect()?;
fixpoint::Expr::App(Box::new(ctor), args)
// do not generate 1-tuples
if let [fld] = &flds[..] {
self.expr_to_fixpoint(fld, scx)?
} else {
scx.declare_tuple(flds.len());
let ctor = fixpoint::Expr::Var(fixpoint::Var::TupleCtor { arity: flds.len() });
let args = flds
.iter()
.map(|fld| self.expr_to_fixpoint(fld, scx))
.try_collect()?;
fixpoint::Expr::App(Box::new(ctor), args)
}
}
rty::ExprKind::ConstDefId(did) => {
let var = self.register_rust_const(*did);
fixpoint::Expr::Var(var.into())
}
rty::ExprKind::App(func, args) => {
let func = self.func_to_fixpoint(func, scx)?;
let func = self.expr_to_fixpoint(func, scx)?;
let args = self.exprs_to_fixpoint(args, scx)?;
fixpoint::Expr::App(Box::new(func), args)
}
Expand Down Expand Up @@ -982,10 +991,18 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
let var = self.register_const_for_lambda(lam, scx);
fixpoint::Expr::Var(var.into())
}
rty::ExprKind::GlobalFunc(_, SpecFuncKind::Thy(sym)) => {
fixpoint::Expr::Var(fixpoint::Var::Itf(*sym))
}
rty::ExprKind::GlobalFunc(sym, SpecFuncKind::Uif) => {
fixpoint::Expr::Var(self.register_uif(*sym, scx).into())
}
rty::ExprKind::GlobalFunc(sym, SpecFuncKind::Def) => {
span_bug!(self.def_span, "unexpected global function `{sym}`. Function must be normalized away at this point")
}
rty::ExprKind::Hole(..)
| rty::ExprKind::KVar(_)
| rty::ExprKind::Local(_)
| rty::ExprKind::GlobalFunc(..)
| rty::ExprKind::PathProj(..)
| rty::ExprKind::ForAll(_) => {
span_bug!(self.def_span, "unexpected expr: `{expr:?}`")
Expand All @@ -1007,14 +1024,21 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {

fn proj_to_fixpoint(
&mut self,
e: &rty::Expr,
proj: rty::FieldProj,
scx: &mut SortEncodingCtxt,
) -> QueryResult<fixpoint::Var> {
) -> QueryResult<fixpoint::Expr> {
let arity = proj.arity(self.genv)?;
let field = proj.field_idx();

scx.declare_tuple(arity);
Ok(fixpoint::Var::TupleProj { arity, field })
// we encode 1-tuples as the single element inside so no projection necessary here
if arity == 1 {
self.expr_to_fixpoint(e, scx)
} else {
let field = proj.field_idx();
scx.declare_tuple(arity);
let proj = fixpoint::Var::TupleProj { arity, field };
let proj = fixpoint::Expr::Var(proj);
Ok(fixpoint::Expr::App(Box::new(proj), vec![self.expr_to_fixpoint(e, scx)?]))
}
}

fn un_op_to_fixpoint(
Expand Down Expand Up @@ -1176,32 +1200,6 @@ impl<'genv, 'tcx> ExprEncodingCtxt<'genv, 'tcx> {
))
}

fn func_to_fixpoint(
&mut self,
func: &rty::Expr,
scx: &mut SortEncodingCtxt,
) -> QueryResult<fixpoint::Expr> {
match func.kind() {
rty::ExprKind::Var(var) => Ok(fixpoint::Expr::Var(self.var_to_fixpoint(var))),
rty::ExprKind::GlobalFunc(_, SpecFuncKind::Thy(sym)) => {
Ok(fixpoint::Expr::Var(fixpoint::Var::Itf(*sym)))
}
rty::ExprKind::GlobalFunc(sym, SpecFuncKind::Uif) => {
Ok(fixpoint::Expr::Var(self.register_uif(*sym, scx).into()))
}
rty::ExprKind::FieldProj(e, proj) => {
let proj = fixpoint::Expr::Var(self.proj_to_fixpoint(*proj, scx)?);
Ok(fixpoint::Expr::App(Box::new(proj), vec![self.func_to_fixpoint(e, scx)?]))
}
rty::ExprKind::GlobalFunc(sym, SpecFuncKind::Def) => {
span_bug!(self.def_span, "unexpected global function `{sym}`. Function must be normalized away at this point")
}
_ => {
span_bug!(self.def_span, "unexpected expr `{func:?}` in function position")
}
}
}

fn imm(
&mut self,
arg: &rty::Expr,
Expand Down

0 comments on commit c7d3f46

Please sign in to comment.