Skip to content

Commit

Permalink
More refinement parameters out of Generics
Browse files Browse the repository at this point in the history
  • Loading branch information
nilehmann committed Jan 5, 2024
1 parent 387a538 commit db49182
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 67 deletions.
23 changes: 10 additions & 13 deletions crates/flux-fhir-analysis/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ pub(crate) fn conv_opaque_ty(
}

pub(crate) fn conv_generics(
genv: &GlobalEnv,
rust_generics: &rustc::ty::Generics,
generics: &fhir::Generics,
refine_params: &[fhir::RefineParam],
is_trait: Option<LocalDefId>,
) -> QueryResult<rty::Generics> {
let opt_self = is_trait.map(|def_id| {
Expand Down Expand Up @@ -177,24 +175,23 @@ pub(crate) fn conv_generics(
}))
.collect();

let refine_params = refine_params
.iter()
.map(|param| conv_refine_param(genv, param))
.collect();

Ok(rty::Generics {
params,
refine_params,
parent: rust_generics.parent(),
parent_count: rust_generics.parent_count(),
parent_refine_count: rust_generics
.parent()
.map(|parent| genv.generics_of(parent))
.transpose()?
.map_or(0, |g| g.refine_count()),
})
}

pub(crate) fn conv_refinement_generics(
genv: &GlobalEnv,
params: &[fhir::RefineParam],
) -> List<rty::RefineParam> {
params
.iter()
.map(|param| conv_refine_param(genv, param))
.collect()
}

fn sort_args_for_adt(genv: &GlobalEnv, def_id: impl Into<DefId>) -> List<fhir::Sort> {
let mut sort_args = vec![];
for param in &genv.tcx.generics_of(def_id.into()).params {
Expand Down
30 changes: 19 additions & 11 deletions crates/flux-fhir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub fn provide(providers: &mut Providers) {
variants_of,
fn_sig,
generics_of,
refinement_generics_of,
predicates_of,
item_bounds,
};
Expand Down Expand Up @@ -160,29 +161,36 @@ fn generics_of(genv: &GlobalEnv, local_id: LocalDefId) -> QueryResult<rty::Gener
.map()
.get_generics(local_id)
.unwrap_or_else(|| bug!("no generics for {:?}", def_id));
let refine_params = genv
.map()
.get_refine_params(genv.tcx, local_id)
.unwrap_or(&[]);
conv::conv_generics(genv, &rustc_generics, generics, refine_params, is_trait)
conv::conv_generics(&rustc_generics, generics, is_trait)
}
DefKind::Closure | DefKind::Coroutine => {
Ok(rty::Generics {
params: List::empty(),
refine_params: List::empty(),
parent: rustc_generics.parent(),
parent_count: rustc_generics.parent_count(),
parent_refine_count: rustc_generics
.parent()
.map(|parent| genv.generics_of(parent))
.transpose()?
.map_or(0, |g| g.refine_count()),
})
}
kind => bug!("generics_of called on `{def_id:?}` with kind `{kind:?}`"),
}
}

fn refinement_generics_of(
genv: &GlobalEnv,
local_id: LocalDefId,
) -> QueryResult<rty::RefinementGenerics> {
let parent = genv.tcx.generics_of(local_id).parent;
let parent_count =
if let Some(def_id) = parent { genv.refinement_generics_of(def_id)?.count() } else { 0 };
match genv.tcx.def_kind(local_id) {
DefKind::Fn | DefKind::AssocFn => {
let fn_sig = genv.map().get_fn_sig(local_id);
let params = conv::conv_refinement_generics(genv, &fn_sig.params);
Ok(rty::RefinementGenerics { parent, parent_count, params })
}
_ => Ok(rty::RefinementGenerics { parent, parent_count, params: List::empty() }),
}
}

