Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dynamic registration and diagnostic streams #2564

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions plugin/core/diagnostics_storage.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,86 @@
from __future__ import annotations
from .protocol import Diagnostic, DiagnosticSeverity, DocumentUri
from .url import parse_uri
from .protocol import Diagnostic
from .protocol import DiagnosticSeverity
from .protocol import DocumentUri
from .protocol import Point
from .url import normalize_uri
from .views import diagnostic_severity
from collections import OrderedDict
from collections.abc import MutableMapping
from typing import Callable, Iterator, Tuple, TypeVar
import itertools
import functools

ParsedUri = Tuple[str, str]
T = TypeVar('T')


# NOTE: OrderedDict can only be properly typed in Python >=3.8.
class DiagnosticsStorage(OrderedDict):
# From the specs:
#
# When a file changes it is the server’s responsibility to re-compute
# diagnostics and push them to the client. If the computed set is empty
# it has to push the empty array to clear former diagnostics. Newly
# pushed diagnostics always replace previously pushed diagnostics. There
# is no merging that happens on the client side.
#
# https://microsoft.github.io/language-server-protocol/specification#textDocument_publishDiagnostics
class DiagnosticsStorage(MutableMapping):

def add_diagnostics_async(self, document_uri: DocumentUri, diagnostics: list[Diagnostic]) -> None:
"""
Add `diagnostics` for `document_uri` to the store, replacing previously received `diagnoscis`
for this `document_uri`. If `diagnostics` is the empty list, `document_uri` is removed from
the store. The item received is moved to the end of the store.
"""
uri = parse_uri(document_uri)
if not diagnostics:
# received "clear diagnostics" message for this uri
self.pop(uri, None)
return
self[uri] = diagnostics
self.move_to_end(uri) # maintain incoming order
def __init__(self) -> None:
super().__init__()
self._d: dict[tuple[DocumentUri, str | None], list[Diagnostic]] = dict()
self._identifiers: set[str | None] = {None}
self._uris: set[DocumentUri] = set()

def __getitem__(self, key: DocumentUri, /) -> list[Diagnostic]:
uri = normalize_uri(key)
return sorted(
itertools.chain.from_iterable(self._d.get((uri, identifier), []) for identifier in self._identifiers),
key=lambda diagnostic: Point.from_lsp(diagnostic['range']['start'])
)

def __setitem__(self, key: DocumentUri | tuple[DocumentUri, str | None], value: list[Diagnostic], /) -> None:
uri, identifier = (normalize_uri(key), None) if isinstance(key, DocumentUri) else \
(normalize_uri(key[0]), key[1])
if identifier not in self._identifiers:
raise ValueError(f'identifier {identifier} must be registered first')
if value:
self._uris.add(uri)
self._d[(uri, identifier)] = value
else:
self._uris.discard(uri)
self._d.pop((uri, identifier), None)

def __delitem__(self, key: DocumentUri, /) -> None:
uri = normalize_uri(key)
self._uris.discard(uri)
for identifier in self._identifiers:
self._d.pop((uri, identifier), None)

def __iter__(self) -> Iterator[DocumentUri]:
return iter(self._uris)

def __len__(self) -> int:
return len(self._uris)

def register(self, identifier: str) -> None:
""" Register an identifier for a diagnostics stream. """
self._identifiers.add(identifier)

def unregister(self, identifier: str) -> None:
""" Unregister an identifier for a diagnostics stream. """
self._identifiers.discard(identifier)

def filter_map_diagnostics_async(
self, pred: Callable[[Diagnostic], bool], f: Callable[[ParsedUri, Diagnostic], T]
) -> Iterator[tuple[ParsedUri, list[T]]]:
self, pred: Callable[[Diagnostic], bool], f: Callable[[DocumentUri, Diagnostic], T]
) -> Iterator[tuple[DocumentUri, list[T]]]:
"""
Yields `(uri, results)` items with `results` being a list of `f(diagnostic)` for each
diagnostic for this `uri` with `pred(diagnostic) == True`, filtered by `bool(f(diagnostic))`.
Only `uri`s with non-empty `results` are returned. Each `uri` is guaranteed to be yielded
not more than once. Items and results are ordered as they came in from the server.
not more than once.
"""
for uri, diagnostics in self.items():
results: list[T] = list(filter(None, map(functools.partial(f, uri), filter(pred, diagnostics))))
if results:
yield uri, results

