1
+ use std:: iter;
2
+
1
3
use either:: Either ;
2
4
use rustc_data_structures:: captures:: Captures ;
3
5
use rustc_data_structures:: fx:: FxIndexSet ;
@@ -10,27 +12,26 @@ use rustc_hir::intravisit::{walk_block, walk_expr, Visitor};
10
12
use rustc_hir:: { AsyncGeneratorKind , GeneratorKind , LangItem } ;
11
13
use rustc_infer:: infer:: TyCtxtInferExt ;
12
14
use rustc_infer:: traits:: ObligationCause ;
15
+ use rustc_middle:: hir:: nested_filter:: OnlyBodies ;
13
16
use rustc_middle:: mir:: tcx:: PlaceTy ;
14
17
use rustc_middle:: mir:: {
15
18
self , AggregateKind , BindingForm , BorrowKind , ClearCrossCrate , ConstraintCategory ,
16
19
FakeReadCause , LocalDecl , LocalInfo , LocalKind , Location , Operand , Place , PlaceRef ,
17
20
ProjectionElem , Rvalue , Statement , StatementKind , Terminator , TerminatorKind , VarBindingForm ,
18
21
} ;
19
- use rustc_middle:: ty:: { self , suggest_constraining_type_params, PredicateKind , Ty } ;
22
+ use rustc_middle:: ty:: { self , suggest_constraining_type_params, PredicateKind , Ty , TypeckResults } ;
20
23
use rustc_middle:: util:: CallKind ;
21
24
use rustc_mir_dataflow:: move_paths:: { InitKind , MoveOutIndex , MovePathIndex } ;
22
25
use rustc_span:: def_id:: LocalDefId ;
23
26
use rustc_span:: hygiene:: DesugaringKind ;
24
- use rustc_span:: symbol:: { kw, sym} ;
27
+ use rustc_span:: symbol:: { kw, sym, Ident } ;
25
28
use rustc_span:: { BytePos , Span , Symbol } ;
26
29
use rustc_trait_selection:: infer:: InferCtxtExt ;
27
30
use rustc_trait_selection:: traits:: ObligationCtxt ;
28
31
29
32
use crate :: borrow_set:: TwoPhaseActivation ;
30
33
use crate :: borrowck_errors;
31
-
32
34
use crate :: diagnostics:: conflict_errors:: StorageDeadOrDrop :: LocalStorageDead ;
33
- use crate :: diagnostics:: mutability_errors:: mut_borrow_of_mutable_ref;
34
35
use crate :: diagnostics:: { find_all_local_uses, CapturedMessageOpt } ;
35
36
use crate :: {
36
37
borrow_set:: BorrowData , diagnostics:: Instance , prefixes:: IsPrefixOf ,
@@ -959,7 +960,8 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
959
960
& msg_borrow,
960
961
None ,
961
962
) ;
962
- self . suggest_binding_for_closure_capture_self (
963
+ self . suggest_binding_for_closure_capture_self ( & mut err, & issued_spans) ;
964
+ self . suggest_using_closure_argument_instead_of_capture (
963
965
& mut err,
964
966
issued_borrow. borrowed_place ,
965
967
& issued_spans,
@@ -982,6 +984,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
982
984
place,
983
985
issued_borrow. borrowed_place ,
984
986
) ;
987
+ self . suggest_using_closure_argument_instead_of_capture (
988
+ & mut err,
989
+ issued_borrow. borrowed_place ,
990
+ & issued_spans,
991
+ ) ;
985
992
err
986
993
}
987
994
@@ -1268,22 +1275,161 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
1268
1275
}
1269
1276
}
1270
1277
1271
- fn suggest_binding_for_closure_capture_self (
1278
+ /// Suggest using closure argument instead of capture.
1279
+ ///
1280
+ /// For example:
1281
+ /// ```ignore (illustrative)
1282
+ /// struct S;
1283
+ ///
1284
+ /// impl S {
1285
+ /// fn call(&mut self, f: impl Fn(&mut Self)) { /* ... */ }
1286
+ /// fn x(&self) {}
1287
+ /// }
1288
+ ///
1289
+ /// let mut v = S;
1290
+ /// v.call(|this: &mut S| v.x());
1291
+ /// // ^\ ^-- help: try using the closure argument: `this`
1292
+ /// // *-- error: cannot borrow `v` as mutable because it is also borrowed as immutable
1293
+ /// ```
1294
+ fn suggest_using_closure_argument_instead_of_capture (
1272
1295
& self ,
1273
1296
err : & mut Diagnostic ,
1274
1297
borrowed_place : Place < ' tcx > ,
1275
1298
issued_spans : & UseSpans < ' tcx > ,
1276
1299
) {
1277
- let UseSpans :: ClosureUse { capture_kind_span, .. } = issued_spans else { return } ;
1278
- let hir = self . infcx . tcx . hir ( ) ;
1300
+ let & UseSpans :: ClosureUse { capture_kind_span, .. } = issued_spans else { return } ;
1301
+ let tcx = self . infcx . tcx ;
1302
+ let hir = tcx. hir ( ) ;
1279
1303
1280
- // check whether the borrowed place is capturing `self` by mut reference
1304
+ // Get the type of the local that we are trying to borrow
1281
1305
let local = borrowed_place. local ;
1282
- let Some ( _) = self
1283
- . body
1284
- . local_decls
1285
- . get ( local)
1286
- . map ( |l| mut_borrow_of_mutable_ref ( l, self . local_names [ local] ) ) else { return } ;
1306
+ let local_ty = self . body . local_decls [ local] . ty ;
1307
+
1308
+ // Get the body the error happens in
1309
+ let Some ( body_id) = hir. get ( self . mir_hir_id ( ) ) . body_id ( ) else { return } ;
1310
+
1311
+ let body_expr = hir. body ( body_id) . value ;
1312
+
1313
+ struct ClosureFinder < ' hir > {
1314
+ hir : rustc_middle:: hir:: map:: Map < ' hir > ,
1315
+ borrow_span : Span ,
1316
+ res : Option < ( & ' hir hir:: Expr < ' hir > , & ' hir hir:: Closure < ' hir > ) > ,
1317
+ /// The path expression with the `borrow_span` span
1318
+ error_path : Option < ( & ' hir hir:: Expr < ' hir > , & ' hir hir:: QPath < ' hir > ) > ,
1319
+ }
1320
+ impl < ' hir > Visitor < ' hir > for ClosureFinder < ' hir > {
1321
+ type NestedFilter = OnlyBodies ;
1322
+
1323
+ fn nested_visit_map ( & mut self ) -> Self :: Map {
1324
+ self . hir
1325
+ }
1326
+
1327
+ fn visit_expr ( & mut self , ex : & ' hir hir:: Expr < ' hir > ) {
1328
+ if let hir:: ExprKind :: Path ( qpath) = & ex. kind
1329
+ && ex. span == self . borrow_span
1330
+ {
1331
+ self . error_path = Some ( ( ex, qpath) ) ;
1332
+ }
1333
+
1334
+ if let hir:: ExprKind :: Closure ( closure) = ex. kind
1335
+ && ex. span . contains ( self . borrow_span )
1336
+ // To support cases like `|| { v.call(|this| v.get()) }`
1337
+ // FIXME: actually support such cases (need to figure out how to move from the capture place to original local)
1338
+ && self . res . as_ref ( ) . map_or ( true , |( prev_res, _) | prev_res. span . contains ( ex. span ) )
1339
+ {
1340
+ self . res = Some ( ( ex, closure) ) ;
1341
+ }
1342
+
1343
+ hir:: intravisit:: walk_expr ( self , ex) ;
1344
+ }
1345
+ }
1346
+
1347
+ // Find the closure that most tightly wraps `capture_kind_span`
1348
+ let mut finder =
1349
+ ClosureFinder { hir, borrow_span : capture_kind_span, res : None , error_path : None } ;
1350
+ finder. visit_expr ( body_expr) ;
1351
+ let Some ( ( closure_expr, closure) ) = finder. res else { return } ;
1352
+
1353
+ let typeck_results: & TypeckResults < ' _ > =
1354
+ tcx. typeck_opt_const_arg ( self . body . source . with_opt_param ( ) . as_local ( ) . unwrap ( ) ) ;
1355
+
1356
+ // Check that the parent of the closure is a method call,
1357
+ // with receiver matching with local's type (modulo refs)
1358
+ let parent = hir. parent_id ( closure_expr. hir_id ) ;
1359
+ if let hir:: Node :: Expr ( parent) = hir. get ( parent) {
1360
+ if let hir:: ExprKind :: MethodCall ( _, recv, ..) = parent. kind {
1361
+ let recv_ty = typeck_results. expr_ty ( recv) ;
1362
+
1363
+ if recv_ty. peel_refs ( ) != local_ty {
1364
+ return ;
1365
+ }
1366
+ }
1367
+ }
1368
+
1369
+ // Get closure's arguments
1370
+ let ty:: Closure ( _, substs) = typeck_results. expr_ty ( closure_expr) . kind ( ) else { unreachable ! ( ) } ;
1371
+ let sig = substs. as_closure ( ) . sig ( ) ;
1372
+ let tupled_params =
1373
+ tcx. erase_late_bound_regions ( sig. inputs ( ) . iter ( ) . next ( ) . unwrap ( ) . map_bound ( |& b| b) ) ;
1374
+ let ty:: Tuple ( params) = tupled_params. kind ( ) else { return } ;
1375
+
1376
+ // Find the first argument with a matching type, get its name
1377
+ let Some ( ( _, this_name) ) = params
1378
+ . iter ( )
1379
+ . zip ( hir. body_param_names ( closure. body ) )
1380
+ . find ( |( param_ty, name) |{
1381
+ // FIXME: also support deref for stuff like `Rc` arguments
1382
+ param_ty. peel_refs ( ) == local_ty && name != & Ident :: empty ( )
1383
+ } )
1384
+ else { return } ;
1385
+
1386
+ let spans;
1387
+ if let Some ( ( _path_expr, qpath) ) = finder. error_path
1388
+ && let hir:: QPath :: Resolved ( _, path) = qpath
1389
+ && let hir:: def:: Res :: Local ( local_id) = path. res
1390
+ {
1391
+ // Find all references to the problematic variable in this closure body
1392
+
1393
+ struct VariableUseFinder {
1394
+ local_id : hir:: HirId ,
1395
+ spans : Vec < Span > ,
1396
+ }
1397
+ impl < ' hir > Visitor < ' hir > for VariableUseFinder {
1398
+ fn visit_expr ( & mut self , ex : & ' hir hir:: Expr < ' hir > ) {
1399
+ if let hir:: ExprKind :: Path ( qpath) = & ex. kind
1400
+ && let hir:: QPath :: Resolved ( _, path) = qpath
1401
+ && let hir:: def:: Res :: Local ( local_id) = path. res
1402
+ && local_id == self . local_id
1403
+ {
1404
+ self . spans . push ( ex. span ) ;
1405
+ }
1406
+
1407
+ hir:: intravisit:: walk_expr ( self , ex) ;
1408
+ }
1409
+ }
1410
+
1411
+ let mut finder = VariableUseFinder { local_id, spans : Vec :: new ( ) } ;
1412
+ finder. visit_expr ( hir. body ( closure. body ) . value ) ;
1413
+
1414
+ spans = finder. spans ;
1415
+ } else {
1416
+ spans = vec ! [ capture_kind_span] ;
1417
+ }
1418
+
1419
+ err. multipart_suggestion (
1420
+ "try using the closure argument" ,
1421
+ iter:: zip ( spans, iter:: repeat ( this_name. to_string ( ) ) ) . collect ( ) ,
1422
+ Applicability :: MaybeIncorrect ,
1423
+ ) ;
1424
+ }
1425
+
1426
+ fn suggest_binding_for_closure_capture_self (
1427
+ & self ,
1428
+ err : & mut Diagnostic ,
1429
+ issued_spans : & UseSpans < ' tcx > ,
1430
+ ) {
1431
+ let UseSpans :: ClosureUse { capture_kind_span, .. } = issued_spans else { return } ;
1432
+ let hir = self . infcx . tcx . hir ( ) ;
1287
1433
1288
1434
struct ExpressionFinder < ' hir > {
1289
1435
capture_span : Span ,
0 commit comments