From b25360a7e552ca446d503856e8d059762fe17e40 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 09:16:03 -0500 Subject: [PATCH 1/6] nixos/test-driver: apply ruff check suggestions --- nixos/lib/test-driver/test_driver/driver.py | 21 +++++----- nixos/lib/test-driver/test_driver/logger.py | 43 ++++++++++---------- nixos/lib/test-driver/test_driver/machine.py | 29 ++++++------- 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index 0f01bd6d0ab49..6f37af954bc52 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -3,9 +3,10 @@ import signal import tempfile import threading -from contextlib import contextmanager +from collections.abc import Iterator +from contextlib import AbstractContextManager, contextmanager from pathlib import Path -from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Optional, Union from colorama import Fore, Style @@ -44,17 +45,17 @@ class Driver: and runs the tests""" tests: str - vlans: List[VLan] - machines: List[Machine] - polling_conditions: List[PollingCondition] + vlans: list[VLan] + machines: list[Machine] + polling_conditions: list[PollingCondition] global_timeout: int race_timer: threading.Timer logger: AbstractLogger def __init__( self, - start_scripts: List[str], - vlans: List[int], + start_scripts: list[str], + vlans: list[int], tests: str, out_dir: Path, logger: AbstractLogger, @@ -73,7 +74,7 @@ def __init__( vlans = list(set(vlans)) self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans] - def cmd(scripts: List[str]) -> Iterator[NixStartScript]: + def cmd(scripts: list[str]) -> Iterator[NixStartScript]: for s in scripts: yield NixStartScript(s) @@ -119,7 +120,7 @@ def subtest(self, name: str) -> Iterator[None]: self.logger.error(f'Test "{name}" failed with error: "{e}"') raise e - def test_symbols(self) -> Dict[str, Any]: + def test_symbols(self) -> dict[str, Any]: @contextmanager def subtest(name: str) -> Iterator[None]: return self.subtest(name) @@ -277,7 +278,7 @@ def polling_condition( *, seconds_interval: float = 2.0, description: Optional[str] = None, - ) -> Union[Callable[[Callable], ContextManager], ContextManager]: + ) -> Union[Callable[[Callable], AbstractContextManager], AbstractContextManager]: driver = self class Poll: diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py index 484829254b812..564d39f4f055c 100644 --- a/nixos/lib/test-driver/test_driver/logger.py +++ b/nixos/lib/test-driver/test_driver/logger.py @@ -5,10 +5,11 @@ import time import unicodedata from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import ExitStack, contextmanager from pathlib import Path from queue import Empty, Queue -from typing import Any, Dict, Iterator, List +from typing import Any from xml.sax.saxutils import XMLGenerator from xml.sax.xmlreader import AttributesImpl @@ -18,17 +19,17 @@ class AbstractLogger(ABC): @abstractmethod - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + def log(self, message: str, attributes: dict[str, str] = {}) -> None: pass @abstractmethod @contextmanager - def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: pass @abstractmethod @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: pass @abstractmethod @@ -68,11 +69,11 @@ def __init__(self, outfile: Path) -> None: self._print_serial_logs = True atexit.register(self.close) - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + def log(self, message: str, attributes: dict[str, str] = {}) -> None: self.tests[self.currentSubtest].stdout += message + os.linesep @contextmanager - def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: old_test = self.currentSubtest self.tests.setdefault(name, self.TestCaseState()) self.currentSubtest = name @@ -82,7 +83,7 @@ def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: self.currentSubtest = old_test @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: self.log(message) yield @@ -123,25 +124,25 @@ def close(self) -> None: class CompositeLogger(AbstractLogger): - def __init__(self, logger_list: List[AbstractLogger]) -> None: + def __init__(self, logger_list: list[AbstractLogger]) -> None: self.logger_list = logger_list def add_logger(self, logger: AbstractLogger) -> None: self.logger_list.append(logger) - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + def log(self, message: str, attributes: dict[str, str] = {}) -> None: for logger in self.logger_list: logger.log(message, attributes) @contextmanager - def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: with ExitStack() as stack: for logger in self.logger_list: stack.enter_context(logger.subtest(name, attributes)) yield @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: with ExitStack() as stack: for logger in self.logger_list: stack.enter_context(logger.nested(message, attributes)) @@ -173,7 +174,7 @@ class TerminalLogger(AbstractLogger): def __init__(self) -> None: self._print_serial_logs = True - def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: + def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str: if "machine" in attributes: return f"{attributes['machine']}: {message}" return message @@ -182,16 +183,16 @@ def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: def _eprint(*args: object, **kwargs: Any) -> None: print(*args, file=sys.stderr, **kwargs) - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + def log(self, message: str, attributes: dict[str, str] = {}) -> None: self._eprint(self.maybe_prefix(message, attributes)) @contextmanager - def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: with self.nested("subtest: " + name, attributes): yield @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: self._eprint( self.maybe_prefix( Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes @@ -241,12 +242,12 @@ def close(self) -> None: def sanitise(self, message: str) -> str: return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") - def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: + def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str: if "machine" in attributes: return f"{attributes['machine']}: {message}" return message - def log_line(self, message: str, attributes: Dict[str, str]) -> None: + def log_line(self, message: str, attributes: dict[str, str]) -> None: self.xml.startElement("line", attrs=AttributesImpl(attributes)) self.xml.characters(message) self.xml.endElement("line") @@ -260,7 +261,7 @@ def warning(self, *args, **kwargs) -> None: # type: ignore def error(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) - def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + def log(self, message: str, attributes: dict[str, str] = {}) -> None: self.drain_log_queue() self.log_line(message, attributes) @@ -273,7 +274,7 @@ def log_serial(self, message: str, machine: str) -> None: self.enqueue({"msg": message, "machine": machine, "type": "serial"}) - def enqueue(self, item: Dict[str, str]) -> None: + def enqueue(self, item: dict[str, str]) -> None: self.queue.put(item) def drain_log_queue(self) -> None: @@ -287,12 +288,12 @@ def drain_log_queue(self) -> None: pass @contextmanager - def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: with self.nested("subtest: " + name, attributes): yield @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: self.xml.startElement("nest", attrs=AttributesImpl({})) self.xml.startElement("head", attrs=AttributesImpl(attributes)) self.xml.characters(message) diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py index 7a602ce6608fa..f4ec494beee23 100644 --- a/nixos/lib/test-driver/test_driver/machine.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -12,10 +12,11 @@ import tempfile import threading import time +from collections.abc import Iterable from contextlib import _GeneratorContextManager, nullcontext from pathlib import Path from queue import Queue -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Optional from test_driver.logger import AbstractLogger @@ -91,7 +92,7 @@ def make_command(args: list) -> str: def _perform_ocr_on_screenshot( screenshot_path: str, model_ids: Iterable[int] -) -> List[str]: +) -> list[str]: if shutil.which("tesseract") is None: raise Exception("OCR requested but enableOCR is false") @@ -260,7 +261,7 @@ class Machine: # Store last serial console lines for use # of wait_for_console_text last_lines: Queue = Queue() - callbacks: List[Callable] + callbacks: list[Callable] def __repr__(self) -> str: return f"" @@ -273,7 +274,7 @@ def __init__( logger: AbstractLogger, name: str = "machine", keep_vm_state: bool = False, - callbacks: Optional[List[Callable]] = None, + callbacks: Optional[list[Callable]] = None, ) -> None: self.out_dir = out_dir self.tmp_dir = tmp_dir @@ -314,7 +315,7 @@ def log(self, msg: str) -> None: def log_serial(self, msg: str) -> None: self.logger.log_serial(msg, self.name) - def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: + def nested(self, msg: str, attrs: dict[str, str] = {}) -> _GeneratorContextManager: my_attrs = {"machine": self.name} my_attrs.update(attrs) return self.logger.nested(msg, my_attrs) @@ -373,7 +374,7 @@ def check_active(_: Any) -> bool: ): retry(check_active, timeout) - def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]: + def get_unit_info(self, unit: str, user: Optional[str] = None) -> dict[str, str]: status, lines = self.systemctl(f'--no-pager show "{unit}"', user) if status != 0: raise Exception( @@ -384,7 +385,7 @@ def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str] line_pattern = re.compile(r"^([^=]+)=(.*)$") - def tuple_from_line(line: str) -> Tuple[str, str]: + def tuple_from_line(line: str) -> tuple[str, str]: match = line_pattern.match(line) assert match is not None return match[1], match[2] @@ -424,7 +425,7 @@ def get_unit_property( assert match[1] == property, invalid_output_message return match[2] - def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]: + def systemctl(self, q: str, user: Optional[str] = None) -> tuple[int, str]: """ Runs `systemctl` commands with optional support for `systemctl --user` @@ -481,7 +482,7 @@ def execute( check_return: bool = True, check_output: bool = True, timeout: Optional[int] = 900, - ) -> Tuple[int, str]: + ) -> tuple[int, str]: """ Execute a shell command, returning a list `(status, stdout)`. @@ -798,10 +799,10 @@ def port_is_closed(_: Any) -> bool: with self.nested(f"waiting for TCP port {port} on {addr} to be closed"): retry(port_is_closed, timeout) - def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: + def start_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]: return self.systemctl(f"start {jobname}", user) - def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: + def stop_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]: return self.systemctl(f"stop {jobname}", user) def wait_for_job(self, jobname: str) -> None: @@ -942,13 +943,13 @@ def dump_tty_contents(self, tty: str) -> None: """Debugging: Dump the contents of the TTY""" self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat") - def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]: + def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]: with tempfile.TemporaryDirectory() as tmpdir: screenshot_path = os.path.join(tmpdir, "ppm") self.send_monitor_command(f"screendump {screenshot_path}") return _perform_ocr_on_screenshot(screenshot_path, model_ids) - def get_screen_text_variants(self) -> List[str]: + def get_screen_text_variants(self) -> list[str]: """ Return a list of different interpretations of what is currently visible on the machine's screen using optical character @@ -1168,7 +1169,7 @@ def check_x(_: Any) -> bool: with self.nested("waiting for the X11 server"): retry(check_x, timeout) - def get_window_names(self) -> List[str]: + def get_window_names(self) -> list[str]: return self.succeed( r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'" ).splitlines() From 42d4046e94dcb653eddc8e21556f72af28c207d5 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 10:02:43 -0500 Subject: [PATCH 2/6] nixos/test-driver: format with nixfmt --- nixos/lib/test-driver/default.nix | 61 ++++++++++++++++++------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix index 26652db6016e6..b518a25dab289 100644 --- a/nixos/lib/test-driver/default.nix +++ b/nixos/lib/test-driver/default.nix @@ -1,17 +1,18 @@ -{ lib -, python3Packages -, enableOCR ? false -, qemu_pkg ? qemu_test -, coreutils -, imagemagick_light -, netpbm -, qemu_test -, socat -, ruff -, tesseract4 -, vde2 -, extraPythonPackages ? (_ : []) -, nixosTests +{ + lib, + python3Packages, + enableOCR ? false, + qemu_pkg ? qemu_test, + coreutils, + imagemagick_light, + netpbm, + qemu_test, + socat, + ruff, + tesseract4, + vde2, + extraPythonPackages ? (_: [ ]), + nixosTests, }: let fs = lib.fileset; @@ -29,17 +30,21 @@ python3Packages.buildPythonApplication { }; pyproject = true; - propagatedBuildInputs = [ - coreutils - netpbm - python3Packages.colorama - python3Packages.junit-xml - python3Packages.ptpython - qemu_pkg - socat - vde2 - ] - ++ (lib.optionals enableOCR [ imagemagick_light tesseract4 ]) + propagatedBuildInputs = + [ + coreutils + netpbm + python3Packages.colorama + python3Packages.junit-xml + python3Packages.ptpython + qemu_pkg + socat + vde2 + ] + ++ (lib.optionals enableOCR [ + imagemagick_light + tesseract4 + ]) ++ extraPythonPackages python3Packages; nativeBuildInputs = [ @@ -51,7 +56,11 @@ python3Packages.buildPythonApplication { }; doCheck = true; - nativeCheckInputs = with python3Packages; [ mypy ruff black ]; + nativeCheckInputs = with python3Packages; [ + mypy + ruff + black + ]; checkPhase = '' echo -e "\x1b[32m## run mypy\x1b[0m" mypy test_driver extract-docstrings.py From ef2d3c542a89f3100ef12f145972b30c74548b0c Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 10:05:19 -0500 Subject: [PATCH 3/6] nixos/test-driver: modernize --- nixos/lib/test-driver/default.nix | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix index b518a25dab289..cf60ebf9b0637 100644 --- a/nixos/lib/test-driver/default.nix +++ b/nixos/lib/test-driver/default.nix @@ -20,6 +20,8 @@ in python3Packages.buildPythonApplication { pname = "nixos-test-driver"; version = "1.1"; + pyproject = true; + src = fs.toSource { root = ./.; fileset = fs.unions [ @@ -28,39 +30,45 @@ python3Packages.buildPythonApplication { ./extract-docstrings.py ]; }; - pyproject = true; + + build-system = with python3Packages; [ + setuptools + ]; + + dependencies = + with python3Packages; + [ + colorama + junit-xml + ptpython + ] + ++ extraPythonPackages python3Packages; propagatedBuildInputs = [ coreutils netpbm - python3Packages.colorama - python3Packages.junit-xml - python3Packages.ptpython qemu_pkg socat vde2 ] - ++ (lib.optionals enableOCR [ + ++ lib.optionals enableOCR [ imagemagick_light tesseract4 - ]) - ++ extraPythonPackages python3Packages; - - nativeBuildInputs = [ - python3Packages.setuptools - ]; + ]; passthru.tests = { inherit (nixosTests.nixos-test-driver) driver-timeout; }; doCheck = true; + nativeCheckInputs = with python3Packages; [ mypy ruff black ]; + checkPhase = '' echo -e "\x1b[32m## run mypy\x1b[0m" mypy test_driver extract-docstrings.py From e23f1733c6fc1b2fecc5e78eb14fa8cb7666daf0 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 10:07:29 -0500 Subject: [PATCH 4/6] nixos/test-driver: use ruff format in place of black --- nixos/lib/test-driver/default.nix | 7 +++---- nixos/lib/test-driver/pyproject.toml | 5 ----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix index cf60ebf9b0637..35471f74949c4 100644 --- a/nixos/lib/test-driver/default.nix +++ b/nixos/lib/test-driver/default.nix @@ -66,15 +66,14 @@ python3Packages.buildPythonApplication { nativeCheckInputs = with python3Packages; [ mypy ruff - black ]; checkPhase = '' echo -e "\x1b[32m## run mypy\x1b[0m" mypy test_driver extract-docstrings.py - echo -e "\x1b[32m## run ruff\x1b[0m" + echo -e "\x1b[32m## run ruff check\x1b[0m" ruff check . - echo -e "\x1b[32m## run black\x1b[0m" - black --check --diff . + echo -e "\x1b[32m## run ruff format\x1b[0m" + ruff format --check --diff . ''; } diff --git a/nixos/lib/test-driver/pyproject.toml b/nixos/lib/test-driver/pyproject.toml index 714139bc1b25c..fe2ce75fd632c 100644 --- a/nixos/lib/test-driver/pyproject.toml +++ b/nixos/lib/test-driver/pyproject.toml @@ -35,11 +35,6 @@ ignore_missing_imports = true module = "junit_xml.*" ignore_missing_imports = true -[tool.black] -line-length = 88 -target-version = ['py39'] -include = '\.pyi?$' - [tool.mypy] warn_redundant_casts = true disallow_untyped_calls = true From 5b5f018586e760f5fd26b7269dadb78fe7095bf4 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 10:09:09 -0500 Subject: [PATCH 5/6] ruff: add nixosTests.nixos-test-driver.busybox to passthru.tests --- pkgs/by-name/ru/ruff/package.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkgs/by-name/ru/ruff/package.nix b/pkgs/by-name/ru/ruff/package.nix index 8811d14822935..2449165caeb0e 100644 --- a/pkgs/by-name/ru/ruff/package.nix +++ b/pkgs/by-name/ru/ruff/package.nix @@ -11,6 +11,7 @@ nix-update-script, versionCheckHook, libiconv, + nixosTests, }: python3Packages.buildPythonPackage rec { @@ -76,6 +77,7 @@ python3Packages.buildPythonPackage rec { passthru = { tests = { inherit ruff-lsp; + nixos-test-driver-busybox = nixosTests.nixos-test-driver.busybox; }; updateScript = nix-update-script { }; }; From 172a35f8ce8d474f4e9e6ef57c369d660fb15ee4 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 22 Nov 2024 10:26:04 -0500 Subject: [PATCH 6/6] nixos/test-driver: target python 3.12 --- nixos/lib/test-driver/pyproject.toml | 1 + nixos/lib/test-driver/test_driver/driver.py | 12 +++--- nixos/lib/test-driver/test_driver/machine.py | 42 +++++++++---------- .../test_driver/polling_condition.py | 8 ++-- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/nixos/lib/test-driver/pyproject.toml b/nixos/lib/test-driver/pyproject.toml index fe2ce75fd632c..ac83eed268d92 100644 --- a/nixos/lib/test-driver/pyproject.toml +++ b/nixos/lib/test-driver/pyproject.toml @@ -17,6 +17,7 @@ find = {} test_driver = ["py.typed"] [tool.ruff] +target-version = "py312" line-length = 88 lint.select = ["E", "F", "I", "U", "N"] diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py index 6f37af954bc52..ca778a576f722 100644 --- a/nixos/lib/test-driver/test_driver/driver.py +++ b/nixos/lib/test-driver/test_driver/driver.py @@ -3,10 +3,10 @@ import signal import tempfile import threading -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import AbstractContextManager, contextmanager from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any from colorama import Fore, Style @@ -208,7 +208,7 @@ def create_machine( self, start_command: str | dict, *, - name: Optional[str] = None, + name: str | None = None, keep_vm_state: bool = False, ) -> Machine: # Legacy args handling @@ -274,11 +274,11 @@ def check_polling_conditions(self) -> None: def polling_condition( self, - fun_: Optional[Callable] = None, + fun_: Callable | None = None, *, seconds_interval: float = 2.0, - description: Optional[str] = None, - ) -> Union[Callable[[Callable], AbstractContextManager], AbstractContextManager]: + description: str | None = None, + ) -> Callable[[Callable], AbstractContextManager] | AbstractContextManager: driver = self class Poll: diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py index f4ec494beee23..c423ad8a3fc07 100644 --- a/nixos/lib/test-driver/test_driver/machine.py +++ b/nixos/lib/test-driver/test_driver/machine.py @@ -12,11 +12,11 @@ import tempfile import threading import time -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import _GeneratorContextManager, nullcontext from pathlib import Path from queue import Queue -from typing import Any, Callable, Optional +from typing import Any from test_driver.logger import AbstractLogger @@ -249,12 +249,12 @@ class Machine: start_command: StartCommand keep_vm_state: bool - process: Optional[subprocess.Popen] - pid: Optional[int] - monitor: Optional[socket.socket] - qmp_client: Optional[QMPSession] - shell: Optional[socket.socket] - serial_thread: Optional[threading.Thread] + process: subprocess.Popen | None + pid: int | None + monitor: socket.socket | None + qmp_client: QMPSession | None + shell: socket.socket | None + serial_thread: threading.Thread | None booted: bool connected: bool @@ -274,7 +274,7 @@ def __init__( logger: AbstractLogger, name: str = "machine", keep_vm_state: bool = False, - callbacks: Optional[list[Callable]] = None, + callbacks: list[Callable] | None = None, ) -> None: self.out_dir = out_dir self.tmp_dir = tmp_dir @@ -344,7 +344,7 @@ def send_monitor_command(self, command: str) -> str: return self.wait_for_monitor_prompt() def wait_for_unit( - self, unit: str, user: Optional[str] = None, timeout: int = 900 + self, unit: str, user: str | None = None, timeout: int = 900 ) -> None: """ Wait for a systemd unit to get into "active" state. @@ -374,7 +374,7 @@ def check_active(_: Any) -> bool: ): retry(check_active, timeout) - def get_unit_info(self, unit: str, user: Optional[str] = None) -> dict[str, str]: + def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]: status, lines = self.systemctl(f'--no-pager show "{unit}"', user) if status != 0: raise Exception( @@ -400,7 +400,7 @@ def get_unit_property( self, unit: str, property: str, - user: Optional[str] = None, + user: str | None = None, ) -> str: status, lines = self.systemctl( f'--no-pager show "{unit}" --property="{property}"', @@ -425,7 +425,7 @@ def get_unit_property( assert match[1] == property, invalid_output_message return match[2] - def systemctl(self, q: str, user: Optional[str] = None) -> tuple[int, str]: + def systemctl(self, q: str, user: str | None = None) -> tuple[int, str]: """ Runs `systemctl` commands with optional support for `systemctl --user` @@ -481,7 +481,7 @@ def execute( command: str, check_return: bool = True, check_output: bool = True, - timeout: Optional[int] = 900, + timeout: int | None = 900, ) -> tuple[int, str]: """ Execute a shell command, returning a list `(status, stdout)`. @@ -549,7 +549,7 @@ def execute( return (rc, output.decode(errors="replace")) - def shell_interact(self, address: Optional[str] = None) -> None: + def shell_interact(self, address: str | None = None) -> None: """ Allows you to directly interact with the guest shell. This should only be used during test development, not in production tests. @@ -596,7 +596,7 @@ def console_interact(self) -> None: break self.send_console(char.decode()) - def succeed(self, *commands: str, timeout: Optional[int] = None) -> str: + def succeed(self, *commands: str, timeout: int | None = None) -> str: """ Execute a shell command, raising an exception if the exit status is not zero, otherwise returning the standard output. Similar to `execute`, @@ -613,7 +613,7 @@ def succeed(self, *commands: str, timeout: Optional[int] = None) -> str: output += out return output - def fail(self, *commands: str, timeout: Optional[int] = None) -> str: + def fail(self, *commands: str, timeout: int | None = None) -> str: """ Like `succeed`, but raising an exception if the command returns a zero status. @@ -725,7 +725,7 @@ def tty_matches(last: bool) -> bool: with self.nested(f"waiting for {regexp} to appear on tty {tty}"): retry(tty_matches, timeout) - def send_chars(self, chars: str, delay: Optional[float] = 0.01) -> None: + def send_chars(self, chars: str, delay: float | None = 0.01) -> None: """ Simulate typing a sequence of characters on the virtual keyboard, e.g., `send_chars("foobar\n")` will type the string `foobar` @@ -799,10 +799,10 @@ def port_is_closed(_: Any) -> bool: with self.nested(f"waiting for TCP port {port} on {addr} to be closed"): retry(port_is_closed, timeout) - def start_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]: + def start_job(self, jobname: str, user: str | None = None) -> tuple[int, str]: return self.systemctl(f"start {jobname}", user) - def stop_job(self, jobname: str, user: Optional[str] = None) -> tuple[int, str]: + def stop_job(self, jobname: str, user: str | None = None) -> tuple[int, str]: return self.systemctl(f"stop {jobname}", user) def wait_for_job(self, jobname: str) -> None: @@ -1029,7 +1029,7 @@ def console_matches(_: Any) -> bool: pass def send_key( - self, key: str, delay: Optional[float] = 0.01, log: Optional[bool] = True + self, key: str, delay: float | None = 0.01, log: bool | None = True ) -> None: """ Simulate pressing keys on the virtual keyboard, e.g., diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py index 1cccaf2c71e74..1a8091cf44719 100644 --- a/nixos/lib/test-driver/test_driver/polling_condition.py +++ b/nixos/lib/test-driver/test_driver/polling_condition.py @@ -1,6 +1,6 @@ import time +from collections.abc import Callable from math import isfinite -from typing import Callable, Optional from test_driver.logger import AbstractLogger @@ -12,7 +12,7 @@ class PollingConditionError(Exception): class PollingCondition: condition: Callable[[], bool] seconds_interval: float - description: Optional[str] + description: str | None logger: AbstractLogger last_called: float @@ -20,10 +20,10 @@ class PollingCondition: def __init__( self, - condition: Callable[[], Optional[bool]], + condition: Callable[[], bool | None], logger: AbstractLogger, seconds_interval: float = 2.0, - description: Optional[str] = None, + description: str | None = None, ): self.condition = condition # type: ignore self.seconds_interval = seconds_interval