Skip to content

Commit

Permalink
Attempt to fix workspace pull diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
jwortmann committed Dec 12, 2024
1 parent 3daf716 commit ca6c014
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 47 deletions.
9 changes: 5 additions & 4 deletions plugin/core/diagnostics_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class DiagnosticsStorage(MutableMapping):

def __init__(self) -> None:
super().__init__()
self._d: dict[tuple[DocumentUri, str], list[Diagnostic]] = dict()
self._identifiers = {''}
self._d: dict[tuple[DocumentUri, str | None], list[Diagnostic]] = dict()
self._identifiers: set[str | None] = {None}
self._uris: set[DocumentUri] = set()

def __getitem__(self, key: DocumentUri, /) -> list[Diagnostic]:
Expand All @@ -29,8 +29,9 @@ def __getitem__(self, key: DocumentUri, /) -> list[Diagnostic]:
key=lambda diagnostic: Point.from_lsp(diagnostic['range']['start'])
)

def __setitem__(self, key: DocumentUri | tuple[DocumentUri, str], value: list[Diagnostic], /) -> None:
uri, identifier = (normalize_uri(key), '') if isinstance(key, DocumentUri) else (normalize_uri(key[0]), key[1])
def __setitem__(self, key: DocumentUri | tuple[DocumentUri, str | None], value: list[Diagnostic], /) -> None:
uri, identifier = (normalize_uri(key), None) if isinstance(key, DocumentUri) else \
(normalize_uri(key[0]), key[1])
if identifier not in self._identifiers:
raise ValueError(f'identifier {identifier} must be registered first')
if value:
Expand Down
83 changes: 50 additions & 33 deletions plugin/core/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .protocol import DiagnosticTag
from .protocol import DidChangeWatchedFilesRegistrationOptions
from .protocol import DidChangeWorkspaceFoldersParams
from .protocol import DocumentDiagnosticReport
from .protocol import DocumentDiagnosticReportKind
from .protocol import DocumentLink
from .protocol import DocumentUri
Expand Down Expand Up @@ -663,6 +664,11 @@ def on_diagnostics_async(
) -> None:
...

def on_document_diagnostic_async(
self, identifier: str | None, version: int, response: DocumentDiagnosticReport
) -> None:
...

def get_document_link_at_point(self, view: sublime.View, point: int) -> DocumentLink | None:
...

