Skip to content

Commit

Permalink
Make normalize_projections take an InferCtxt (#968)
Browse files Browse the repository at this point in the history
* Reorder some code in trait_impl_subtyping

* Move check_impl_against_trait to flux-refineck

* Move normalize_projection to extension trait

* Move projection to flux-infer

* Make normalize_projection take an InferCtxt
  • Loading branch information
nilehmann authored Jan 7, 2025
1 parent 561bc64 commit 5f3c207
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 178 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/flux-driver/src/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use flux_common::{bug, cache::QueryCache, dbg, iter::IterExt, result::ResultExt};
use flux_config as config;
use flux_errors::FluxSession;
use flux_fhir_analysis::compare_impl_item;
use flux_infer::fixpoint_encoding::FixQueryCache;
use flux_metadata::CStore;
use flux_middle::{fhir, global_env::GlobalEnv, queries::Providers, Specs};
Expand Down Expand Up @@ -229,7 +228,8 @@ impl<'genv, 'tcx> CrateChecker<'genv, 'tcx> {
}
DefKind::Impl { of_trait } => {
if of_trait {
compare_impl_item::check_impl_against_trait(self.genv, def_id)?;
refineck::compare_impl_item::check_impl_against_trait(self.genv, def_id)
.emit(&self.genv)?;
}
Ok(())
}
Expand Down
9 changes: 0 additions & 9 deletions crates/flux-fhir-analysis/locales/en-US.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,5 @@ fhir_analysis_generics_on_ty_param =
fhir_analysis_generics_on_self_ty =
generic arguments are not allowed on self type
# Check impl against trait errors

fhir_analysis_incompatible_sort =
implemented associated refinement `{$name}` has an incompatible sort for trait
.label = expected `{$expected}`, found `{$found}`
fhir_analysis_invalid_assoc_reft =
associated refinement `{$name}` is not a member of trait `{$trait_}`
fhir_analysis_missing_assoc_reft =
associated refinement `{$name}` is not defined in implementation of trait `{$trait_}`
19 changes: 16 additions & 3 deletions crates/flux-fhir-analysis/src/conv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ use rustc_target::spec::abi;
use rustc_trait_selection::traits;
use rustc_type_ir::DebruijnIndex;

use crate::compare_impl_item::errors::InvalidAssocReft;

/// Wrapper over a type implementing [`ConvPhase`]. We have this to implement most functionality as
/// inherent methods instead of defining them as default implementation in the trait definition.
#[repr(transparent)]
Expand Down Expand Up @@ -1997,7 +1995,7 @@ impl<'genv, 'tcx: 'genv, P: ConvPhase<'genv, 'tcx>> ConvCtxt<P> {
rty::AliasReft { trait_id, name: alias.name, args: List::from_vec(generic_args) };

let Some(fsort) = alias_reft.fsort(self.genv())? else {
return Err(self.emit(InvalidAssocReft::new(
return Err(self.emit(errors::InvalidAssocReft::new(
alias.path.span,
alias_reft.name,
format!("{:?}", alias.path),
Expand Down Expand Up @@ -2520,4 +2518,19 @@ mod errors {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(fhir_analysis_invalid_assoc_reft, code = E0999)]
pub struct InvalidAssocReft {
#[primary_span]
span: Span,
trait_: String,
name: Symbol,
}

impl InvalidAssocReft {
pub(crate) fn new(span: Span, name: Symbol, trait_: String) -> Self {
Self { span, trait_, name }
}
}
}
2 changes: 0 additions & 2 deletions crates/flux-fhir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ extern crate rustc_data_structures;
extern crate rustc_errors;

extern crate rustc_hir;
extern crate rustc_infer;
extern crate rustc_middle;
extern crate rustc_span;
extern crate rustc_target;
extern crate rustc_trait_selection;
extern crate rustc_type_ir;

pub mod compare_impl_item;
mod conv;
mod wf;

Expand Down
1 change: 1 addition & 0 deletions crates/flux-infer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ flux-config.workspace = true
flux-errors.workspace = true
flux-macros.workspace = true
flux-middle.workspace = true
flux-rustc-bridge.workspace = true

itertools.workspace = true
liquid-fixpoint.workspace = true
Expand Down
19 changes: 8 additions & 11 deletions crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use rustc_span::Span;
use crate::{
evars::{EVarState, EVarStore},
fixpoint_encoding::{FixQueryCache, FixpointCtxt, KVarEncoding, KVarGen},
projections::NormalizeExt as _,
refine_tree::{AssumeInvariants, Cursor, Marker, RefineTree, Scope, Unpacker},
};

Expand Down Expand Up @@ -101,6 +102,7 @@ impl<'genv, 'tcx> InferCtxtRootBuilder<'genv, 'tcx> {
self
}

/// When provided use `generic_args` to instantiate sorts
pub fn with_generic_args(mut self, generic_args: &GenericArgs) -> Self {
self.generic_args = Some(generic_args.clone());
self
Expand Down Expand Up @@ -451,16 +453,11 @@ impl<'genv, 'tcx> InferCtxtAt<'_, '_, 'genv, 'tcx> {
if let rty::ClauseKind::Projection(projection_pred) = clause.kind_skipping_binder() {
let impl_elem = BaseTy::projection(projection_pred.projection_ty)
.to_ty()
.normalize_projections(
self.infcx.genv,
self.infcx.region_infcx,
self.infcx.def_id,
)?;
let term = projection_pred.term.to_ty().normalize_projections(
self.infcx.genv,
self.infcx.region_infcx,
self.infcx.def_id,
)?;
.normalize_projections(self.infcx)?;
let term = projection_pred
.term
.to_ty()
.normalize_projections(self.infcx)?;

// TODO: does this really need to be invariant? https://github.com/flux-rs/flux/pull/478#issuecomment-1654035374
self.subtyping(&impl_elem, &term, reason)?;
Expand Down Expand Up @@ -962,7 +959,7 @@ impl<'a, E: LocEnv> Sub<'a, E> {
let alias_ty = pred.projection_ty.with_self_ty(bty.to_subset_ty_ctor());
let ty1 = BaseTy::Alias(AliasKind::Projection, alias_ty)
.to_ty()
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)?;
.normalize_projections(infcx)?;
let ty2 = pred.term.to_ty();
self.tys(infcx, &ty1, &ty2)?;
}
Expand Down
2 changes: 2 additions & 0 deletions crates/flux-infer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ extern crate rustc_infer;
extern crate rustc_macros;
extern crate rustc_middle;
extern crate rustc_span;
extern crate rustc_trait_selection;
extern crate rustc_type_ir;

mod evars;
pub mod fixpoint_encoding;
pub mod infer;
pub mod projections;
pub mod refine_tree;
Original file line number Diff line number Diff line change
@@ -1,54 +1,61 @@
use std::iter;

use flux_arc_interner::List;
use flux_common::{bug, tracked_span_bug};
use flux_middle::{
global_env::GlobalEnv,
queries::{QueryErr, QueryResult},
rty::{
fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable, TypeVisitable},
refining::Refiner,
subst::{GenericsSubstDelegate, GenericsSubstFolder},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, ConstKind,
EarlyBinder, Expr, ExprKind, GenericArg, List, ProjectionPredicate, RefineArgs, Region,
Sort, SubsetTy, SubsetTyCtor, Ty, TyKind,
},
};
use flux_rustc_bridge::{lowering::Lower, ToRustc};
use rustc_hir::def_id::DefId;
use rustc_infer::{infer::InferCtxt, traits::Obligation};
use rustc_infer::traits::Obligation;
use rustc_middle::{
traits::{ImplSource, ObligationCause},
ty::TyCtxt,
};
use rustc_trait_selection::traits::SelectionContext;

use super::{
fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable},
subst::{GenericsSubstDelegate, GenericsSubstFolder},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, EarlyBinder, Expr,
ExprKind, GenericArg, ProjectionPredicate, RefineArgs, Region, Sort, SubsetTy, SubsetTyCtor,
Ty, TyKind,
};
use crate::{
global_env::GlobalEnv,
queries::{QueryErr, QueryResult},
rty::{fold::TypeVisitable, refining::Refiner},
};
use crate::infer::InferCtxt;

pub trait NormalizeExt: TypeFoldable {
fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self>;

Check warning on line 28 in crates/flux-infer/src/projections.rs

View workflow job for this annotation

GitHub Actions / clippy

this lifetime isn't used in the function definition

warning: this lifetime isn't used in the function definition --> crates/flux-infer/src/projections.rs:28:30 | 28 | fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self>; | ^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#extra_unused_lifetimes = note: `#[warn(clippy::extra_unused_lifetimes)]` on by default
}

pub(crate) struct Normalizer<'genv, 'tcx, 'cx> {
genv: GlobalEnv<'genv, 'tcx>,
selcx: SelectionContext<'cx, 'tcx>,
def_id: DefId,
impl<T: TypeFoldable> NormalizeExt for T {
fn normalize_projections<'tcx>(&self, infcx: &mut InferCtxt) -> QueryResult<Self> {
let mut normalizer = Normalizer::new(infcx.branch())?;
self.erase_regions().try_fold_with(&mut normalizer)
}
}

struct Normalizer<'infcx, 'genv, 'tcx> {
infcx: InferCtxt<'infcx, 'genv, 'tcx>,
selcx: SelectionContext<'infcx, 'tcx>,
param_env: List<Clause>,
}

impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
pub(crate) fn new(
genv: GlobalEnv<'genv, 'tcx>,
infcx: &'cx InferCtxt<'tcx>,
callsite_def_id: DefId,
) -> QueryResult<Self> {
let param_env = genv
.predicates_of(callsite_def_id)?
impl<'infcx, 'genv, 'tcx> Normalizer<'infcx, 'genv, 'tcx> {
fn new(infcx: InferCtxt<'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
let param_env = infcx
.genv
.predicates_of(infcx.def_id)?
.instantiate_identity()
.predicates
.clone();
let selcx = SelectionContext::new(infcx);
Ok(Normalizer { genv, selcx, def_id: callsite_def_id, param_env })
let selcx = SelectionContext::new(infcx.region_infcx);
Ok(Normalizer { infcx, selcx, param_env })
}

fn get_impl_id_of_alias_reft(&mut self, alias_reft: &AliasReft) -> QueryResult<Option<DefId>> {
let tcx = self.tcx();
let def_id = self.def_id;
let def_id = self.def_id();
let selcx = &mut self.selcx;

let trait_pred = Obligation::new(
Expand All @@ -71,7 +78,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
) -> QueryResult<Expr> {
if let Some(impl_def_id) = self.get_impl_id_of_alias_reft(alias_reft)? {
let impl_trait_ref = self
.genv
.genv()
.impl_trait_ref(impl_def_id)?
.unwrap()
.skip_binder();
Expand All @@ -86,7 +93,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
let tcx = self.tcx();

let pred = self
.genv
.genv()
.assoc_refinement_def(impl_def_id, alias_reft.name)?
.instantiate(tcx, &args, &[]);

Expand All @@ -106,7 +113,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
) -> QueryResult<SubsetTyCtor> {
let projection_ty = obligation.to_rustc(self.tcx());
let cause = ObligationCause::dummy();
let param_env = self.tcx().param_env(self.def_id);
let param_env = self.rustc_param_env();

let ty = rustc_trait_selection::traits::normalize_projection_ty(
&mut self.selcx,
Expand All @@ -118,7 +125,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
)
.expect_type();
let rustc_ty = ty.lower(self.tcx()).unwrap();
Ok(Refiner::default_for_item(self.genv, self.def_id)?
Ok(Refiner::default_for_item(self.genv(), self.def_id())?
.refine_ty_or_base(&rustc_ty)?
.expect_base())
}
Expand All @@ -139,7 +146,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
return Ok((ty != orig_ty, ty));
}
if candidates.len() > 1 {
bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id);
bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
}
let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
Ok((true, ctor))
Expand Down Expand Up @@ -173,7 +180,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
impl_def_id: DefId,
) -> QueryResult {
let mut projection_preds: Vec<_> = self
.genv
.genv()
.predicates_of(impl_def_id)?
.skip_binder()
.predicates
Expand Down Expand Up @@ -222,7 +229,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
// => {T -> {v. i32[v] | v > 0}, A -> Global}

let impl_trait_ref = self
.genv
.genv()
.impl_trait_ref(impl_def_id)?
.unwrap()
.skip_binder();
Expand Down Expand Up @@ -250,7 +257,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {

let tcx = self.tcx();
Ok(self
.genv
.genv()
.type_of(assoc_type_id)?
.instantiate(tcx, &args, &[])
.expect_subset_ty_ctor())
Expand Down Expand Up @@ -280,7 +287,7 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
&& let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
{
debug_assert!(!alias_ty.has_escaping_bvars());
let bounds = self.genv.item_bounds(alias_ty.def_id)?.instantiate(
let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
self.tcx(),
&alias_ty.args,
&alias_ty.refine_args,
Expand Down Expand Up @@ -316,12 +323,20 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
Ok(())
}

fn def_id(&self) -> DefId {
self.infcx.def_id
}

fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
self.infcx.genv
}

fn tcx(&self) -> TyCtxt<'tcx> {
self.selcx.tcx()
}

fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
self.selcx.tcx().param_env(self.def_id)
self.selcx.tcx().param_env(self.def_id())
}
}

Expand All @@ -348,7 +363,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
match sort {
Sort::Alias(AliasKind::Weak, alias_ty) => {
self.genv
self.genv()
.normalize_weak_alias_sort(alias_ty)?
.try_fold_with(self)
}
Expand All @@ -374,9 +389,9 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
match ty.kind() {
TyKind::Indexed(BaseTy::Alias(AliasKind::Weak, alias_ty), idx) => {
Ok(self
.genv
.genv()
.type_of(alias_ty.def_id)?
.instantiate(self.genv.tcx(), &alias_ty.args, &alias_ty.refine_args)
.instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
.expect_ctor()
.replace_bound_reft(idx))
}
Expand Down Expand Up @@ -427,7 +442,7 @@ impl FallibleTypeFolder for Normalizer<'_, '_, '_> {
c.to_rustc(self.tcx())
.normalize_internal(self.tcx(), self.rustc_param_env())
.lower(self.tcx())
.map_err(|e| QueryErr::unsupported(self.def_id, e.into_err()))
.map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
}
}

Expand Down Expand Up @@ -595,7 +610,7 @@ impl TVarSubst {
}

fn consts(&mut self, a: &Const, b: &Const) {
if let super::ConstKind::Param(param_const) = a.kind {
if let ConstKind::Param(param_const) = a.kind {
self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
}
}
Expand Down
Loading

0 comments on commit 5f3c207

Please sign in to comment.