fn type_of(genv: &GlobalEnv, def_id: LocalDefId) -> QueryResult<rty::EarlyBinder<rty::PolyTy>> {
let ty = match genv.tcx.def_kind(def_id) {
DefKind::TyAlias { .. } => {
Expand Down
7 changes: 7 additions & 0 deletions crates/flux-middle/src/global_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ impl<'sess, 'tcx> GlobalEnv<'sess, 'tcx> {
self.queries.generics_of(self, def_id.into())
}

pub fn refinement_generics_of(
&self,
def_id: impl Into<DefId>,
) -> QueryResult<rty::RefinementGenerics> {
self.queries.refinement_generics_of(self, def_id.into())
}

pub fn predicates_of(
&self,
def_id: impl Into<DefId>,
Expand Down
21 changes: 20 additions & 1 deletion crates/flux-middle/src/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub struct Providers {
) -> QueryResult<rty::Opaqueness<rty::EarlyBinder<rty::PolyVariants>>>,
pub fn_sig: fn(&GlobalEnv, LocalDefId) -> QueryResult<rty::EarlyBinder<rty::PolyFnSig>>,
pub generics_of: fn(&GlobalEnv, LocalDefId) -> QueryResult<rty::Generics>,
pub refinement_generics_of: fn(&GlobalEnv, LocalDefId) -> QueryResult<rty::RefinementGenerics>,
pub predicates_of:
fn(&GlobalEnv, LocalDefId) -> QueryResult<rty::EarlyBinder<rty::GenericPredicates>>,
pub item_bounds: fn(&GlobalEnv, LocalDefId) -> QueryResult<rty::EarlyBinder<List<rty::Clause>>>,
Expand All @@ -79,6 +80,7 @@ impl Default for Providers {
variants_of: |_, _| empty_query!(),
fn_sig: |_, _| empty_query!(),
generics_of: |_, _| empty_query!(),
refinement_generics_of: |_, _| empty_query!(),
predicates_of: |_, _| empty_query!(),
item_bounds: |_, _| empty_query!(),
}
Expand All @@ -99,6 +101,7 @@ pub struct Queries<'tcx> {
check_wf: Cache<FluxLocalDefId, QueryResult<Rc<fhir::WfckResults>>>,
adt_def: Cache<DefId, QueryResult<rty::AdtDef>>,
generics_of: Cache<DefId, QueryResult<rty::Generics>>,
refinement_generics_of: Cache<DefId, QueryResult<rty::RefinementGenerics>>,
predicates_of: Cache<DefId, QueryResult<rty::EarlyBinder<rty::GenericPredicates>>>,
item_bounds: Cache<DefId, QueryResult<rty::EarlyBinder<List<rty::Clause>>>>,
type_of: Cache<DefId, QueryResult<rty::EarlyBinder<rty::PolyTy>>>,
Expand Down Expand Up @@ -239,7 +242,23 @@ impl<'tcx> Queries<'tcx> {
(self.providers.generics_of)(genv, local_id)
} else {
let generics = genv.lower_generics_of(def_id)?;
refining::refine_generics(genv, &generics)
refining::refine_generics(&generics)
}
})
}

pub(crate) fn refinement_generics_of(
&self,
genv: &GlobalEnv,
def_id: DefId,
) -> QueryResult<rty::RefinementGenerics> {
run_with_cache(&self.refinement_generics_of, def_id, || {
let def_id = genv.lookup_extern(def_id).unwrap_or(def_id);
if let Some(local_id) = def_id.as_local() {
(self.providers.refinement_generics_of)(genv, local_id)
} else {
let parent = genv.tcx.generics_of(def_id).parent;
Ok(rty::RefinementGenerics { parent, parent_count: 0, params: List::empty() })
}
})
}
Expand Down
41 changes: 22 additions & 19 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,16 @@ pub use crate::{

#[derive(Debug, Clone)]
pub struct Generics {
pub parent: Option<DefId>,
pub parent_count: usize,
pub params: List<GenericParamDef>,
pub refine_params: List<RefineParam>,
}

#[derive(Debug, Clone)]
pub struct RefinementGenerics {
pub parent: Option<DefId>,
pub parent_count: usize,
pub parent_refine_count: usize,
pub params: List<RefineParam>,
}

#[derive(PartialEq, Eq, Debug, Clone, Hash)]
Expand Down Expand Up @@ -599,39 +604,37 @@ impl Generics {
Ok(self.params[index].clone())
} else {
let parent = self.parent.expect("parent_count > 0 but no parent?");
let parent_generics = genv.generics_of(parent)?;
parent_generics.param_at(param_index, genv)
genv.generics_of(parent)?.param_at(param_index, genv)
}
}
}