def filter_map_diagnostics_flat_async(self, pred: Callable[[Diagnostic], bool],
f: Callable[[ParsedUri, Diagnostic], T]) -> Iterator[tuple[ParsedUri, T]]:
f: Callable[[DocumentUri, Diagnostic], T]) -> Iterator[tuple[DocumentUri, T]]:
"""
Flattened variant of `filter_map_diagnostics_async()`. Yields `(uri, result)` items for each
of the `result`s per `uri` instead. Each `uri` can be yielded more than once. Items are
grouped by `uri` and each `uri` group is guaranteed to appear not more than once. Items are
ordered as they came in from the server.
grouped by `uri` and each `uri` group is guaranteed to appear not more than once.
"""
for uri, results in self.filter_map_diagnostics_async(pred, f):
for result in results:
Expand All @@ -71,18 +95,6 @@ def sum_total_errors_and_warnings_async(self) -> tuple[int, int]:
sum(map(severity_count(DiagnosticSeverity.Warning), self.values())),
)

def diagnostics_by_document_uri(self, document_uri: DocumentUri) -> list[Diagnostic]:
"""
Returns possibly empty list of diagnostic for `document_uri`.
"""
return self.get(parse_uri(document_uri), [])

def diagnostics_by_parsed_uri(self, uri: ParsedUri) -> list[Diagnostic]:
"""
Returns possibly empty list of diagnostic for `uri`.
"""
return self.get(uri, [])


