Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: request cancellation #7054

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Lean/Server/CodeActions/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array
let caps ← names.mapM evalCodeActionProvider
return (← builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps
caps.flatMapM fun (providerName, cap) => do
RequestM.checkCancelled
let cas ← cap params snap
cas.mapIdxM fun i lca => do
if lca.lazy?.isNone then return lca.eager
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Server/Completion.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga
-/
prelude
import Lean.Server.Completion.CompletionCollectors
import Lean.Server.RequestCancellation
import Std.Data.HashMap

namespace Lean.Server.Completion
Expand Down Expand Up @@ -61,11 +62,12 @@ partial def find?
(cmdStx : Syntax)
(infoTree : InfoTree)
(caps : ClientCapabilities)
: IO CompletionList := do
: CancellableM CompletionList := do
let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree
let mut allCompletions := #[]
for partition in prioritizedPartitions do
for (i, completionInfoPos) in partition do
CancellableM.checkCancelled
let completions : Array ScoredCompletionItem ←
match i.info with
| .id stx id danglingDot lctx .. =>
Expand Down
39 changes: 27 additions & 12 deletions src/Lean/Server/Completion/CompletionCollectors.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching
import Lean.Elab.Tactic.Doc
import Lean.Server.Completion.CompletionResolution
import Lean.Server.Completion.EligibleHeaderDecls
import Lean.Server.RequestCancellation

namespace Lean.Server.Completion
open Elab
Expand Down Expand Up @@ -36,7 +37,7 @@ section Infrastructure
Monad used for completion computation that allows modifying a completion `State` and reading
`CompletionParams`.
-/
private abbrev M := ReaderT Context $ StateRefT State MetaM
private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM

/-- Adds a new completion item to the state in `M`. -/
private def addItem
Expand Down Expand Up @@ -114,10 +115,13 @@ section Infrastructure
(ctx : ContextInfo)
(lctx : LocalContext)
(x : M Unit)
: IO (Array ScoredCompletionItem) :=
ctx.runMetaM lctx do
let (_, s) ← x.run ⟨params, completionInfoPos⟩ |>.run {}
return s.items
: CancellableM (Array ScoredCompletionItem) := do
let tk ← read
let r ← ctx.runMetaM lctx do
x.run ⟨params, completionInfoPos⟩ |>.run {} |>.run tk
match r with
| .error _ => throw .requestCancelled
| .ok (_, s) => return s.items

end Infrastructure

Expand Down Expand Up @@ -161,6 +165,16 @@ section Utils
return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat)
return none

private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m]
[MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m]
(f : Name → ConstantInfo → m PUnit) : m PUnit := do
let _ ← StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do
modify (· + 1)
if (← get) >= 10000 then
RequestCancellation.check
set <| 0
f decl ci

end Utils

section IdCompletionUtils
Expand Down Expand Up @@ -349,7 +363,7 @@ private def idCompletionCore
addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score
-- search for matches in the environment
let env ← getEnv
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let bestMatch? ← (·.2) <$> StateT.run (s := none) do
let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do
let some (label, score) ← matchDecl? ns id danglingDot declName
Expand Down Expand Up @@ -380,6 +394,7 @@ private def idCompletionCore
matchUsingNamespace Name.anonymous
if let some (bestLabel, bestScore) := bestMatch? then
addUnresolvedCompletionItem bestLabel (.const declName) (← getCompletionKindForDecl c) bestScore
RequestCancellation.check
let matchAlias (ns : Name) (alias : Name) : Option Float :=
-- Recall that aliases may not be atomic and include the namespace where they were created.
if ns.isPrefixOf alias then
Expand Down Expand Up @@ -434,7 +449,7 @@ def idCompletion
(id : Name)
(hoverInfo : HoverInfo)
(danglingDot : Bool)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
idCompletionCore ctx stx id hoverInfo danglingDot

Expand All @@ -443,7 +458,7 @@ def dotCompletion
(completionInfoPos : Nat)
(ctx : ContextInfo)
(info : TermInfo)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx info.lctx do
let nameSet ← try
getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr))
Expand All @@ -452,7 +467,7 @@ def dotCompletion
if nameSet.isEmpty then
return

forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
Expand All @@ -471,7 +486,7 @@ def dotIdCompletion
(lctx : LocalContext)
(id : Name)
(expectedType? : Option Expr)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let some expectedType := expectedType?
| return ()
Expand All @@ -485,7 +500,7 @@ def dotIdCompletion
catch _ =>
pure RBTree.empty

forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
Expand Down Expand Up @@ -513,7 +528,7 @@ def fieldIdCompletion
(lctx : LocalContext)
(id : Option Name)
(structName : Name)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let idStr := id.map (·.toString) |>.getD ""
let fieldNames := getStructureFieldsFlattened (← getEnv) structName (includeSubobjectFields := false)
Expand Down
10 changes: 8 additions & 2 deletions src/Lean/Server/FileWorker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,14 @@ section NotificationHandling
let newDocText := foldDocumentChanges changes oldDoc.meta.text
updateDocument ⟨docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode⟩
for (_, r) in st.pendingRequests do
r.cancelTk.cancel .edit
r.cancelTk.cancelByEdit


