Skip to content

Commit 618d2d6

Browse files
Properly drain pending obligations for coroutines
1 parent 7439a16 commit 618d2d6

File tree

19 files changed

+233
-56
lines changed

19 files changed

+233
-56
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+3
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
163163
// Resume type defaults to `()` if the coroutine has no argument.
164164
let resume_ty = liberated_sig.inputs().get(0).copied().unwrap_or(tcx.types.unit);
165165

166+
// TODO: In the new solver, we can just instantiate this eagerly
167+
// with the witness. This will ensure that goals that don't need
168+
// to stall on interior types will get processed eagerly.
166169
let interior = self.next_ty_var(expr_span);
167170
self.deferred_coroutine_interiors.borrow_mut().push((expr_def_id, interior));
168171

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
659659
obligations.extend(ok.obligations);
660660
}
661661

662-
// FIXME: Use a real visitor for unstalled obligations in the new solver.
663662
if !coroutines.is_empty() {
664-
obligations
665-
.extend(self.fulfillment_cx.borrow_mut().drain_unstalled_obligations(&self.infcx));
663+
obligations.extend(
664+
self.fulfillment_cx
665+
.borrow_mut()
666+
.drain_stalled_obligations_for_coroutines(&self.infcx),
667+
);
666668
}
667669

668670
self.typeck_results

