diff --git a/src/Lean/Server/CodeActions/Basic.lean b/src/Lean/Server/CodeActions/Basic.lean index 43963cac6dc3..2ca55171fd17 100644 --- a/src/Lean/Server/CodeActions/Basic.lean +++ b/src/Lean/Server/CodeActions/Basic.lean @@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array let caps ← names.mapM evalCodeActionProvider return (← builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps caps.flatMapM fun (providerName, cap) => do + RequestM.checkCancelled let cas ← cap params snap cas.mapIdxM fun i lca => do if lca.lazy?.isNone then return lca.eager diff --git a/src/Lean/Server/Completion.lean b/src/Lean/Server/Completion.lean index 9ffb290958a4..61446b366b34 100644 --- a/src/Lean/Server/Completion.lean +++ b/src/Lean/Server/Completion.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga -/ prelude import Lean.Server.Completion.CompletionCollectors +import Lean.Server.RequestCancellation import Std.Data.HashMap namespace Lean.Server.Completion @@ -61,11 +62,12 @@ partial def find? (cmdStx : Syntax) (infoTree : InfoTree) (caps : ClientCapabilities) - : IO CompletionList := do + : CancellableM CompletionList := do let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree let mut allCompletions := #[] for partition in prioritizedPartitions do for (i, completionInfoPos) in partition do + CancellableM.checkCancelled let completions : Array ScoredCompletionItem ← match i.info with | .id stx id danglingDot lctx .. => diff --git a/src/Lean/Server/Completion/CompletionCollectors.lean b/src/Lean/Server/Completion/CompletionCollectors.lean index 97ff51ce9ddf..4498432614a1 100644 --- a/src/Lean/Server/Completion/CompletionCollectors.lean +++ b/src/Lean/Server/Completion/CompletionCollectors.lean @@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching import Lean.Elab.Tactic.Doc import Lean.Server.Completion.CompletionResolution import Lean.Server.Completion.EligibleHeaderDecls +import Lean.Server.RequestCancellation namespace Lean.Server.Completion open Elab @@ -36,7 +37,7 @@ section Infrastructure Monad used for completion computation that allows modifying a completion `State` and reading `CompletionParams`. -/ - private abbrev M := ReaderT Context $ StateRefT State MetaM + private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM /-- Adds a new completion item to the state in `M`. -/ private def addItem @@ -114,10 +115,13 @@ section Infrastructure (ctx : ContextInfo) (lctx : LocalContext) (x : M Unit) - : IO (Array ScoredCompletionItem) := - ctx.runMetaM lctx do - let (_, s) ← x.run ⟨params, completionInfoPos⟩ |>.run {} - return s.items + : CancellableM (Array ScoredCompletionItem) := do + let tk ← read + let r ← ctx.runMetaM lctx do + x.run ⟨params, completionInfoPos⟩ |>.run {} |>.run tk + match r with + | .error _ => throw .requestCancelled + | .ok (_, s) => return s.items end Infrastructure @@ -161,6 +165,16 @@ section Utils return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat) return none + private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m] + [MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m] + (f : Name → ConstantInfo → m PUnit) : m PUnit := do + let _ ← StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do + modify (· + 1) + if (← get) >= 10000 then + RequestCancellation.check + set <| 0 + f decl ci + end Utils section IdCompletionUtils @@ -349,7 +363,7 @@ private def idCompletionCore addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score -- search for matches in the environment let env ← getEnv - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let bestMatch? ← (·.2) <$> StateT.run (s := none) do let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do let some (label, score) ← matchDecl? ns id danglingDot declName @@ -380,6 +394,7 @@ private def idCompletionCore matchUsingNamespace Name.anonymous if let some (bestLabel, bestScore) := bestMatch? then addUnresolvedCompletionItem bestLabel (.const declName) (← getCompletionKindForDecl c) bestScore + RequestCancellation.check let matchAlias (ns : Name) (alias : Name) : Option Float := -- Recall that aliases may not be atomic and include the namespace where they were created. if ns.isPrefixOf alias then @@ -434,7 +449,7 @@ def idCompletion (id : Name) (hoverInfo : HoverInfo) (danglingDot : Bool) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do idCompletionCore ctx stx id hoverInfo danglingDot @@ -443,7 +458,7 @@ def dotCompletion (completionInfoPos : Nat) (ctx : ContextInfo) (info : TermInfo) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx info.lctx do let nameSet ← try getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr)) @@ -452,7 +467,7 @@ def dotCompletion if nameSet.isEmpty then return - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let unnormedTypeName := declName.getPrefix if ! nameSet.contains unnormedTypeName then return @@ -471,7 +486,7 @@ def dotIdCompletion (lctx : LocalContext) (id : Name) (expectedType? : Option Expr) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do let some expectedType := expectedType? | return () @@ -485,7 +500,7 @@ def dotIdCompletion catch _ => pure RBTree.empty - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let unnormedTypeName := declName.getPrefix if ! nameSet.contains unnormedTypeName then return @@ -513,7 +528,7 @@ def fieldIdCompletion (lctx : LocalContext) (id : Option Name) (structName : Name) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do let idStr := id.map (·.toString) |>.getD "" let fieldNames := getStructureFieldsFlattened (← getEnv) structName (includeSubobjectFields := false) diff --git a/src/Lean/Server/FileWorker.lean b/src/Lean/Server/FileWorker.lean index bdc625661b98..ac303281af55 100644 --- a/src/Lean/Server/FileWorker.lean +++ b/src/Lean/Server/FileWorker.lean @@ -543,14 +543,14 @@ section NotificationHandling let newDocText := foldDocumentChanges changes oldDoc.meta.text updateDocument ⟨docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode⟩ for (_, r) in st.pendingRequests do - r.cancelTk.cancel .edit + r.cancelTk.cancelByEdit def handleCancelRequest (p : CancelParams) : WorkerM Unit := do let st ← get let some r := st.pendingRequests.find? p.id | return - r.cancelTk.cancel .cancelRequest + r.cancelTk.cancelByCancelRequest set <| { st with pendingRequests := st.pendingRequests.erase p.id } /-- @@ -741,6 +741,12 @@ section MessageHandling pure <| Task.pure <| .ok () | Except.ok t => (IO.mapTask · t) fun | Except.ok r => do + if ← cancelTk.wasCancelledByCancelRequest then + -- Try not to emit a partial response if this request was cancelled. + -- Clients usually discard responses for requests that they cancelled anyways, + -- but it's still good to send less over the wire in this case. + emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id + return emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response) | Except.error e => emitResponse ctx (isComplete := false) <| e.toLspResponseError id diff --git a/src/Lean/Server/FileWorker/InlayHints.lean b/src/Lean/Server/FileWorker/InlayHints.lean index 339b468e3e91..bdc3216b3962 100644 --- a/src/Lean/Server/FileWorker/InlayHints.lean +++ b/src/Lean/Server/FileWorker/InlayHints.lean @@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) : | some lastEditTimestamp => let timeSinceLastEditMs := timestamp - lastEditTimestamp inlayHintEditDelayMs - timeSinceLastEditMs - let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask) + let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask) let finishedRange? : Option String.Range := do return ⟨⟨0⟩, ← List.max? <| snaps.map (fun s => s.endPos)⟩ let oldInlayHints := @@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) : let lspInlayHints ← inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text) let r := { response := lspInlayHints, isComplete } let s := { s with oldInlayHints := inlayHints } - RequestM.checkCanceled return (r, s) def handleInlayHintsDidChange (p : DidChangeTextDocumentParams) diff --git a/src/Lean/Server/FileWorker/RequestHandling.lean b/src/Lean/Server/FileWorker/RequestHandling.lean index d272c3aadb75..0e02345908ec 100644 --- a/src/Lean/Server/FileWorker/RequestHandling.lean +++ b/src/Lean/Server/FileWorker/RequestHandling.lean @@ -382,13 +382,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams) let t := doc.cmdSnaps.waitAll mapTask t fun (snaps, _) => do let mut stxs := snaps.map (·.stx) - return { syms := toDocumentSymbols doc.meta.text stxs #[] [] } + return { syms := ← toDocumentSymbols doc.meta.text stxs #[] [] } where toDocumentSymbols (text : FileMap) (stxs : List Syntax) (syms : Array DocumentSymbol) (stack : List NamespaceEntry) : - Array DocumentSymbol := + RequestM (Array DocumentSymbol) := do + RequestM.checkCancelled match stxs with - | [] => stack.foldl (fun syms entry => entry.finish text syms none) syms + | [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms | stx::stxs => match stx with | `(namespace $id) => let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms } @@ -411,9 +412,9 @@ where let syms := entry.finish text syms stx popStack (n - entry.name.length) syms stack popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack - | _ => Id.run do + | _ => do unless stx.isOfKind ``Lean.Parser.Command.declaration do - return toDocumentSymbols text stxs syms stack + return ← toDocumentSymbols text stxs syms stack if let some stxRange := stx.getRange? then let (name, selection) := match stx with | `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) => @@ -431,7 +432,7 @@ where range := stxRange.toLspRange text selectionRange := selRange.toLspRange text } - return toDocumentSymbols text stxs (syms.push sym) stack + return ← toDocumentSymbols text stxs (syms.push sym) stack toDocumentSymbols text stxs syms stack partial def handleFoldingRange (_ : FoldingRangeParams) @@ -450,7 +451,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams) if let (_, start)::rest := sections then addRange text FoldingRangeKind.region start text.source.endPos addRanges text rest [] - | stx::stxs => match stx with + | stx::stxs => do + RequestM.checkCancelled + match stx with | `(namespace $id) => addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs | `(section $(id)?) => diff --git a/src/Lean/Server/FileWorker/SemanticHighlighting.lean b/src/Lean/Server/FileWorker/SemanticHighlighting.lean index fc12536b39c5..d9d666b66e4b 100644 --- a/src/Lean/Server/FileWorker/SemanticHighlighting.lean +++ b/src/Lean/Server/FileWorker/SemanticHighlighting.lean @@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos) -- for the full file before sending a response. This means that the response will be incomplete, -- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the -- `FileWorker` to tell the client to re-compute the semantic tokens. - let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask) + let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask) asTask <| do return { response := ← run doc snaps, isComplete } | some endPos => let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos) mapTask t fun (snaps, _) => do - RequestM.checkCanceled return { response := ← run doc snaps, isComplete := true } where run doc snaps : RequestM SemanticTokens := do @@ -164,8 +163,11 @@ where let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens + RequestM.checkCancelled let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens + RequestM.checkCancelled let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens + RequestM.checkCancelled let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens return semanticTokens diff --git a/src/Lean/Server/RequestCancellation.lean b/src/Lean/Server/RequestCancellation.lean new file mode 100644 index 000000000000..72f68d1422dc --- /dev/null +++ b/src/Lean/Server/RequestCancellation.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Marc Huisinga +-/ +prelude +import Init.System.Promise + +namespace Lean.Server + +structure RequestCancellationToken where + cancelledByCancelRequest : IO.Ref Bool + cancelledByEdit : IO.Ref Bool + cancellationPromise : IO.Promise Unit + +namespace RequestCancellationToken + +def new : IO RequestCancellationToken := do + return { + cancelledByCancelRequest := ← IO.mkRef false + cancelledByEdit := ← IO.mkRef false + cancellationPromise := ← IO.Promise.new + } + +def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do + tk.cancelledByCancelRequest.set true + tk.cancellationPromise.resolve () + +def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do + tk.cancelledByEdit.set true + tk.cancellationPromise.resolve () + +def cancellationTask (tk : RequestCancellationToken) : Task Unit := + tk.cancellationPromise.result! + +def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := + tk.cancelledByCancelRequest.get + +def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do + tk.cancelledByEdit.get + +end RequestCancellationToken + +structure RequestCancellation where + +def RequestCancellation.requestCancelled : RequestCancellation := {} + +abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m) +abbrev CancellableM := CancellableT IO + +def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) : + m (Except RequestCancellation α) := + x tk + +def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) : + IO (Except RequestCancellation α) := + CancellableT.run tk x + +def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do + let tk ← read + if ← tk.wasCancelledByCancelRequest then + throw .requestCancelled + +def CancellableM.checkCancelled : CancellableM Unit := + CancellableT.checkCancelled + +class MonadCancellable (m : Type → Type v) where + checkCancelled : m PUnit + +instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where + checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit) + +instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where + checkCancelled := CancellableT.checkCancelled + +def RequestCancellation.check [MonadCancellable m] : m Unit := + MonadCancellable.checkCancelled diff --git a/src/Lean/Server/Requests.lean b/src/Lean/Server/Requests.lean index 81ae49d50a2d..d1f9c8461cd3 100644 --- a/src/Lean/Server/Requests.lean +++ b/src/Lean/Server/Requests.lean @@ -11,6 +11,8 @@ import Lean.Data.Json import Lean.Data.Lsp import Lean.Elab.Command +import Lean.Server.RequestCancellation + import Lean.Server.FileSource import Lean.Server.FileWorker.Utils @@ -84,47 +86,6 @@ def toLspResponseError (id : RequestID) (e : RequestError) : ResponseError Unit end RequestError -inductive RequestCancellationCause where - | cancelRequest - | edit - deriving Inhabited, BEq - -structure RequestCancellationToken where - promise : IO.Promise RequestCancellationCause - -namespace RequestCancellationToken - -def new : IO RequestCancellationToken := do - return { promise := ← IO.Promise.new } - -def cancel (tk : RequestCancellationToken) (cause : RequestCancellationCause) : IO Unit := - tk.promise.resolve cause - -def task (tk : RequestCancellationToken) : Task RequestCancellationCause := - tk.promise.result! - -def truncatedTask (tk : RequestCancellationToken) : Task Unit := - tk.task.map (sync := true) fun _ => () - -def cancelled? (tk : RequestCancellationToken) : IO (Option RequestCancellationCause) := do - let t := tk.task - if ← IO.hasFinished t then - return some t.get - else - return none - -def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := do - let some c ← tk.cancelled? - | return false - return c matches .cancelRequest - -def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do - let some c ← tk.cancelled? - | return false - return c matches .edit - -end RequestCancellationToken - def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json) : Except RequestError paramType := fromJson? params |>.mapError fun inner => @@ -158,6 +119,14 @@ instance : MonadLift (EIO Exception) RequestM where | .error e => throw <| ← RequestError.ofException e | .ok v => return v +instance : MonadLift CancellableM RequestM where + monadLift x := do + let ctx ← read + let r ← x.run ctx.cancelTk + match r with + | .error _ => throw RequestError.requestCancelled + | .ok v => return v + namespace RequestM open FileWorker open Snapshots @@ -181,7 +150,7 @@ def bindTask (t : Task α) (f : α → RequestM (RequestTask β)) : RequestM (Re let rc ← readThe RequestContext EIO.bindTask t (f · rc) -def checkCanceled : RequestM Unit := do +def checkCancelled : RequestM Unit := do let rc ← readThe RequestContext if ← rc.cancelTk.wasCancelledByCancelRequest then throw .requestCancelled