def severity_count(severity: int) -> Callable[[list[Diagnostic]], int]:
def severity_count(diagnostics: list[Diagnostic]) -> int:
Expand Down
2 changes: 1 addition & 1 deletion plugin/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def navigate_diagnostics(view: sublime.View, point: int | None, forward: bool =
return
diagnostics: list[Diagnostic] = []
for session in wm.get_sessions():
diagnostics.extend(session.diagnostics.diagnostics_by_document_uri(uri))
diagnostics.extend(session.diagnostics[uri])
if not diagnostics:
return
# Sort diagnostics by location
Expand Down
103 changes: 66 additions & 37 deletions plugin/core/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .protocol import DiagnosticTag
from .protocol import DidChangeWatchedFilesRegistrationOptions
from .protocol import DidChangeWorkspaceFoldersParams
from .protocol import DocumentDiagnosticReport
from .protocol import DocumentDiagnosticReportKind
from .protocol import DocumentLink
from .protocol import DocumentUri
Expand Down Expand Up @@ -663,6 +664,11 @@ def on_diagnostics_async(
) -> None:
...

def on_document_diagnostic_async(
self, identifier: str | None, version: int, response: DocumentDiagnosticReport
) -> None:
...

def get_document_link_at_point(self, view: sublime.View, point: int) -> DocumentLink | None:
...

Expand Down Expand Up @@ -1282,8 +1288,8 @@ def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[Wor
self.state = ClientStates.STARTING
self.capabilities = Capabilities()
self.diagnostics = DiagnosticsStorage()
self.diagnostics_result_ids: dict[DocumentUri, str | None] = {}
self.workspace_diagnostics_pending_response: int | None = None
self.diagnostics_result_ids: dict[tuple[DocumentUri, str | None], str | None] = {}
self.workspace_diagnostics_pending_responses: dict[str | None, int | None] = {}
self.exiting = False
self._registrations: dict[str, _RegistrationData] = {}
self._init_callback: InitCallback | None = None
Expand Down Expand Up @@ -1377,7 +1383,7 @@ def register_session_buffer_async(self, sb: SessionBufferProtocol) -> None:
data.check_applicable(sb)
uri = sb.get_uri()
if uri:
diagnostics = self.diagnostics.diagnostics_by_document_uri(uri)
diagnostics = self.diagnostics[uri]
if diagnostics:
self._publish_diagnostics_to_session_buffer_async(sb, diagnostics, version=None)

Expand Down Expand Up @@ -1451,15 +1457,23 @@ def can_handle(self, view: sublime.View, scheme: str, capability: str | None, in
return self.has_capability(capability)
return False

@deprecated("Use has_provider instead")
def has_capability(self, capability: str) -> bool:
value = self.get_capability(capability)
return value is not False and value is not None

@deprecated("Use get_providers instead")
def get_capability(self, capability: str) -> Any | None:
if self.config.is_disabled_capability(capability):
return None
return self.capabilities.get(capability)

def get_providers(self, capability_name: str) -> list[Any]:
return self.capabilities.get_all(capability_name)

def has_provider(self, capability_name: str) -> bool:
return bool(self.get_providers(capability_name))

def should_notify_did_open(self) -> bool:
return self.capabilities.should_notify_did_open()

Expand Down Expand Up @@ -1547,7 +1561,11 @@ def initialize_async(
Request.initialize(params), self._handle_initialize_success, self._handle_initialize_error)

def _handle_initialize_success(self, result: InitializeResult) -> None:
self.capabilities.assign(result.get('capabilities', dict()))
capabilities = result.get('capabilities', dict())
self.capabilities.assign(capabilities)
if diagnostic_provider := capabilities.get('diagnosticProvider'):
if identifier := diagnostic_provider.get('identifier'):
self.diagnostics.register(identifier)
if self._workspace_folders and not self._supports_workspace_folders():
self._workspace_folders = self._workspace_folders[:1]
self.state = ClientStates.READY
Expand Down Expand Up @@ -1940,28 +1958,29 @@ def session_views_by_visibility(self) -> tuple[set[SessionViewProtocol], set[Ses
# --- Workspace Pull Diagnostics -----------------------------------------------------------------------------------

def do_workspace_diagnostics_async(self) -> None:
if self.workspace_diagnostics_pending_response:
# The server is probably leaving the request open intentionally, in order to continuously stream updates via
# $/progress notifications.
return
previous_result_ids: list[PreviousResultId] = [
{'uri': uri, 'value': result_id} for uri, result_id in self.diagnostics_result_ids.items()
if result_id is not None
]
params: WorkspaceDiagnosticParams = {'previousResultIds': previous_result_ids}
identifier = self.get_capability("diagnosticProvider.identifier")
if identifier:
params['identifier'] = identifier
self.workspace_diagnostics_pending_response = self.send_request_async(
Request.workspaceDiagnostic(params),
self._on_workspace_diagnostics_async,
self._on_workspace_diagnostics_error_async)
for provider in self.get_providers('diagnosticProvider'):
identifier = provider.get('identifier')
if self.workspace_diagnostics_pending_responses[identifier]:
# The server is probably leaving the request open intentionally, in order to continuously stream updates
# via $/progress notifications.
return
previous_result_ids: list[PreviousResultId] = [
{'uri': uri, 'value': result_id} for (uri, id_), result_id in self.diagnostics_result_ids.items()
if id_ == identifier and result_id is not None
]
params: WorkspaceDiagnosticParams = {'previousResultIds': previous_result_ids}
if identifier:
params['identifier'] = identifier
self.workspace_diagnostics_pending_responses[identifier] = self.send_request_async(
Request.workspaceDiagnostic(params),
functools.partial(self._on_workspace_diagnostics_async, identifier),
functools.partial(self._on_workspace_diagnostics_error_async, identifier))

def _on_workspace_diagnostics_async(
self, response: WorkspaceDiagnosticReport, reset_pending_response: bool = True
self, identifier: str | None, response: WorkspaceDiagnosticReport, reset_pending_response: bool = True
) -> None:
if reset_pending_response:
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None
if not response['items']:
return
window = sublime.active_window()
Expand All @@ -1978,20 +1997,23 @@ def _on_workspace_diagnostics_async(
uri = unparse_uri((scheme, path))
# Note: 'version' is a mandatory field, but some language servers have serialization bugs with null values.
version = diagnostic_report.get('version')
# Skip if outdated
# Note: this is just a necessary, but not a sufficient condition to decide whether the diagnostics for this
# file are likely not accurate anymore, because changes in another file in the meanwhile could have affected
# the diagnostics in this file. If this is the case, a new request is already queued, or updated partial
# results are expected to be streamed by the server.
if isinstance(version, int):
if version is not None:
sb = self.get_session_buffer_for_uri_async(uri)
if sb and sb.version != version:
if not sb:
# There should always be a SessionBuffer if version != None
continue
self.diagnostics_result_ids[uri] = diagnostic_report.get('resultId')
if is_workspace_full_document_diagnostic_report(diagnostic_report):
self.m_textDocument_publishDiagnostics({'uri': uri, 'diagnostics': diagnostic_report['items']})
if sb.version != version:
# Skip if outdated
continue
if is_workspace_full_document_diagnostic_report(diagnostic_report):
diagnostic_report = cast(DocumentDiagnosticReport, diagnostic_report)
sb.on_document_diagnostic_async(identifier, version, diagnostic_report)
else:
# TODO support diagnostics for unopened docuements (version == None)
pass
self.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId')

def _on_workspace_diagnostics_error_async(self, error: ResponseError) -> None:
def _on_workspace_diagnostics_error_async(self, identifier: str | None, error: ResponseError) -> None:
if error['code'] == LSPErrorCodes.ServerCancelled:
data = error.get('data')
if is_diagnostic_server_cancellation_data(data) and data['retriggerRequest']:
Expand All @@ -2000,12 +2022,12 @@ def _on_workspace_diagnostics_error_async(self, error: ResponseError) -> None:
# infinite cycles of cancel -> retrigger, in case the server is busy.

def _retrigger_request() -> None:
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None
self.do_workspace_diagnostics_async()

sublime.set_timeout_async(_retrigger_request, WORKSPACE_DIAGNOSTICS_TIMEOUT)
return
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None

# --- server request handlers --------------------------------------------------------------------------------------

Expand Down Expand Up @@ -2089,7 +2111,7 @@ def m_textDocument_publishDiagnostics(self, params: PublishDiagnosticsParams) ->
if isinstance(reason, str):
return debug("ignoring unsuitable diagnostics for", uri, "reason:", reason)
diagnostics = params["diagnostics"]
self.diagnostics.add_diagnostics_async(uri, diagnostics)
self.diagnostics[uri] = diagnostics
mgr.on_diagnostics_updated()
sb = self.get_session_buffer_for_uri_async(uri)
if sb:
Expand Down Expand Up @@ -2138,6 +2160,9 @@ def m_client_registerCapability(self, params: RegistrationParams, request_id: An
watcher = self._watcher_impl.create(folder.path, [pattern], kind, ignores, self)
file_watchers.append(watcher)
self._dynamic_file_watchers[registration_id] = file_watchers
elif capability_path == 'diagnosticProvider':
if (identifier := options.get('identifier')) is not None:
self.diagnostics.register(identifier)
self.send_response(Response(request_id, None))

def m_client_unregisterCapability(self, params: UnregistrationParams, request_id: Any) -> None:
Expand All @@ -2160,6 +2185,9 @@ def m_client_unregisterCapability(self, params: UnregistrationParams, request_id
if isinstance(discarded, dict):
for sv in self.session_views_async():
sv.on_capability_removed_async(registration_id, discarded)
if capability_path == 'diagnosticProvider':
if data and (identifier := data.options.get('identifier')) is not None:
self.diagnostics.unregister(identifier)
self.send_response(Response(request_id, None))

def m_window_showDocument(self, params: Any, request_id: Any) -> None:
Expand Down Expand Up @@ -2213,8 +2241,9 @@ def m___progress(self, params: ProgressParams) -> None:
request_id = int(token[len(_PARTIAL_RESULT_PROGRESS_PREFIX):])
request = self._response_handlers[request_id][0]
if request.method == "workspace/diagnostic":
# TODO somehow get the identifier (probably needs to be stored based on progress token)
self._on_workspace_diagnostics_async(
cast(WorkspaceDiagnosticReport, value), reset_pending_response=False)
'', cast(WorkspaceDiagnosticReport, value), reset_pending_response=False)
return
# Work Done Progress
# https://microsoft.github.io/language-server-protocol/specifications/specification-current/#workDoneProgress
Expand Down
Loading
Loading