pub fn refine_count(&self) -> usize {
self.parent_refine_count + self.refine_params.len()
impl RefinementGenerics {
pub fn count(&self) -> usize {
self.parent_count + self.params.len()
}

pub fn refine_param_at(
&self,
param_index: usize,
genv: &GlobalEnv,
) -> QueryResult<RefineParam> {
if let Some(index) = param_index.checked_sub(self.parent_refine_count) {
Ok(self.refine_params[index].clone())
pub fn param_at(&self, param_index: usize, genv: &GlobalEnv) -> QueryResult<RefineParam> {
if let Some(index) = param_index.checked_sub(self.parent_count) {
Ok(self.params[index].clone())
} else {
genv.generics_of(self.parent.expect("parent_count > 0 but no parent?"))?
.refine_param_at(param_index, genv)
let parent = self.parent.expect("parent_count > 0 but no parent?");
genv.refinement_generics_of(parent)?
.param_at(param_index, genv)
}
}

/// Iterate and collect all refinement parameters in this item including parents
pub fn collect_all_refine_params<T, S>(
/// Iterate and collect all parameters in this item including parents
pub fn collect_all_params<T, S>(
&self,
genv: &GlobalEnv,
mut f: impl FnMut(RefineParam) -> T,
) -> QueryResult<S>
where
S: FromIterator<T>,
{
(0..self.refine_count())
.map(|i| Ok(f(self.refine_param_at(i, genv)?)))
(0..self.count())
.map(|i| Ok(f(self.param_at(i, genv)?)))
.try_collect()
}
}
Expand Down
21 changes: 4 additions & 17 deletions crates/flux-middle/src/rty/refining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use rustc_middle::ty::{ClosureKind, ParamTy};
use super::fold::TypeFoldable;
use crate::{global_env::GlobalEnv, intern::List, queries::QueryResult, rty, rustc};

pub(crate) fn refine_generics(
genv: &GlobalEnv,
generics: &rustc::ty::Generics,
) -> QueryResult<rty::Generics> {
pub(crate) fn refine_generics(generics: &rustc::ty::Generics) -> QueryResult<rty::Generics> {
let params = generics
.params
.iter()
Expand All @@ -35,17 +32,7 @@ pub(crate) fn refine_generics(
})
.collect();

Ok(rty::Generics {
params,
refine_params: List::empty(),
parent: generics.parent(),
parent_count: generics.parent_count(),
parent_refine_count: generics
.parent()
.map(|parent| genv.generics_of(parent))
.transpose()?
.map_or(0, |g| g.refine_count()),
})
Ok(rty::Generics { params, parent: generics.parent(), parent_count: generics.parent_count() })
}

pub struct Refiner<'a, 'tcx> {
Expand Down Expand Up @@ -386,8 +373,8 @@ impl<'a, 'tcx> Refiner<'a, 'tcx> {
) -> QueryResult<rty::RefineArgs> {
if let rustc::ty::AliasKind::Opaque = alias_kind {
self.genv
.generics_of(def_id)?
.collect_all_refine_params(self.genv, |param| {
.refinement_generics_of(def_id)?
.collect_all_params(self.genv, |param| {
rty::Expr::hole(rty::HoleKind::Expr(param.sort.clone()))
})
} else {
Expand Down
5 changes: 3 additions & 2 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ impl<'a, 'tcx, M: Mode> Checker<'a, 'tcx, M> {
let refine_params = if let Some(refine_params) = refine_params {
refine_params
} else {
generics
.collect_all_refine_params(genv, |param| rcx.define_vars(&param.sort))
genv.refinement_generics_of(def_id)
.with_span(span)?
.collect_all_params(genv, |param| rcx.define_vars(&param.sort))
.with_span(span)?
};

Expand Down
6 changes: 2 additions & 4 deletions crates/flux-refineck/src/constraint_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,8 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
) -> Result<Vec<Expr>, CheckerErrKind> {
if let Some(callee_id) = callee_def_id {
Ok(genv
.generics_of(callee_id)?
.collect_all_refine_params(genv, |param| {
self.fresh_infer_var(&param.sort, param.mode)
})?)
.refinement_generics_of(callee_id)?
.collect_all_params(genv, |param| self.fresh_infer_var(&param.sort, param.mode))?)
} else {
Ok(vec![])
}
Expand Down

0 comments on commit db49182

Please sign in to comment.