compiler/rustc_hir_typeck/src/typeck_root_ctxt.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<'tcx> TypeckRootCtxt<'tcx> {
8484
let hir_owner = tcx.local_def_id_to_hir_id(def_id).owner;
8585

8686
let infcx =
87-
tcx.infer_ctxt().ignoring_regions().build(TypingMode::analysis_in_body(tcx, def_id));
87+
tcx.infer_ctxt().ignoring_regions().build(TypingMode::typeck_for_body(tcx, def_id));
8888
let typeck_results = RefCell::new(ty::TypeckResults::new(hir_owner));
8989

9090
TypeckRootCtxt {

compiler/rustc_infer/src/infer/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ impl<'tcx> InferCtxt<'tcx> {
966966
pub fn can_define_opaque_ty(&self, id: impl Into<DefId>) -> bool {
967967
debug_assert!(!self.next_trait_solver());
968968
match self.typing_mode() {
969-
TypingMode::Analysis { defining_opaque_types } => {
969+
TypingMode::Analysis { defining_opaque_types, stalled_generators: _ } => {
970970
id.into().as_local().is_some_and(|def_id| defining_opaque_types.contains(&def_id))
971971
}
972972
// FIXME(#132279): This function is quite weird in post-analysis
@@ -1260,7 +1260,7 @@ impl<'tcx> InferCtxt<'tcx> {
12601260
// to handle them without proper canonicalization. This means we may cause cycle
12611261
// errors and fail to reveal opaques while inside of bodies. We should rename this
12621262
// function and require explicit comments on all use-sites in the future.
1263-
ty::TypingMode::Analysis { defining_opaque_types: _ } => {
1263+
ty::TypingMode::Analysis { defining_opaque_types: _, stalled_generators: _ } => {
12641264
TypingMode::non_body_analysis()
12651265
}
12661266
mode @ (ty::TypingMode::Coherence

compiler/rustc_infer/src/traits/engine.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub trait TraitEngine<'tcx, E: 'tcx>: 'tcx {
9494
/// Among all pending obligations, collect those are stalled on a inference variable which has
9595
/// changed since the last call to `select_where_possible`. Those obligations are marked as
9696
/// successful and returned.
97-
fn drain_unstalled_obligations(
97+
fn drain_stalled_obligations_for_coroutines(
9898
&mut self,
9999
infcx: &InferCtxt<'tcx>,
100100
) -> PredicateObligations<'tcx>;

compiler/rustc_middle/src/query/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,15 @@ rustc_queries! {
379379
}
380380
}
381381

382+
query stalled_generators_within(
383+
key: LocalDefId
384+
) -> &'tcx ty::List<LocalDefId> {
385+
desc {
386+
|tcx| "computing the opaque types defined by `{}`",
387+
tcx.def_path_str(key.to_def_id())
388+
}
389+
}
390+
382391
/// Returns the explicitly user-written *bounds* on the associated or opaque type given by `DefId`
383392
/// that must be proven true at definition site (and which can be assumed at usage sites).
384393
///

compiler/rustc_middle/src/query/plumbing.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,11 @@ macro_rules! define_callbacks {
366366

367367
pub type Storage<'tcx> = <$($K)* as keys::Key>::Cache<Erase<$V>>;
368368

369-
// Ensure that keys grow no larger than 80 bytes by accident.
369+
// Ensure that keys grow no larger than 96 bytes by accident.
370370
// Increase this limit if necessary, but do try to keep the size low if possible
371371
#[cfg(target_pointer_width = "64")]
372372
const _: () = {
373-
if size_of::<Key<'static>>() > 88 {
373+
if size_of::<Key<'static>>() > 96 {
374374
panic!("{}", concat!(
375375
"the query `",
376376
stringify!($name),

compiler/rustc_middle/src/ty/context.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
106106
) -> Self::PredefinedOpaques {
107107
self.mk_predefined_opaques_in_body(data)
108108
}
109-
type DefiningOpaqueTypes = &'tcx ty::List<LocalDefId>;
109+
type LocalDefIds = &'tcx ty::List<LocalDefId>;
110110
type CanonicalVars = CanonicalVarInfos<'tcx>;
111111
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
112112
self.mk_canonical_var_infos(infos)
@@ -663,9 +663,13 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
663663
self.anonymize_bound_vars(binder)
664664
}
665665

666-
fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::DefiningOpaqueTypes {
666+
fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::LocalDefIds {
667667
self.opaque_types_defined_by(defining_anchor)
668668
}
669+
670+
fn stalled_generators_within(self, defining_anchor: Self::LocalDefId) -> Self::LocalDefIds {
671+
self.stalled_generators_within(defining_anchor)
672+
}
669673
}
670674

671675
macro_rules! bidirectional_lang_item_map {

compiler/rustc_next_trait_solver/src/solve/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ where
329329
TypingMode::Coherence | TypingMode::PostAnalysis => false,
330330
// During analysis, opaques are rigid unless they may be defined by
331331
// the current body.
332-
TypingMode::Analysis { defining_opaque_types: non_rigid_opaques }
332+
TypingMode::Analysis {
333+
defining_opaque_types: non_rigid_opaques,
334+
stalled_generators: _,
335+
}
333336
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: non_rigid_opaques } => {
334337
!def_id.as_local().is_some_and(|def_id| non_rigid_opaques.contains(&def_id))
335338
}

compiler/rustc_next_trait_solver/src/solve/normalizes_to/opaque_types.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ where
3333
);
3434
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
3535
}
36-
TypingMode::Analysis { defining_opaque_types } => {
36+
TypingMode::Analysis { defining_opaque_types, stalled_generators: _ } => {
3737
let Some(def_id) = opaque_ty
3838
.def_id
3939
.as_local()

compiler/rustc_next_trait_solver/src/solve/trait_goals.rs

+14
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,20 @@ where
189189
debug_assert!(ecx.opaque_type_is_rigid(opaque_ty.def_id));
190190
}
191191

192+
if let ty::CoroutineWitness(def_id, _) = goal.predicate.self_ty().kind() {
193+
match ecx.typing_mode() {
194+
TypingMode::Analysis { stalled_generators, defining_opaque_types: _ } => {
195+
if def_id.as_local().is_some_and(|def_id| stalled_generators.contains(&def_id))
196+
{
197+
return ecx.forced_ambiguity(MaybeCause::Ambiguity);
198+
}
199+
}
200+
TypingMode::Coherence
201+
| TypingMode::PostAnalysis
202+
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ } => {}
203+
}
204+
}
205+
192206
ecx.probe_and_evaluate_goal_for_constituent_tys(
193207
CandidateSource::BuiltinImpl(BuiltinImplSource::Misc),
194208
goal,

compiler/rustc_trait_selection/src/solve/fulfill.rs

+86-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
use std::marker::PhantomData;
22
use std::mem;
3+
use std::ops::ControlFlow;
34

45
use rustc_data_structures::thinvec::ExtractIf;
6+
use rustc_hir::def_id::LocalDefId;
57
use rustc_infer::infer::InferCtxt;
68
use rustc_infer::traits::query::NoSolution;
79
use rustc_infer::traits::{
810
FromSolverError, PredicateObligation, PredicateObligations, TraitEngine,
911
};
12+
use rustc_middle::ty::{
13+
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitor, TypingMode,
14+
};
1015
use rustc_next_trait_solver::solve::{GenerateProofTree, HasChanged, SolverDelegateEvalExt as _};
16+
use rustc_span::Span;
1117
use tracing::instrument;
1218

1319
use self::derive_errors::*;
1420
use super::Certainty;
1521
use super::delegate::SolverDelegate;
22+
use super::inspect::{self, ProofTreeInferCtxtExt};
1623
use crate::traits::{FulfillmentError, ScrubbedTraitError};
1724

1825
mod derive_errors;
@@ -39,7 +46,7 @@ pub struct FulfillmentCtxt<'tcx, E: 'tcx> {
3946
_errors: PhantomData<E>,
4047
}
4148

42-
#[derive(Default)]
49+
#[derive(Default, Debug)]
4350
struct ObligationStorage<'tcx> {
4451
/// Obligations which resulted in an overflow in fulfillment itself.
4552
///
@@ -55,20 +62,23 @@ impl<'tcx> ObligationStorage<'tcx> {
5562
self.pending.push(obligation);
5663
}
5764

65+
fn has_pending_obligations(&self) -> bool {
66+
!self.pending.is_empty() || !self.overflowed.is_empty()
67+
}
68+
5869
fn clone_pending(&self) -> PredicateObligations<'tcx> {
5970
let mut obligations = self.pending.clone();
6071
obligations.extend(self.overflowed.iter().cloned());
6172
obligations
6273
}
6374

64-
fn take_pending(&mut self) -> PredicateObligations<'tcx> {
65-
let mut obligations = mem::take(&mut self.pending);
66-
obligations.append(&mut self.overflowed);
67-
obligations
68-
}
69-
70-
fn unstalled_for_select(&mut self) -> impl Iterator<Item = PredicateObligation<'tcx>> + 'tcx {
71-
mem::take(&mut self.pending).into_iter()
75+
fn drain_pending(
76+
&mut self,
77+
cond: impl Fn(&PredicateObligation<'tcx>) -> bool,
78+
) -> PredicateObligations<'tcx> {
79+
let (unstalled, pending) = mem::take(&mut self.pending).into_iter().partition(cond);
80+
self.pending = pending;
81+
unstalled
7282
}
7383

7484
fn on_fulfillment_overflow(&mut self, infcx: &InferCtxt<'tcx>) {
@@ -160,7 +170,7 @@ where
160170
}
161171

162172
let mut has_changed = false;
163-
for obligation in self.obligations.unstalled_for_select() {
173+
for obligation in self.obligations.drain_pending(|_| true) {
164174
let goal = obligation.as_goal();
165175
let result = <&SolverDelegate<'tcx>>::from(infcx)
166176
.evaluate_root_goal(goal, GenerateProofTree::No, obligation.cause.span)
@@ -196,15 +206,78 @@ where
196206
}
197207

198208
fn has_pending_obligations(&self) -> bool {
199-
!self.obligations.pending.is_empty() || !self.obligations.overflowed.is_empty()
209+
self.obligations.has_pending_obligations()
200210
}
201211

202212
fn pending_obligations(&self) -> PredicateObligations<'tcx> {
203213
self.obligations.clone_pending()
204214
}
205215

206-
fn drain_unstalled_obligations(&mut self, _: &InferCtxt<'tcx>) -> PredicateObligations<'tcx> {
207-
self.obligations.take_pending()
216+
fn drain_stalled_obligations_for_coroutines(
217+
&mut self,
218+
infcx: &InferCtxt<'tcx>,
219+
) -> PredicateObligations<'tcx> {
220+
self.obligations.drain_pending(|obl| {
221+
let stalled_generators = match infcx.typing_mode() {
222+
TypingMode::Analysis { defining_opaque_types: _, stalled_generators } => {
223+
stalled_generators
224+
}
225+
TypingMode::Coherence
226+
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
227+
| TypingMode::PostAnalysis => return false,
228+
};
229+
230+
if stalled_generators.is_empty() {
231+
return false;
232+
}
233+
234+
infcx.probe(|_| {
235+
infcx
236+
.visit_proof_tree(
237+
obl.as_goal(),
238+
&mut StalledOnCoroutines { stalled_generators, span: obl.cause.span },
239+
)
240+
.is_break()
241+
})
242+
})
243+
}
244+
}
245+
246+
struct StalledOnCoroutines<'tcx> {
247+
stalled_generators: &'tcx ty::List<LocalDefId>,
248+
span: Span,
249+
// TODO: Cache
250+
}
251+
252+
impl<'tcx> inspect::ProofTreeVisitor<'tcx> for StalledOnCoroutines<'tcx> {
253+
type Result = ControlFlow<()>;
254+
255+
fn span(&self) -> rustc_span::Span {
256+
self.span
257+
}
258+
259+
fn visit_goal(&mut self, inspect_goal: &super::inspect::InspectGoal<'_, 'tcx>) -> Self::Result {
260+
inspect_goal.goal().predicate.visit_with(self)?;
261+
262+
if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
263+
candidate.visit_nested_no_probe(self)
264+
} else {
265+
ControlFlow::Continue(())
266+
}
267+
}
268+
}
269+
270+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for StalledOnCoroutines<'tcx> {
271+
type Result = ControlFlow<()>;
272+
273+
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
274+
if let ty::CoroutineWitness(def_id, _) = *ty.kind()
275+
&& def_id.as_local().is_some_and(|def_id| self.stalled_generators.contains(&def_id))
276+
{
277+
return ControlFlow::Break(());
278+
}
279+
280+
ty.super_visit_with(self)
208281
}
209282
}
210283

compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,16 @@ pub(super) fn fulfillment_error_for_stalled<'tcx>(
109109
false,
110110
),
111111
Ok((_, Certainty::Yes)) => {
112-
bug!("did not expect successful goal when collecting ambiguity errors")
112+
bug!(
113+
"did not expect successful goal when collecting ambiguity errors for `{:?}`",
114+
infcx.resolve_vars_if_possible(root_obligation.predicate),
115+
)
113116
}
114117
Err(_) => {
115-
bug!("did not expect selection error when collecting ambiguity errors")
118+
bug!(
119+
"did not expect selection error when collecting ambiguity errors for `{:?}`",
120+
infcx.resolve_vars_if_possible(root_obligation.predicate),
121+
)
116122
}
117123
}
118124
});

