Skip to content

Commit

Permalink
Make normalize_projection take an InferCtxt
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Jan 6, 2025
1 parent dd5b938 commit 6b63c42
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 122 deletions.
3 changes: 2 additions & 1 deletion crates/flux-driver/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {
}
DefKind::Impl { of_trait } => {
if of_trait {
refineck::compare_impl_item::check_impl_against_trait(self.genv, def_id)?;
refineck::compare_impl_item::check_impl_against_trait(self.genv, def_id)
.emit(&self.genv)?;
}
Ok(())
}
Expand Down
17 changes: 6 additions & 11 deletions crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,11 @@ impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
let impl_elem = BaseTy::projection(projection_pred.projection_ty)
.to_ty()
.normalize_projections(
self.infcx.genv,
self.infcx.region_infcx,
self.infcx.def_id,
)?;
let term = projection_pred.term.to_ty().normalize_projections(
self.infcx.genv,
self.infcx.region_infcx,
self.infcx.def_id,
)?;
.normalize_projections(self.infcx)?;
let term = projection_pred
.term
.to_ty()
.normalize_projections(self.infcx)?;

// TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374
self.subtyping(&impl_elem, &term, reason)?;
Expand Down Expand Up @@ -964,7 +959,7 @@ impl<'a, E: LocEnv> Sub<'a, E> {
let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
.to_ty()
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)?;
.normalize_projections(infcx)?;
let ty2 = pred.term.to_ty();
self.tys(infcx, &ty1, &ty2)?;
}
Expand Down
82 changes: 39 additions & 43 deletions crates/flux-infer/src/projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,47 @@ use flux_middle::{
};
use flux_rustc_bridge::{lowering::Lower, ToRustc};
use rustc_hir::def_id::DefId;
use rustc_infer::{infer::InferCtxt, traits::Obligation};
use rustc_infer::traits::Obligation;
use rustc_middle::{
traits::{ImplSource, ObligationCause},
ty::TyCtxt,
};
use rustc_trait_selection::traits::SelectionContext;

use crate::infer::InferCtxt;

pub trait NormalizeExt: TypeFoldable {
fn normalize_projections<'tcx>(
&self,
genv: GlobalEnv<'_, 'tcx>,
infcx: &rustc_infer::infer::InferCtxt<'tcx>,
callsite_def_id: DefId,
) -> QueryResult<Self>;
fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self>;

Check warning on line 28 in crates/flux-infer/src/projections.rs

View workflow job for this annotation

GitHub Actions / clippy

this lifetime isn't used in the function definition

warning: this lifetime isn't used in the function definition --> crates/flux-infer/src/projections.rs:28:30 | 28 | fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self>; | ^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#extra_unused_lifetimes = note: `#[warn(clippy::extra_unused_lifetimes)]` on by default
}

impl<T: TypeFoldable> NormalizeExt for T {
fn normalize_projections<'tcx>(
&self,
genv: GlobalEnv<'_, 'tcx>,
infcx: &rustc_infer::infer::InferCtxt<'tcx>,
callsite_def_id: DefId,
) -> QueryResult<Self> {
let mut normalizer = Normalizer::new(genv, infcx, callsite_def_id)?;
fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self> {
let mut normalizer = Normalizer::new(infcx.branch())?;
self.erase_regions().try_fold_with(&mut normalizer)
}
}

struct Normalizer<'genv, 'tcx, 'cx> {
genv: GlobalEnv<'genv, 'tcx>,
selcx: SelectionContext<'cx, 'tcx>,
def_id: DefId,
struct Normalizer<'infcx, 'genv, 'tcx> {
infcx: InferCtxt<'infcx, 'genv, 'tcx>,
selcx: SelectionContext<'infcx, 'tcx>,
param_env: List<Clause>,
}

impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
fn new(
genv: GlobalEnv<'genv, 'tcx>,
infcx: &'cx InferCtxt<'tcx>,
callsite_def_id: DefId,
) -> QueryResult<Self> {
let param_env = genv
.predicates_of(callsite_def_id)?
impl<'infcx, 'genv, 'tcx> Normalizer<'infcx, 'genv, 'tcx> {
fn new(infcx: InferCtxt<'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
let param_env = infcx
.genv
.predicates_of(infcx.def_id)?
.instantiate_identity()
.predicates
.clone();
let selcx = SelectionContext::new(infcx);
Ok(Normalizer { genv, selcx, def_id: callsite_def_id, param_env })
let selcx = SelectionContext::new(infcx.region_infcx);
Ok(Normalizer { infcx, selcx, param_env })
}

fn get_impl_id_of_alias_reft(&mut self, alias_reft: &AliasReft) -> QueryResult<Option<DefId>> {
let tcx = self.tcx();
let def_id = self.def_id;
let def_id = self.def_id();
let selcx = &mut self.selcx;

let trait_pred = Obligation::new(
Expand All @@ -90,7 +78,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
) -> QueryResult<Expr> {
if let Some(impl_def_id) = self.get_impl_id_of_alias_reft(alias_reft)? {
let impl_trait_ref = self
.genv
.genv()
.impl_trait_ref(impl_def_id)?
.unwrap()
.skip_binder();
Expand All @@ -105,7 +93,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
let tcx = self.tcx();

let pred = self
.genv
.genv()
.assoc_refinement_def(impl_def_id, alias_reft.name)?
.instantiate(tcx, &args, &[]);

Expand All @@ -125,7 +113,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
) -> QueryResult<SubsetTyCtor> {
let projection_ty = obligation.to_rustc(self.tcx());
let cause = ObligationCause::dummy();
let param_env = self.tcx().param_env(self.def_id);
let param_env = self.rustc_param_env();

let ty = rustc_trait_selection::traits::normalize_projection_ty(
&mut self.selcx,
Expand All @@ -137,7 +125,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
)
.expect_type();
let rustc_ty = ty.lower(self.tcx()).unwrap();
Ok(Refiner::default_for_item(self.genv, self.def_id)?
Ok(Refiner::default_for_item(self.genv(), self.def_id())?
.refine_ty_or_base(&rustc_ty)?
.expect_base())
}
Expand All @@ -158,7 +146,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
return Ok((ty != orig_ty, ty));
}
if candidates.len() > 1 {
bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id);
bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
}
let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
Ok((true, ctor))
Expand Down Expand Up @@ -192,7 +180,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
impl_def_id: DefId,
) -> QueryResult {
let mut projection_preds: Vec<_> = self
.genv
.genv()
.predicates_of(impl_def_id)?
.skip_binder()
.predicates
Expand Down Expand Up @@ -241,7 +229,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
// => {T -> {v. i32[v] | v > 0}, A -> Global}

let impl_trait_ref = self
.genv
.genv()
.impl_trait_ref(impl_def_id)?
.unwrap()
.skip_binder();
Expand Down Expand Up @@ -269,7 +257,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {

let tcx = self.tcx();
Ok(self
.genv
.genv()
.type_of(assoc_type_id)?
.instantiate(tcx, &args, &[])
.expect_subset_ty_ctor())
Expand Down Expand Up @@ -299,7 +287,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
&& let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
{
debug_assert!(!alias_ty.has_escaping_bvars());
let bounds = self.genv.item_bounds(alias_ty.def_id)?.instantiate(
let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
self.tcx(),
&alias_ty.args,
&alias_ty.refine_args,
Expand Down Expand Up @@ -335,12 +323,20 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
Ok(())
}

fn def_id(&self) -> DefId {
self.infcx.def_id
}

fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
self.infcx.genv
}

fn tcx(&self) -> TyCtxt<'tcx> {
self.selcx.tcx()
}

fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
self.selcx.tcx().param_env(self.def_id)
self.selcx.tcx().param_env(self.def_id())
}
}

Expand All @@ -367,7 +363,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
match sort {
Sort::Alias(AliasKind::Weak, alias_ty) => {
self.genv
self.genv()
.normalize_weak_alias_sort(alias_ty)?
.try_fold_with(self)
}
Expand All @@ -393,9 +389,9 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
match ty.kind() {
TyKind::Indexed(BaseTy::Alias(AliasKind::Weak, alias_ty), idx) => {
Ok(self
.genv
.genv()
.type_of(alias_ty.def_id)?
.instantiate(self.genv.tcx(), &alias_ty.args, &alias_ty.refine_args)
.instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
.expect_ctor()
.replace_bound_reft(idx))
}
Expand Down Expand Up @@ -446,7 +442,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
c.to_rustc(self.tcx())
.normalize_internal(self.tcx(), self.rustc_param_env())
.lower(self.tcx())
.map_err(|e| QueryErr::unsupported(self.def_id, e.into_err()))
.map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
}
}

Expand Down
25 changes: 13 additions & 12 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ impl<'ck, 'genv, 'tcx> Checker<'ck, 'genv, 'tcx, ShapeMode> {
let inherited = Inherited::new(&mut mode, ghost_stmts)?;

let body = genv.mir(local_id).with_span(span)?;
let infcx = root_ctxt.infcx(def_id, &body.infcx);
let mut infcx = root_ctxt.infcx(def_id, &body.infcx);
let poly_sig = genv
.fn_sig(local_id)
.with_span(span)?
.instantiate_identity()
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)
.normalize_projections(&mut infcx)
.with_span(span)?;
Checker::run(infcx, local_id, inherited, poly_sig)?;

Expand All @@ -194,12 +194,12 @@ impl<'ck, 'genv, 'tcx> Checker<'ck, 'genv, 'tcx, RefineMode> {
let mut mode = RefineMode { bb_envs };
let inherited = Inherited::new(&mut mode, ghost_stmts)?;
let body = genv.mir(local_id).with_span(span)?;
let infcx = root_ctxt.infcx(def_id, &body.infcx);
let mut infcx = root_ctxt.infcx(def_id, &body.infcx);
let poly_sig = genv
.fn_sig(def_id)
.with_span(span)?
.instantiate_identity()
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)
.normalize_projections(&mut infcx)
.with_span(span)?;
Checker::run(infcx, local_id, inherited, poly_sig)?;

Expand Down Expand Up @@ -234,7 +234,7 @@ fn check_fn_subtyping(

let super_sig = super_sig
.replace_bound_vars(|_| rty::ReErased, |sort, _| infcx.define_vars(sort))
.normalize_projections(infcx.genv, infcx.region_infcx, *def_id)?;
.normalize_projections(&mut infcx)?;

// 1. Unpack `T_g` input types
let actuals = super_sig
Expand All @@ -253,7 +253,7 @@ fn check_fn_subtyping(
let sub_sig = sub_sig.instantiate(tcx, sub_args, &refine_args);
let sub_sig = sub_sig
.replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode))
.normalize_projections(infcx.genv, infcx.region_infcx, *def_id)?;
.normalize_projections(infcx)?;

// 3. INPUT subtyping (g-input <: f-input)
for requires in super_sig.requires() {
Expand Down Expand Up @@ -310,8 +310,9 @@ pub(crate) fn trait_impl_subtyping<'genv, 'tcx>(
let Some((impl_trait_ref, trait_method_id)) = find_trait_item(genv, def_id)? else {
return Ok(None);
};
let impl_method_id = def_id.to_def_id();
// Skip the check if either the trait-method or the impl-method are marked as `trusted_impl`
if genv.has_trusted_impl(trait_method_id) || genv.has_trusted_impl(def_id.to_def_id()) {
if genv.has_trusted_impl(trait_method_id) || genv.has_trusted_impl(impl_method_id) {
return Ok(None);
}

Expand All @@ -328,13 +329,13 @@ pub(crate) fn trait_impl_subtyping<'genv, 'tcx>(
.tcx()
.infer_ctxt()
.build(TypingMode::non_body_analysis());
let mut infcx = root_ctxt.infcx(trait_method_id, &rustc_infcx);
let mut infcx = root_ctxt.infcx(impl_method_id, &rustc_infcx);

let trait_fn_sig = genv.fn_sig(trait_method_id)?;
let impl_sig = genv.fn_sig(def_id)?;
let impl_sig = genv.fn_sig(impl_method_id)?;
check_fn_subtyping(
&mut infcx,
&def_id.to_def_id(),
&impl_method_id,
impl_sig,
&impl_args,
&trait_fn_sig.instantiate(tcx, &trait_args, &trait_refine_args),
Expand Down Expand Up @@ -422,7 +423,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {

let fn_sig = poly_sig
.replace_bound_vars(|_| rty::ReErased, |sort, _| infcx.define_vars(sort))
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)
.normalize_projections(&mut infcx)
.with_span(span)?;

let mut env = TypeEnv::new(&mut infcx, &body, &fn_sig);
Expand Down Expand Up @@ -782,7 +783,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
let fn_sig = fn_sig
.instantiate(tcx, &generic_args, &refine_args)
.replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode))
.normalize_projections(genv, infcx.region_infcx, infcx.def_id)
.normalize_projections(infcx)
.with_span(span)?;

let mut at = infcx.at(span);
Expand Down
Loading

0 comments on commit 6b63c42

Please sign in to comment.