Expand Down Expand Up @@ -1282,8 +1288,8 @@ def __init__(self, manager: Manager, logger: Logger, workspace_folders: list[Wor
self.state = ClientStates.STARTING
self.capabilities = Capabilities()
self.diagnostics = DiagnosticsStorage()
self.diagnostics_result_ids: dict[tuple[DocumentUri, str], str | None] = {}
self.workspace_diagnostics_pending_response: int | None = None
self.diagnostics_result_ids: dict[tuple[DocumentUri, str | None], str | None] = {}
self.workspace_diagnostics_pending_responses: dict[str | None, int | None] = {}
self.exiting = False
self._registrations: dict[str, _RegistrationData] = {}
self._init_callback: InitCallback | None = None
Expand Down Expand Up @@ -1451,15 +1457,23 @@ def can_handle(self, view: sublime.View, scheme: str, capability: str | None, in
return self.has_capability(capability)
return False

@deprecated("Use has_provider instead")
def has_capability(self, capability: str) -> bool:
value = self.get_capability(capability)
return value is not False and value is not None

@deprecated("Use get_providers instead")
def get_capability(self, capability: str) -> Any | None:
if self.config.is_disabled_capability(capability):
return None
return self.capabilities.get(capability)

def get_providers(self, capability_name: str) -> list[Any]:
return self.capabilities.get_all(capability_name)

def has_provider(self, capability_name: str) -> bool:
return bool(self.get_providers(capability_name))

def should_notify_did_open(self) -> bool:
return self.capabilities.should_notify_did_open()

Expand Down Expand Up @@ -1944,29 +1958,29 @@ def session_views_by_visibility(self) -> tuple[set[SessionViewProtocol], set[Ses
# --- Workspace Pull Diagnostics -----------------------------------------------------------------------------------

def do_workspace_diagnostics_async(self) -> None:
# TODO consider separate diagnostic streams (identifiers)
if self.workspace_diagnostics_pending_response:
# The server is probably leaving the request open intentionally, in order to continuously stream updates via
# $/progress notifications.
return
previous_result_ids: list[PreviousResultId] = [
{'uri': uri, 'value': result_id} for (uri, _), result_id in self.diagnostics_result_ids.items()
if result_id is not None
]
params: WorkspaceDiagnosticParams = {'previousResultIds': previous_result_ids}
identifier = self.get_capability("diagnosticProvider.identifier")
if identifier:
params['identifier'] = identifier
self.workspace_diagnostics_pending_response = self.send_request_async(
Request.workspaceDiagnostic(params),
functools.partial(self._on_workspace_diagnostics_async, identifier or ''),
self._on_workspace_diagnostics_error_async)
for provider in self.get_providers('diagnosticProvider'):
identifier = provider.get('identifier')
if self.workspace_diagnostics_pending_responses[identifier]:
# The server is probably leaving the request open intentionally, in order to continuously stream updates
# via $/progress notifications.
return
previous_result_ids: list[PreviousResultId] = [
{'uri': uri, 'value': result_id} for (uri, id_), result_id in self.diagnostics_result_ids.items()
if id_ == identifier and result_id is not None
]
params: WorkspaceDiagnosticParams = {'previousResultIds': previous_result_ids}
if identifier:
params['identifier'] = identifier
self.workspace_diagnostics_pending_responses[identifier] = self.send_request_async(
Request.workspaceDiagnostic(params),
functools.partial(self._on_workspace_diagnostics_async, identifier),
functools.partial(self._on_workspace_diagnostics_error_async, identifier))

def _on_workspace_diagnostics_async(
self, identifier: str, response: WorkspaceDiagnosticReport, reset_pending_response: bool = True
self, identifier: str | None, response: WorkspaceDiagnosticReport, reset_pending_response: bool = True
) -> None:
if reset_pending_response:
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None
if not response['items']:
return
window = sublime.active_window()
Expand All @@ -1983,20 +1997,23 @@ def _on_workspace_diagnostics_async(
uri = unparse_uri((scheme, path))
# Note: 'version' is a mandatory field, but some language servers have serialization bugs with null values.
version = diagnostic_report.get('version')
# Skip if outdated
# Note: this is just a necessary, but not a sufficient condition to decide whether the diagnostics for this
# file are likely not accurate anymore, because changes in another file in the meanwhile could have affected
# the diagnostics in this file. If this is the case, a new request is already queued, or updated partial
# results are expected to be streamed by the server.
if isinstance(version, int):
if version is not None:
sb = self.get_session_buffer_for_uri_async(uri)
if sb and sb.version != version:
if not sb:
# There should always be a SessionBuffer if version != None
continue
if sb.version != version:
# Skip if outdated
continue
if is_workspace_full_document_diagnostic_report(diagnostic_report):
diagnostic_report = cast(DocumentDiagnosticReport, diagnostic_report)
sb.on_document_diagnostic_async(identifier, version, diagnostic_report)
else:
# TODO support diagnostics for unopened docuements (version == None)
pass
self.diagnostics_result_ids[(uri, identifier)] = diagnostic_report.get('resultId')
if is_workspace_full_document_diagnostic_report(diagnostic_report):
self.m_textDocument_publishDiagnostics({'uri': uri, 'diagnostics': diagnostic_report['items']})

def _on_workspace_diagnostics_error_async(self, error: ResponseError) -> None:
def _on_workspace_diagnostics_error_async(self, identifier: str | None, error: ResponseError) -> None:
if error['code'] == LSPErrorCodes.ServerCancelled:
data = error.get('data')
if is_diagnostic_server_cancellation_data(data) and data['retriggerRequest']:
Expand All @@ -2005,12 +2022,12 @@ def _on_workspace_diagnostics_error_async(self, error: ResponseError) -> None:
# infinite cycles of cancel -> retrigger, in case the server is busy.

def _retrigger_request() -> None:
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None
self.do_workspace_diagnostics_async()

sublime.set_timeout_async(_retrigger_request, WORKSPACE_DIAGNOSTICS_TIMEOUT)
return
self.workspace_diagnostics_pending_response = None
self.workspace_diagnostics_pending_responses[identifier] = None

# --- server request handlers --------------------------------------------------------------------------------------

Expand Down
27 changes: 17 additions & 10 deletions plugin/session_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, session_view: SessionViewProtocol, buffer_id: int, uri: Docum
self.diagnostics_flags = 0
self._diagnostics_are_visible = False
self.document_diagnostic_needs_refresh = False
self._document_diagnostic_pending_requests: dict[str, PendingDocumentDiagnosticRequest | None] = {}
self._document_diagnostic_pending_requests: dict[str | None, PendingDocumentDiagnosticRequest | None] = {}
self._last_synced_version = 0
self._last_text_change_time = 0.0
self._diagnostics_debouncer_async = DebouncerNonThreadSafe(async_thread=True)
Expand Down Expand Up @@ -265,21 +265,26 @@ def unregister_capability_async(
for sv in self.session_views:
sv.on_capability_removed_async(registration_id, discarded)

@deprecated("Use get_providers instead")
def get_capability(self, capability_path: str) -> Any | None:
if self.session.config.is_disabled_capability(capability_path):
return None
value = self.capabilities.get(capability_path)
return value if value is not None else self.session.capabilities.get(capability_path)

def get_capability_2(self, capability_path: str) -> list[Any]:
if self.session.config.is_disabled_capability(capability_path):
def get_providers(self, capability_name: str) -> list[Any]:
if self.session.config.is_disabled_capability(capability_name):
return []
return self.capabilities.get_all(capability_path) + self.session.capabilities.get_all(capability_path)
return self.capabilities.get_all(capability_name) + self.session.capabilities.get_all(capability_name)

@deprecated("Use has_provider instead")
def has_capability(self, capability_path: str) -> bool:
value = self.get_capability(capability_path)
return value is not False and value is not None

def has_provider(self, capability_name: str) -> bool:
return bool(self.get_providers(capability_name))

def text_sync_kind(self) -> TextDocumentSyncKind:
value = self.capabilities.text_sync_kind()
return value if value != TextDocumentSyncKind.None_ else self.session.text_sync_kind()
Expand Down Expand Up @@ -514,8 +519,8 @@ def do_document_diagnostic_async(
self._document_diagnostic_pending_requests[identifier] = None
_params: DocumentDiagnosticParams = {'textDocument': text_document_identifier(view)}
identifiers = set()
for registration in self.get_capability_2('diagnosticProvider'):
identifiers.add(registration.get('identifier', ''))
for provider in self.get_providers('diagnosticProvider'):
identifiers.add(provider.get('identifier', ''))
for identifier in identifiers:
params = _params.copy()
if identifier:
Expand All @@ -525,13 +530,15 @@ def do_document_diagnostic_async(
params['previousResultId'] = result_id
request_id = self.session.send_request_async(
Request.documentDiagnostic(params, view),
partial(self._on_document_diagnostic_async, identifier, version),
partial(self.on_document_diagnostic_async, identifier, version),
partial(self._on_document_diagnostic_error_async, identifier, version)
)
self._document_diagnostic_pending_requests[identifier] = \
PendingDocumentDiagnosticRequest(version, request_id)

def _on_document_diagnostic_async(self, identifier: str, version: int, response: DocumentDiagnosticReport) -> None:
def on_document_diagnostic_async(
self, identifier: str | None, version: int, response: DocumentDiagnosticReport
) -> None:
self._document_diagnostic_pending_requests[identifier] = None
view = self.some_view()
if view and view.change_count() == version:
Expand All @@ -541,7 +548,7 @@ def _on_document_diagnostic_async(self, identifier: str, version: int, response:
mgr.on_diagnostics_updated()

def _apply_document_diagnostic_async(
self, identifier: str, version: int, response: DocumentDiagnosticReport
self, identifier: str | None, version: int, response: DocumentDiagnosticReport
) -> None:
self.session.diagnostics_result_ids[(self._last_known_uri, identifier)] = response.get('resultId')
if is_full_document_diagnostic_report(response):
Expand All @@ -554,7 +561,7 @@ def _apply_document_diagnostic_async(
cast(SessionBuffer, sb)._apply_document_diagnostic_async(
identifier, version, cast(DocumentDiagnosticReport, diagnostic_report))

def _on_document_diagnostic_error_async(self, identifier: str, version: int, error: ResponseError) -> None:
def _on_document_diagnostic_error_async(self, identifier: str | None, version: int, error: ResponseError) -> None:
self._document_diagnostic_pending_requests[identifier] = None
if error['code'] == LSPErrorCodes.ServerCancelled:
data = error.get('data')
Expand Down

0 comments on commit ca6c014

Please sign in to comment.