compiler/rustc_trait_selection/src/solve/normalize.rs

+19-15
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ where
7676
let value = value.try_fold_with(&mut folder)?;
7777
let goals = folder
7878
.fulfill_cx
79-
.drain_unstalled_obligations(at.infcx)
79+
.drain_stalled_obligations_for_coroutines(at.infcx)
8080
.into_iter()
8181
.map(|obl| obl.as_goal())
8282
.collect();
@@ -130,7 +130,7 @@ where
130130
);
131131

132132
self.fulfill_cx.register_predicate_obligation(infcx, obligation);
133-
let errors = self.fulfill_cx.select_all_or_error(infcx);
133+
let errors = self.fulfill_cx.select_where_possible(infcx);
134134
if !errors.is_empty() {
135135
return Err(errors);
136136
}
@@ -171,7 +171,7 @@ where
171171

172172
let result = if infcx.predicate_may_hold(&obligation) {
173173
self.fulfill_cx.register_predicate_obligation(infcx, obligation);
174-
let errors = self.fulfill_cx.select_all_or_error(infcx);
174+
let errors = self.fulfill_cx.select_where_possible(infcx);
175175
if !errors.is_empty() {
176176
return Err(errors);
177177
}
@@ -285,20 +285,24 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for DeeplyNormalizeForDiagnosticsFolder<'_,
285285
}
286286