def handleCancelRequest (p : CancelParams) : WorkerM Unit := do
let st ← get
let some r := st.pendingRequests.find? p.id
| return
r.cancelTk.cancel .cancelRequest
r.cancelTk.cancelByCancelRequest
set <| { st with pendingRequests := st.pendingRequests.erase p.id }

/--
Expand Down Expand Up @@ -741,6 +741,12 @@ section MessageHandling
pure <| Task.pure <| .ok ()
| Except.ok t => (IO.mapTask · t) fun
| Except.ok r => do
if ← cancelTk.wasCancelledByCancelRequest then
-- Try not to emit a partial response if this request was cancelled.
-- Clients usually discard responses for requests that they cancelled anyways,
-- but it's still good to send less over the wire in this case.
emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id
return
emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response)
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
Expand Down
3 changes: 1 addition & 2 deletions src/Lean/Server/FileWorker/InlayHints.lean
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
| some lastEditTimestamp =>
let timeSinceLastEditMs := timestamp - lastEditTimestamp
inlayHintEditDelayMs - timeSinceLastEditMs
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask)
let finishedRange? : Option String.Range := do
return ⟨⟨0⟩, ← List.max? <| snaps.map (fun s => s.endPos)⟩
let oldInlayHints :=
Expand All @@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
let lspInlayHints ← inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text)
let r := { response := lspInlayHints, isComplete }
let s := { s with oldInlayHints := inlayHints }
RequestM.checkCanceled
return (r, s)

def handleInlayHintsDidChange (p : DidChangeTextDocumentParams)
Expand Down
17 changes: 10 additions & 7 deletions src/Lean/Server/FileWorker/RequestHandling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
let t := doc.cmdSnaps.waitAll
mapTask t fun (snaps, _) => do
let mut stxs := snaps.map (·.stx)
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
where
toDocumentSymbols (text : FileMap) (stxs : List Syntax)
(syms : Array DocumentSymbol) (stack : List NamespaceEntry) :
Array DocumentSymbol :=
RequestM (Array DocumentSymbol) := do
RequestM.checkCancelled
match stxs with
| [] => stack.foldl (fun syms entry => entry.finish text syms none) syms
| [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms
| stx::stxs => match stx with
| `(namespace $id) =>
let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms }
Expand All @@ -411,9 +412,9 @@ where
let syms := entry.finish text syms stx
popStack (n - entry.name.length) syms stack
popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack
| _ => Id.run do
| _ => do
unless stx.isOfKind ``Lean.Parser.Command.declaration do
return toDocumentSymbols text stxs syms stack
return toDocumentSymbols text stxs syms stack
if let some stxRange := stx.getRange? then
let (name, selection) := match stx with
| `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) =>
Expand All @@ -431,7 +432,7 @@ where
range := stxRange.toLspRange text
selectionRange := selRange.toLspRange text
}
return toDocumentSymbols text stxs (syms.push sym) stack
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack

partial def handleFoldingRange (_ : FoldingRangeParams)
Expand All @@ -450,7 +451,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams)
if let (_, start)::rest := sections then
addRange text FoldingRangeKind.region start text.source.endPos
addRanges text rest []
| stx::stxs => match stx with
| stx::stxs => do
RequestM.checkCancelled
match stx with
| `(namespace $id) =>
addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs
| `(section $(id)?) =>
Expand Down
6 changes: 4 additions & 2 deletions src/Lean/Server/FileWorker/SemanticHighlighting.lean
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask)
asTask <| do
return { response := ← run doc snaps, isComplete }
| some endPos =>
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => do
RequestM.checkCanceled
return { response := ← run doc snaps, isComplete := true }
where
run doc snaps : RequestM SemanticTokens := do
Expand All @@ -164,8 +163,11 @@ where
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
RequestM.checkCancelled
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

Expand Down
77 changes: 77 additions & 0 deletions src/Lean/Server/RequestCancellation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Marc Huisinga
-/
prelude
import Init.System.Promise

namespace Lean.Server

structure RequestCancellationToken where
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
cancellationPromise : IO.Promise Unit

namespace RequestCancellationToken

def new : IO RequestCancellationToken := do
return {
cancelledByCancelRequest := ← IO.mkRef false
cancelledByEdit := ← IO.mkRef false
cancellationPromise := ← IO.Promise.new
}

def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByCancelRequest.set true
tk.cancellationPromise.resolve ()

def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByEdit.set true
tk.cancellationPromise.resolve ()

def cancellationTask (tk : RequestCancellationToken) : Task Unit :=
tk.cancellationPromise.result!

def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool :=
tk.cancelledByCancelRequest.get

def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
tk.cancelledByEdit.get

end RequestCancellationToken

structure RequestCancellation where

def RequestCancellation.requestCancelled : RequestCancellation := {}

abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m)
abbrev CancellableM := CancellableT IO

def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) :
m (Except RequestCancellation α) :=
x tk

def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) :
IO (Except RequestCancellation α) :=
CancellableT.run tk x

def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do
let tk ← read
if ← tk.wasCancelledByCancelRequest then
throw .requestCancelled

def CancellableM.checkCancelled : CancellableM Unit :=
CancellableT.checkCancelled

class MonadCancellable (m : Type → Type v) where
checkCancelled : m PUnit

instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where
checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit)

instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where
checkCancelled := CancellableT.checkCancelled

def RequestCancellation.check [MonadCancellable m] : m Unit :=
MonadCancellable.checkCancelled
Loading
Loading