Skip to content

Commit

Permalink
lsp-devtools: pyupgrade --py39-plus
Browse files Browse the repository at this point in the history
  • Loading branch information
alcarney committed Nov 6, 2024
1 parent a7abb52 commit f4558ad
Show file tree
Hide file tree
Showing 17 changed files with 102 additions and 135 deletions.
7 changes: 4 additions & 3 deletions lib/lsp-devtools/lsp_devtools/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import argparse
import asyncio
import subprocess
import sys
from typing import List

from .agent import Agent
from .agent import RPCMessage
Expand Down Expand Up @@ -31,7 +32,7 @@ async def forward_stderr(server: asyncio.subprocess.Process):
sys.stderr.buffer.write(line)


async def main(args, extra: List[str]):
async def main(args, extra: list[str]):
if extra is None:
print("Missing server start command", file=sys.stderr)
return 1
Expand All @@ -54,7 +55,7 @@ async def main(args, extra: List[str]):
)


def run_agent(args, extra: List[str]):
def run_agent(args, extra: list[str]):
asyncio.run(main(args, extra))


Expand Down
25 changes: 10 additions & 15 deletions lib/lsp-devtools/lsp_devtools/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
import attrs

if typing.TYPE_CHECKING:
from collections.abc import Coroutine
from typing import Any
from typing import BinaryIO
from typing import Callable
from typing import Coroutine
from typing import Dict
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union

MessageHandler = Callable[[bytes], Union[None, Coroutine[Any, Any, None]]]
Expand All @@ -35,9 +31,9 @@
class RPCMessage:
"""A Json-RPC message."""

headers: Dict[str, str]
headers: dict[str, str]

body: Dict[str, Any]
body: dict[str, Any]

def __getitem__(self, key: str):
return self.headers[key]
Expand All @@ -46,8 +42,8 @@ def __getitem__(self, key: str):
def parse_rpc_message(data: bytes) -> RPCMessage:
"""Parse a JSON-RPC message from the given set of bytes."""

headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
headers: dict[str, str] = {}
body: dict[str, Any] | None = None
headers_complete = False