287287
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
288-
deeply_normalize_with_skipped_universes(
289-
self.at,
290-
ty,
291-
vec![None; ty.outer_exclusive_binder().as_usize()],
292-
)
293-
.unwrap_or_else(|_: Vec<ScrubbedTraitError<'tcx>>| ty.super_fold_with(self))
288+
match deeply_normalize_with_skipped_universes_and_ambiguous_goals::<
289+
_,
290+
ScrubbedTraitError<'tcx>,
291+
>(self.at, ty, vec![None; ty.outer_exclusive_binder().as_usize()])
292+
{
293+
Ok((value, _)) => value,
294+
Err(_) => ty.super_fold_with(self),
295+
}
294296
}
295297

296298
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
297-
deeply_normalize_with_skipped_universes(
298-
self.at,
299-
ct,
300-
vec![None; ct.outer_exclusive_binder().as_usize()],
301-
)
302-
.unwrap_or_else(|_: Vec<ScrubbedTraitError<'tcx>>| ct.super_fold_with(self))
299+
match deeply_normalize_with_skipped_universes_and_ambiguous_goals::<
300+
_,
301+
ScrubbedTraitError<'tcx>,
302+
>(self.at, ct, vec![None; ct.outer_exclusive_binder().as_usize()])
303+
{
304+
Ok((value, _)) => value,
305+
Err(_) => ct.super_fold_with(self),
306+
}
303307
}
304308
}

0 commit comments

Comments
 (0)