for line in data.split(b"\r\n"):
Expand Down Expand Up @@ -118,7 +114,7 @@ async def aio_readline(reader: asyncio.StreamReader, message_handler: MessageHan

async def get_streams(
stdin, stdout
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""Convert blocking stdin/stdout streams into async streams."""
loop = asyncio.get_running_loop()

Expand Down Expand Up @@ -150,9 +146,9 @@ def __init__(
self.handler = handler
self.session_id = str(uuid4())

self._tasks: Set[asyncio.Task] = set()
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self._tasks: set[asyncio.Task] = set()
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None

async def start(self):
# Get async versions of stdin/stdout
Expand Down Expand Up @@ -227,8 +223,7 @@ async def stop(self):
self.server.kill()

args = {}
if sys.version_info >= (3, 9):
args["msg"] = "lsp-devtools agent is stopping."
args["msg"] = "lsp-devtools agent is stopping."

# Cancel the tasks connecting client to server
for task in self._tasks:
Expand Down
6 changes: 2 additions & 4 deletions lib/lsp-devtools/lsp_devtools/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

if typing.TYPE_CHECKING:
from typing import Any
from typing import List
from typing import Optional

# from websockets.client import WebSocketClientProtocol

Expand Down Expand Up @@ -45,14 +43,14 @@ def __init__(self):
protocol_cls=AgentProtocol, converter_factory=default_converter
)
self.connected = False
self._buffer: List[bytes] = []
self._buffer: list[bytes] = []

def _report_server_error(self, error, source):
# Bail on error
# TODO: Report the actual error somehow
self._stop_event.set()

def feature(self, feature_name: str, options: Optional[Any] = None):
def feature(self, feature_name: str, options: Any | None = None):
return self.protocol.fm.feature(feature_name, options)

# TODO: Upstream this... or at least something equivalent.
Expand Down
16 changes: 7 additions & 9 deletions lib/lsp-devtools/lsp_devtools/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

if typing.TYPE_CHECKING:
from typing import Any
from typing import List
from typing import Optional

from lsp_devtools.agent.agent import MessageHandler

Expand All @@ -29,8 +27,8 @@ class AgentServer(Server):
def __init__(
self,
*args,
logger: Optional[logging.Logger] = None,
handler: Optional[MessageHandler] = None,
logger: logging.Logger | None = None,
handler: MessageHandler | None = None,
**kwargs,
):
if "protocol_cls" not in kwargs:
Expand All @@ -43,19 +41,19 @@ def __init__(

self.logger = logger or logging.getLogger(__name__)
self.handler = handler or self.lsp.data_received
self.db: Optional[Database] = None
self.db: Database | None = None

self._client_buffer: List[str] = []
self._server_buffer: List[str] = []
self._tcp_server: Optional[asyncio.Task] = None
self._client_buffer: list[str] = []
self._server_buffer: list[str] = []
self._tcp_server: asyncio.Task | None = None

def _report_server_error(self, exc: Exception, source):
"""Report internal server errors."""
tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
self.logger.error("%s: %s", type(exc).__name__, exc)
self.logger.debug("%s", tb)

def feature(self, feature_name: str, options: Optional[Any] = None):
def feature(self, feature_name: str, options: Any | None = None):
return self.lsp.fm.feature(feature_name, options)

async def start_tcp(self, host: str, port: int) -> None: # type: ignore[override]
Expand Down
9 changes: 5 additions & 4 deletions lib/lsp-devtools/lsp_devtools/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import argparse
import asyncio
import logging
import os
import pathlib
from typing import List
from uuid import uuid4

import platformdirs
Expand Down Expand Up @@ -54,7 +55,7 @@ class LSPClient(App):
]

def __init__(
self, db: Database, server_command: List[str], session: str, *args, **kwargs
self, db: Database, server_command: list[str], session: str, *args, **kwargs
):
super().__init__(*args, **kwargs)

Expand All @@ -65,7 +66,7 @@ def __init__(
self.server_command = server_command
self.lsp_client = LanguageClient()

self._async_tasks: List[asyncio.Task] = []
self._async_tasks: list[asyncio.Task] = []

def compose(self) -> ComposeResult:
message_viewer = MessageViewer("")
Expand Down Expand Up @@ -140,7 +141,7 @@ async def action_quit(self):
await super().action_quit()


def client(args, extra: List[str]):
def client(args, extra: list[str]):
if len(extra) == 0:
raise ValueError("Missing server command.")

Expand Down
3 changes: 1 addition & 2 deletions lib/lsp-devtools/lsp_devtools/client/editor/text_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import pathlib
import typing
from typing import List
from typing import Union

from lsprotocol import types
Expand All @@ -15,7 +14,7 @@
if typing.TYPE_CHECKING:
from lsp_devtools.client.lsp import LanguageClient

CompletionResult = Union[List[types.CompletionItem], types.CompletionList, None]
CompletionResult = Union[list[types.CompletionItem], types.CompletionList, None]


# TODO: Refactor to
Expand Down
18 changes: 5 additions & 13 deletions lib/lsp-devtools/lsp_devtools/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,17 @@
import json
import logging
import pathlib
import sys
from contextlib import asynccontextmanager
from importlib import resources
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Set

import aiosqlite
from textual.app import App
from textual.message import Message

from lsp_devtools.handlers import LspMessage

if sys.version_info < (3, 9):
import importlib_resources as resources
else:
from importlib import resources # type: ignore[no-redef]


class Database:
"""Controls access to the backing sqlite database."""
Expand All @@ -32,7 +24,7 @@ def __init__(self, dbpath: Optional[pathlib.Path] = None):
self.dbpath = dbpath or ":memory:"
self.db: Optional[aiosqlite.Connection] = None
self.app: Optional[App] = None
self._handlers: Dict[str, set] = {}
self._handlers: dict[str, set] = {}

async def close(self):
if self.db:
Expand Down Expand Up @@ -106,8 +98,8 @@ async def get_messages(
"""

base_query = "SELECT rowid, * FROM protocol"
where: List[str] = []
parameters: List[Any] = []
where: list[str] = []
parameters: list[Any] = []

if session:
where.append("session = ?")
Expand Down Expand Up @@ -151,7 +143,7 @@ class DatabaseLogHandler(logging.Handler):
def __init__(self, db: Database, *args, **kwargs):
super().__init__(*args, **kwargs)
self.db = db
self._tasks: Set[asyncio.Task] = set()
self._tasks: set[asyncio.Task] = set()

def emit(self, record: logging.LogRecord):
body = json.loads(record.args[0]) # type: ignore
Expand Down
13 changes: 6 additions & 7 deletions lib/lsp-devtools/lsp_devtools/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import attrs

if typing.TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any
from typing import Literal
from typing import Mapping
from typing import Optional

MessageSource = Literal["client", "server"]

Expand All @@ -37,19 +36,19 @@ class LspMessage:
source: MessageSource
"""Indicates if the message was sent by the client or the server."""

id: Optional[str]
id: str | None
"""The ``id`` field, if it exists."""

method: Optional[str]
method: str | None
"""The ``method`` field, if it exists."""

params: Optional[Any] = attrs.field(converter=maybe_json)
params: Any | None = attrs.field(converter=maybe_json)
"""The ``params`` field, if it exists."""

result: Optional[Any] = attrs.field(converter=maybe_json)
result: Any | None = attrs.field(converter=maybe_json)
"""The ``result`` field, if it exists."""

error: Optional[Any] = attrs.field(converter=maybe_json)
error: Any | None = attrs.field(converter=maybe_json)
"""The ``error`` field, if it exists."""

@classmethod
Expand Down
7 changes: 1 addition & 6 deletions lib/lsp-devtools/lsp_devtools/handlers/sql.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import json
import pathlib
import sqlite3
import sys
from contextlib import closing
from importlib import resources

from lsp_devtools.handlers import LspHandler
from lsp_devtools.handlers import LspMessage

if sys.version_info < (3, 9):
import importlib_resources as resources
else:
from importlib import resources # type: ignore[no-redef]


class SqlHandler(LspHandler):
"""A logging handler that sends log records to a SQL database"""
Expand Down
Loading

0 comments on commit f4558ad

Please sign in to comment.