diff --git a/Makefile b/Makefile index 0c556e85..757df8a6 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,9 @@ serve-docs: clean-docs: uv run --isolated --all-packages --group docs $(MAKE) -C docs clean +doctest: + uv run --isolated --all-packages --group docs $(MAKE) -C docs doctest + test-%: packages/% uv run --isolated --directory $< pytest @@ -33,7 +36,7 @@ clean-test: sync: uv sync --all-packages --all-extras -test: test-packages +test: test-packages doctest generate: buf generate diff --git a/docs/source/api-reference/adapters/network.md b/docs/source/api-reference/adapters/network.md index 8788b863..0bd8b149 100644 --- a/docs/source/api-reference/adapters/network.md +++ b/docs/source/api-reference/adapters/network.md @@ -41,48 +41,50 @@ export: path: /tmp/test.sock ``` -Forward a remote TCP port to a local TCP port - -```{testcode} -# random port on localhost -with TcpPortforwardAdapter(client.tcp_port) as addr: - print(addr[0], addr[1]) # 127.0.0.1 38406 - -# specific address and port -with TcpPortforwardAdapter(client.tcp_port, local_host="192.0.2.1", local_port=8080) as addr: - print(addr[0], addr[1]) # 192.0.2.1 8080 +### Forward a remote TCP port to a local TCP port + +```{doctest} +>>> # random port on localhost +>>> with TcpPortforwardAdapter(client=client.tcp_port) as addr: +... print(addr[0], addr[1]) +127.0.0.1 ... +>>> +>>> # specific address and port +>>> with TcpPortforwardAdapter(client=client.tcp_port, local_host="127.0.0.2", local_port=8080) as addr: +... print(addr[0], addr[1]) +127.0.0.2 8080 ``` -Forward a remote Unix domain socket to a local socket - -```{testcode} -with UnixPortforwardAdapter(client.unix_socket) as addr: - print(addr) # /tmp/jumpstarter-w30wxu64/socket - -# the type of the remote socket and the local one doesn't have to match -# e.g. forward a remote Unix domain socket to a local TCP port -with TcpPortforwardAdapter(client.unix_socket) as addr: - print(addr[0], addr[1]) # 127.0.0.1 38406 +### Forward a remote Unix domain socket to a local socket + +```{doctest} +>>> with UnixPortforwardAdapter(client=client.unix_socket) as addr: +... print(addr) +/tmp/jumpstarter-.../socket +>>> # the type of the remote socket and the local one doesn't have to match +>>> # e.g. forward a remote Unix domain socket to a local TCP port +>>> with TcpPortforwardAdapter(client=client.unix_socket) as addr: +... print(addr[0], addr[1]) +127.0.0.1 ... ``` Connect to a remote TCP port with a web-based VNC client -```{testcode} -with NovncAdapter(client.tcp_port) as url: - print(url) # https://novnc.com/noVNC/vnc.html?autoconnect=1&reconnect=1&host=127.0.0.1&port=36459 - # open the url in browser to access the VNC client +```{doctest} +>>> with NovncAdapter(client=client.tcp_port) as url: +... print(url) # open the url in browser to access the VNC client +https://novnc.com/noVNC/vnc.html?autoconnect=1&reconnect=1&host=127.0.0.1&port=... ``` Interact with a remote TCP port as if it's a serial console See [pexpect](https://pexpect.readthedocs.io/en/stable/api/fdpexpect.html) for API documentation -```{testcode} -with PexpectAdapter(client.tcp_port) as expect: - expect.expect("localhost login:") - expect.send("root\n") - expect.expect("Password:") - expect.send("secret\n") +```{doctest} +>>> # the server echos all inputs +>>> with PexpectAdapter(client=client.tcp_port) as expect: +... assert expect.send("hello") == 5 # written 5 bytes +... assert expect.expect(["hi", "hello"]) == 1 # found string at index 1 ``` Connect to a remote TCP port with the fabric SSH client @@ -90,6 +92,22 @@ Connect to a remote TCP port with the fabric SSH client See [fabric](https://docs.fabfile.org/en/latest/api/connection.html#fabric.connection.Connection) for API documentation ```{testcode} +:skipif: True with FabricAdapter(client=client.tcp_port, connect_kwargs={"password": "secret"}) as conn: conn.run("uname") ``` + +```{testsetup} * +from jumpstarter_driver_network.adapters import * +from jumpstarter_driver_network.driver import * +from jumpstarter_driver_composite.driver import Composite +from jumpstarter.common.utils import serve + +instance = serve(Composite(children={"tcp_port": EchoNetwork(), "unix_socket": EchoNetwork()})) + +client = instance.__enter__() +``` + +```{testcleanup} * +instance.__exit__(None, None, None) +``` diff --git a/docs/source/api-reference/drivers.md b/docs/source/api-reference/drivers.md index 2f4498f6..91eaace0 100644 --- a/docs/source/api-reference/drivers.md +++ b/docs/source/api-reference/drivers.md @@ -19,16 +19,9 @@ This project is still evolving, so these docs may be incomplete or out-of-date. ``` ## Example -```{testsetup} * -import jumpstarter.common.importlib - -def import_class(class_path, allow, unsafe): - return globals()["ExampleClient"] - -jumpstarter.common.importlib.import_class = import_class -``` - ```{testcode} +from sys import modules +from types import SimpleNamespace from anyio import connect_tcp, sleep from contextlib import asynccontextmanager from collections.abc import Generator @@ -74,6 +67,8 @@ class ExampleClient(DriverClient): def echo_generator(self, message) -> Generator[str, None, None]: yield from self.streamingcall("echo_generator", message) +modules["example"] = SimpleNamespace(ExampleClient=ExampleClient) + with serve(ExampleDriver()) as client: print(client.echo("hello")) assert list(client.echo_generator("hello")) == ["hello"] * 10 diff --git a/docs/source/api-reference/drivers/can.md b/docs/source/api-reference/drivers/can.md index 0a24424b..8632aec4 100644 --- a/docs/source/api-reference/drivers/can.md +++ b/docs/source/api-reference/drivers/can.md @@ -3,11 +3,11 @@ The CAN driver is a driver for using CAN bus connections. ```{eval-rst} -.. autoclass:: jumpstarter_driver_can.client.CanClient +.. autoclass:: jumpstarter_driver_can.client.CanClient() :members: ``` ```{eval-rst} -.. autoclass:: jumpstarter_driver_can.client.IsoTpClient +.. autoclass:: jumpstarter_driver_can.client.IsoTpClient() :members: -``` \ No newline at end of file +``` diff --git a/docs/source/api-reference/drivers/index.md b/docs/source/api-reference/drivers/index.md index 50940776..63829bae 100644 --- a/docs/source/api-reference/drivers/index.md +++ b/docs/source/api-reference/drivers/index.md @@ -11,6 +11,6 @@ can.md pyserial.md sdwire.md snmp.md +tftp.md ustreamer.md yepkit.md -``` diff --git a/docs/source/api-reference/drivers/pyserial.md b/docs/source/api-reference/drivers/pyserial.md index 0ad32c10..56ce5961 100644 --- a/docs/source/api-reference/drivers/pyserial.md +++ b/docs/source/api-reference/drivers/pyserial.md @@ -24,7 +24,7 @@ export: ## PySerialClient API ```{eval-rst} -.. autoclass:: jumpstarter_driver_pyserial.client.PySerialClient +.. autoclass:: jumpstarter_driver_pyserial.client.PySerialClient() :members: pexpect, open, stream, open_stream, close ``` @@ -41,20 +41,32 @@ Using expect without a context manager session = pyserialclient.open() session.sendline("Hello, world!") session.expect("Hello, world!") -session.close() +pyserialclient.close() ``` Using a simple BlockingStream with a context manager ```{testcode} with pyserialclient.stream() as stream: - stream.write(b"Hello, world!") - data = stream.read(13) + stream.send(b"Hello, world!") + data = stream.receive() ``` Using a simple BlockingStream without a context manager ```{testcode} stream = pyserialclient.open_stream() -stream.write(b"Hello, world!") -data = stream.read(13) -stream.close() +stream.send(b"Hello, world!") +data = stream.receive() +``` + +```{testsetup} * +from jumpstarter_driver_pyserial.driver import PySerial +from jumpstarter.common.utils import serve + +instance = serve(PySerial(url="loop://")) + +pyserialclient = instance.__enter__() +``` + +```{testcleanup} * +instance.__exit__(None, None, None) ``` diff --git a/docs/source/api-reference/drivers/sdwire.md b/docs/source/api-reference/drivers/sdwire.md index 1c0da174..85b1f5ae 100644 --- a/docs/source/api-reference/drivers/sdwire.md +++ b/docs/source/api-reference/drivers/sdwire.md @@ -4,10 +4,27 @@ The SDWire driver is an storgate multiplexer driver for using the SDWire multiplexer. This device multiplexes an SD card between the DUT and the exporter host. +## Driver Configuration + +```{literalinclude} sdwire.yaml +:language: yaml +``` + +```{doctest} +:hide: +>>> from jumpstarter.config import ExporterConfigV1Alpha1DriverInstance +>>> ExporterConfigV1Alpha1DriverInstance.from_path("source/api-reference/drivers/sdwire.yaml").instantiate() +Traceback (most recent call last): +... +FileNotFoundError: failed to find sd-wire device +``` + +## Client API + The SDWire driver implements the `StorageMuxClient` class, which is a generic storage class. ```{eval-rst} -.. autoclass:: jumpstarter_driver_opendal.client.StorageMuxClient +.. autoclass:: jumpstarter_driver_opendal.client.StorageMuxClient() :members: ``` diff --git a/docs/source/api-reference/drivers/sdwire.yaml b/docs/source/api-reference/drivers/sdwire.yaml new file mode 100644 index 00000000..a966b538 --- /dev/null +++ b/docs/source/api-reference/drivers/sdwire.yaml @@ -0,0 +1,8 @@ +type: "jumpstarter_driver_sdwire.driver.SDWire" +config: + # optional serial number of the sd-wire device + # the first one found would be used if unset + serial: "sdw-00001" + # optional path to the block device exposed by sd-wire + # automatically detected if unset + storage_device: "/dev/disk/by-diskseq/1" diff --git a/docs/source/api-reference/drivers/tftp.md b/docs/source/api-reference/drivers/tftp.md new file mode 100644 index 00000000..68cc8c26 --- /dev/null +++ b/docs/source/api-reference/drivers/tftp.md @@ -0,0 +1,89 @@ +# TFTP Driver + +**driver**: `jumpstarter_driver_tftp.driver.Tftp` + +The TFTP driver provides a read-only TFTP server that can be used to serve files. + +## Driver Configuration +```yaml +export: + tftp: + type: jumpstarter_driver_tftp.driver.Tftp + config: + root_dir: /var/lib/tftpboot # Directory to serve files from + host: 192.168.1.100 # Host IP to bind to (optional) + port: 69 # Port to listen on (optional) +``` + +### Config parameters + +| Parameter | Description | Type | Required | Default | +|-----------|-------------|------|----------|---------| +| root_dir | Root directory for the TFTP server | str | no | "/var/lib/tftpboot" | +| host | IP address to bind the server to | str | no | auto-detect | +| port | Port number to listen on | int | no | 69 | + +## TftpServerClient API + +```{eval-rst} +.. autoclass:: jumpstarter_driver_tftp.client.TftpServerClient() + :members: + :show-inheritance: +``` + +## Exception Classes + +```{eval-rst} +.. autoclass:: jumpstarter_driver_tftp.driver.TftpError + :members: + :show-inheritance: + +.. autoclass:: jumpstarter_driver_tftp.driver.ServerNotRunning + :members: + :show-inheritance: + +.. autoclass:: jumpstarter_driver_tftp.driver.FileNotFound + :members: + :show-inheritance: +``` + +## Examples + +```{doctest} +>>> import tempfile +>>> import os +>>> from jumpstarter_driver_tftp.driver import Tftp +>>> with tempfile.TemporaryDirectory() as tmp_dir: +... # Create a test file +... test_file = os.path.join(tmp_dir, "test.txt") +... with open(test_file, "w") as f: +... _ = f.write("hello") +... +... # Start TFTP server +... tftp = Tftp(root_dir=tmp_dir, host="127.0.0.1", port=6969) +... tftp.start() +... +... # List files +... files = tftp.list_files() +... assert "test.txt" in files +... +... tftp.stop() +``` + +```{testsetup} * +import tempfile +import os +from jumpstarter_driver_tftp.driver import Tftp +from jumpstarter.common.utils import serve + +# Create a persistent temp dir that won't be removed by the example +TEST_DIR = tempfile.mkdtemp(prefix='tftp-test-') +instance = serve(Tftp(root_dir=TEST_DIR, host="127.0.0.1")) +client = instance.__enter__() +``` + +```{testcleanup} * +instance.__exit__(None, None, None) +import shutil +shutil.rmtree(TEST_DIR, ignore_errors=True) +``` diff --git a/docs/source/api-reference/drivers/ustreamer.md b/docs/source/api-reference/drivers/ustreamer.md index 01ee79d8..073730cb 100644 --- a/docs/source/api-reference/drivers/ustreamer.md +++ b/docs/source/api-reference/drivers/ustreamer.md @@ -4,7 +4,24 @@ The Ustreamer driver is a driver for using the ustreamer video streaming server driven by the jumpstarter exporter. This driver takes a video device and exposes both snapshot and streaming interfaces. +## Driver configuration + +```{literalinclude} ustreamer.yaml +:language: yaml +``` + +```{doctest} +:hide: +>>> from jumpstarter.config import ExporterConfigV1Alpha1DriverInstance +>>> ExporterConfigV1Alpha1DriverInstance.from_path("source/api-reference/drivers/ustreamer.yaml").instantiate() +Traceback (most recent call last): +... +io.UnsupportedOperation: fileno +``` + +## Client API + ```{eval-rst} -.. autoclass:: jumpstarter_driver_ustreamer.client.UStreamerClient +.. autoclass:: jumpstarter_driver_ustreamer.client.UStreamerClient() :members: ``` diff --git a/docs/source/api-reference/drivers/ustreamer.yaml b/docs/source/api-reference/drivers/ustreamer.yaml new file mode 100644 index 00000000..a3230a38 --- /dev/null +++ b/docs/source/api-reference/drivers/ustreamer.yaml @@ -0,0 +1,8 @@ +type: "jumpstarter_driver_ustreamer.driver.UStreamer" +config: + # name or path of the ustreamer executable + # defaults to finding ustreamer from path + executable: "ustreamer" + args: # extra arguments to pass to ustreamer + brightness: auto # --brightness=auto + contrast: default # --contract=default diff --git a/docs/source/api-reference/drivers/yepkit.md b/docs/source/api-reference/drivers/yepkit.md index 2d42f34b..b8c03219 100644 --- a/docs/source/api-reference/drivers/yepkit.md +++ b/docs/source/api-reference/drivers/yepkit.md @@ -36,13 +36,14 @@ export: The yepkit ykush driver provides a `PowerClient` with the following API: ```{eval-rst} -.. autoclass:: jumpstarter_driver_power.client.PowerClient +.. autoclass:: jumpstarter_driver_power.client.PowerClient() :members: on, off ``` ### Examples Powering on and off a device ```{testcode} +:skipif: True client.power.on() time.sleep(1) client.power.off() diff --git a/docs/source/getting-started/setup-local-exporter.md b/docs/source/getting-started/setup-local-exporter.md index 2dd41b37..a805e425 100644 --- a/docs/source/getting-started/setup-local-exporter.md +++ b/docs/source/getting-started/setup-local-exporter.md @@ -164,10 +164,10 @@ from jumpstarter_testing.pytest import JumpstarterTest class MyTest(JumpstarterTest): def test_power_on(self, client): - assert client.power.on() == "ok" + client.power.on() def test_power_off(self, client): - assert client.power.off() == "ok" + client.power.off() ``` ```shell diff --git a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py index 20ec6acf..35d8de22 100644 --- a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py +++ b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py @@ -1,23 +1,21 @@ -from jumpstarter_driver_composite.driver import Composite +from jumpstarter_driver_power.driver import MockPower -from jumpstarter.driver import Driver +from .driver import Composite +from jumpstarter.common.utils import serve - -def test_composite_basic(): - class SimpleDriver(Driver): - @classmethod - def client(cls) -> str: - return "test.client.SimpleClient" - - child1 = SimpleDriver() - child2 = SimpleDriver() - - composite = Composite(children={ - "child1": child1, - "child2": child2 - }) - - assert len(composite.children) == 2 - assert composite.children["child1"] == child1 - assert composite.children["child2"] == child2 +def test_drivers_composite(): + with serve( + Composite( + children={ + "power0": MockPower(), + "composite1": Composite( + children={ + "power1": MockPower(), + }, + ), + }, + ) + ) as client: + client.power0.on() + client.composite1.power1.on() diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 39de02db..4926aec6 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -122,12 +122,12 @@ def close(self): self.off() @export - def on(self): - return self.control("on") + def on(self) -> None: + self.control("on") @export - def off(self): - return self.control("off") + def off(self) -> None: + self.control("off") @export async def read(self) -> AsyncGenerator[PowerReading, None]: diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py index 96f64ffd..9af2b281 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py @@ -8,10 +8,10 @@ class PowerClient(DriverClient): - def on(self): + def on(self) -> None: self.call("on") - def off(self): + def off(self) -> None: self.call("off") def cycle(self, wait: int = 2): @@ -37,13 +37,11 @@ def base(): def on(): """Power on""" self.on() - click.echo("Powered on") @base.command() def off(): """Power off""" self.off() - click.echo("Powered off") @base.command() @click.option('--wait', '-w', default=2, help='Wait time in seconds between off and on') diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py index 1be8a470..cd23c2ba 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py @@ -11,27 +11,23 @@ def client(cls) -> str: return "jumpstarter_driver_power.client.PowerClient" @abstractmethod - async def on(self): ... + async def on(self) -> None: ... @abstractmethod - async def off(self): ... + async def off(self) -> None: ... @abstractmethod async def read(self) -> AsyncGenerator[PowerReading, None]: ... class MockPower(PowerInterface, Driver): - def __init__(self, children=None): - self._power_state = None - super().__init__() - @export - async def on(self): - self._power_state = "on" + async def on(self) -> None: + pass @export - async def off(self): - self._power_state = "off" + async def off(self) -> None: + pass @export async def read(self) -> AsyncGenerator[PowerReading, None]: @@ -40,17 +36,13 @@ async def read(self) -> AsyncGenerator[PowerReading, None]: class SyncMockPower(PowerInterface, Driver): - def __init__(self, children=None): - self._power_state = None - super().__init__() - @export - def on(self): - self._power_state = "on" + def on(self) -> None: + pass @export - def off(self): - self._power_state = "off" + def off(self) -> None: + pass @export def read(self) -> Generator[PowerReading, None]: diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py index 63a2b506..9f766a1f 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py @@ -10,10 +10,7 @@ async def test_driver_mock_power(): driver = MockPower() await driver.on() - assert driver._power_state == "on" - await driver.off() - assert driver._power_state == "off" assert [v async for v in driver.read()] == [ PowerReading(voltage=0.0, current=0.0), @@ -25,10 +22,7 @@ def test_driver_sync_mock_power(): driver = SyncMockPower() driver.on() - assert driver._power_state == "on" - driver.off() - assert driver._power_state == "off" assert list(driver.read()) == [ PowerReading(voltage=0.0, current=0.0), diff --git a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py index cf9419d4..fca7550e 100644 --- a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py +++ b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py @@ -25,6 +25,10 @@ def open(self) -> fdspawn: self._context_manager = self.pexpect() return self._context_manager.__enter__() + def close(self): + if hasattr(self, "_context_manager"): + self._context_manager.__exit__(None, None, None) + @contextmanager def pexpect(self): """ diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py index 3242c78c..e31325d0 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py @@ -5,10 +5,10 @@ @dataclass(kw_only=True) class DigitalOutputClient(DriverClient): - def off(self): + def off(self) -> None: self.call("off") - def on(self): + def on(self) -> None: self.call("on") diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py index f159d234..0805bddf 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py @@ -26,14 +26,14 @@ def close(self): super().close() @export - def off(self): + def off(self) -> None: if not isinstance(self.device, DigitalOutputDevice): self.device.close() self.device = DigitalOutputDevice(pin=self.pin, initial_value=None) self.device.off() @export - def on(self): + def on(self) -> None: if not isinstance(self.device, DigitalOutputDevice): self.device.close() self.device = DigitalOutputDevice(pin=self.pin, initial_value=None) diff --git a/packages/jumpstarter-driver-tftp/examples/tftp_test.py b/packages/jumpstarter-driver-tftp/examples/tftp_test.py index 735fcc14..c5aa221c 100644 --- a/packages/jumpstarter-driver-tftp/examples/tftp_test.py +++ b/packages/jumpstarter-driver-tftp/examples/tftp_test.py @@ -6,6 +6,7 @@ log = logging.getLogger(__name__) + class TestResource(JumpstarterTest): filter_labels = {"board": "rpi4"} diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py index fc318846..e70eab33 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py @@ -1 +1 @@ -CHUNK_SIZE = 1024 * 1024 * 4 # 4MB +CHUNK_SIZE = 1024 * 1024 * 4 # 4MB diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 5ad98809..d90ea677 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -17,22 +17,36 @@ class TftpError(Exception): """Base exception for TFTP server errors""" + pass + class ServerNotRunning(TftpError): """Server is not running""" + pass + class FileNotFound(TftpError): """File not found""" + pass + @dataclass(kw_only=True) class Tftp(Driver): - """TFTP Server driver for Jumpstarter""" + """TFTP Server driver for Jumpstarter + + This driver implements a TFTP read-only server. + + Attributes: + root_dir (str): Root directory for the TFTP server. Defaults to "/var/lib/tftpboot" + host (str): IP address to bind the server to. If empty, will use the default route interface + port (int): Port number to listen on. Defaults to 69 (standard TFTP port) + """ root_dir: str = "/var/lib/tftpboot" - host: str = field(default='') + host: str = field(default="") port: int = 69 server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) @@ -45,7 +59,7 @@ def __post_init__(self): super().__post_init__() os.makedirs(self.root_dir, exist_ok=True) - if self.host == '': + if self.host == "": self.host = self.get_default_ip() def get_default_ip(self): @@ -97,6 +111,14 @@ async def _wait_for_shutdown(self): @export def start(self): + """Start the TFTP server. + + The server will start listening for incoming TFTP requests on the configured + host and port. If the server is already running, a warning will be logged. + + Raises: + TftpError: If the server fails to start or times out during initialization + """ if self.server_thread is not None and self.server_thread.is_alive(): self.logger.warning("TFTP server is already running") return @@ -116,6 +138,11 @@ def start(self): @export def stop(self): + """Stop the TFTP server. + + Initiates a graceful shutdown of the server and waits for all active transfers + to complete. If the server is not running, a warning will be logged. + """ if self.server_thread is None or not self.server_thread.is_alive(): self.logger.warning("stop called - TFTP server is not running") return @@ -131,10 +158,28 @@ def stop(self): @export def list_files(self) -> list[str]: + """List all files available in the TFTP server root directory. + + Returns: + list[str]: A list of filenames present in the root directory + """ return os.listdir(self.root_dir) @export async def put_file(self, filename: str, src_stream, client_checksum: str): + """Upload a file to the TFTP server. + + Args: + filename (str): Name of the file to create + src_stream: Source stream to read the file data from + client_checksum (str): SHA256 checksum of the file for verification + + Returns: + str: The filename that was uploaded + + Raises: + TftpError: If the file upload fails or path validation fails + """ file_path = os.path.join(self.root_dir, filename) try: @@ -152,6 +197,18 @@ async def put_file(self, filename: str, src_stream, client_checksum: str): @export def delete_file(self, filename: str): + """Delete a file from the TFTP server. + + Args: + filename (str): Name of the file to delete + + Returns: + str: The filename that was deleted + + Raises: + FileNotFound: If the specified file does not exist + TftpError: If the deletion operation fails + """ file_path = os.path.join(self.root_dir, filename) if not os.path.exists(file_path): @@ -165,6 +222,15 @@ def delete_file(self, filename: str): @export def check_file_checksum(self, filename: str, client_checksum: str) -> bool: + """Check if a file matches the expected checksum. + + Args: + filename (str): Name of the file to check + client_checksum (str): Expected SHA256 checksum + + Returns: + bool: True if the file exists and matches the checksum, False otherwise + """ file_path = os.path.join(self.root_dir, filename) self.logger.debug(f"checking checksum for file: {filename}") self.logger.debug(f"file path: {file_path}") @@ -181,10 +247,20 @@ def check_file_checksum(self, filename: str, client_checksum: str) -> bool: @export def get_host(self) -> str: + """Get the host address the server is bound to. + + Returns: + str: The IP address or hostname + """ return self.host @export def get_port(self) -> int: + """Get the port number the server is listening on. + + Returns: + int: The port number + """ return self.port def close(self): diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index 3f0f6911..f74e8ec0 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -21,12 +21,14 @@ def temp_dir(): with tempfile.TemporaryDirectory() as tmpdir: yield tmpdir + @pytest.fixture def server(temp_dir): server = Tftp(root_dir=temp_dir, host="127.0.0.1") yield server server.close() + @pytest.mark.anyio async def test_tftp_file_operations(server): filename = "test.txt" @@ -60,17 +62,20 @@ async def send_data(): with pytest.raises(FileNotFound): server.delete_file("nonexistent.txt") + def test_tftp_host_config(temp_dir): custom_host = "192.168.1.1" server = Tftp(root_dir=temp_dir, host=custom_host) assert server.get_host() == custom_host + def test_tftp_root_directory_creation(temp_dir): new_dir = os.path.join(temp_dir, "new_tftp_root") server = Tftp(root_dir=new_dir) assert os.path.exists(new_dir) server.close() + @pytest.mark.anyio async def test_tftp_detect_corrupted_file(server): filename = "corrupted.txt" @@ -86,10 +91,12 @@ async def test_tftp_detect_corrupted_file(server): assert not server.check_file_checksum(filename, client_checksum) + @pytest.fixture def anyio_backend(): return "asyncio" + async def _upload_file(server, filename: str, data: bytes) -> str: send_stream, receive_stream = create_memory_object_stream() resource_uuid = uuid4() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index 37e83a3a..1374df08 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -33,18 +33,19 @@ class TftpServer: TFTP Server that handles read requests (RRQ). """ - def __init__(self, host: str, port: int, root_dir: str, - block_size: int = 512, timeout: float = 5.0, retries: int = 3): + def __init__( + self, host: str, port: int, root_dir: str, block_size: int = 512, timeout: float = 5.0, retries: int = 3 + ): self.host = host self.port = port self.root_dir = pathlib.Path(os.path.abspath(root_dir)) self.block_size = block_size self.timeout = timeout self.retries = retries - self.active_transfers: Set['TftpTransfer'] = set() + self.active_transfers: Set["TftpTransfer"] = set() self.shutdown_event = asyncio.Event() self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional['TftpServerProtocol'] = None + self.protocol: Optional["TftpServerProtocol"] = None self.logger = logging.getLogger(self.__class__.__name__) self.ready_event = asyncio.Event() @@ -52,7 +53,7 @@ def __init__(self, host: str, port: int, root_dir: str, def address(self) -> Optional[Tuple[str, int]]: """Get the server's bound address and port.""" if self.transport: - return self.transport.get_extra_info('socket').getsockname() + return self.transport.get_extra_info("socket").getsockname() return None async def start(self): @@ -61,8 +62,7 @@ async def start(self): self.ready_event.set() self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpServerProtocol(self), - local_addr=(self.host, self.port) + lambda: TftpServerProtocol(self), local_addr=(self.host, self.port) ) try: @@ -92,11 +92,11 @@ async def shutdown(self): self.logger.info("Shutdown signal received for TFTP server") self.shutdown_event.set() - def register_transfer(self, transfer: 'TftpTransfer'): + def register_transfer(self, transfer: "TftpTransfer"): self.active_transfers.add(transfer) self.logger.debug(f"Registered transfer: {transfer}") - def unregister_transfer(self, transfer: 'TftpTransfer'): + def unregister_transfer(self, transfer: "TftpTransfer"): self.active_transfers.discard(transfer) self.logger.debug(f"Unregistered transfer: {transfer}") @@ -130,7 +130,7 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return try: - opcode = Opcode(int.from_bytes(data[0:2], 'big')) + opcode = Opcode(int.from_bytes(data[0:2], "big")) except ValueError: self.logger.error(f"Unknown opcode from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") @@ -166,9 +166,9 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]): def _send_oack(self, addr: Tuple[str, int], options: dict): """Send Option Acknowledgment (OACK) packet.""" - oack_data = Opcode.OACK.to_bytes(2, 'big') + oack_data = Opcode.OACK.to_bytes(2, "big") for opt_name, opt_value in options.items(): - oack_data += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + oack_data += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") if self.transport: self.transport.sendto(oack_data, addr) @@ -176,39 +176,36 @@ def _send_oack(self, addr: Tuple[str, int], options: dict): def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, 'big') + - error_code.to_bytes(2, 'big') + - message.encode('utf-8') + b'\x00' + Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" ) if self.transport: self.transport.sendto(error_packet, addr) self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}") def _parse_request(self, data: bytes) -> Tuple[str, str, dict]: - parts = data[2:].split(b'\x00') + parts = data[2:].split(b"\x00") if len(parts) < 2: raise ValueError("Invalid RRQ format") - filename = parts[0].decode('utf-8') + filename = parts[0].decode("utf-8") if len(filename) > 255: # RFC 1350 doesn't specify a limit raise ValueError("Filename too long") if not all(c.isprintable() and c not in '<>:"/\\|?*' for c in filename): raise ValueError("Invalid characters in filename") - if '\x00' in filename: + if "\x00" in filename: raise ValueError("Null byte in filename") - mode = parts[1].decode('utf-8').lower() + mode = parts[1].decode("utf-8").lower() options = self._parse_options(parts[2:]) return filename, mode, options - def _parse_options(self, option_parts: list) -> dict: options = {} i = 0 while i < len(option_parts) - 1: try: - opt_name = option_parts[i].decode('utf-8').lower() - opt_value = option_parts[i + 1].decode('utf-8') + opt_name = option_parts[i].decode("utf-8").lower() + opt_value = option_parts[i + 1].decode("utf-8") options[opt_name] = opt_value i += 2 except Exception: @@ -216,7 +213,7 @@ def _parse_options(self, option_parts: list) -> dict: return options def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: - if mode not in ('netascii', 'octet'): + if mode not in ("netascii", "octet"): self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode") return False @@ -248,14 +245,12 @@ def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int: return blksize else: self.logger.warning( - f"Requested block size {blksize} out of range (512-65464), " - f"using default: {self.server.block_size}" + f"Requested block size {blksize} out of range (512-65464), using default: {self.server.block_size}" ) return self.server.block_size except ValueError: self.logger.warning( - f"Invalid block size value '{requested_blksize}', " - f"using default: {self.server.block_size}" + f"Invalid block size value '{requested_blksize}', using default: {self.server.block_size}" ) return self.server.block_size @@ -269,15 +264,11 @@ def _negotiate_timeout(self, requested_timeout: Optional[str]) -> float: return float(timeout) else: self.logger.warning( - f"Timeout value {timeout} out of range (1-255), " - f"using default: {self.server.timeout}" + f"Timeout value {timeout} out of range (1-255), using default: {self.server.timeout}" ) return self.server.timeout except ValueError: - self.logger.warning( - f"Invalid timeout value '{requested_timeout}', " - f"using default: {self.server.timeout}" - ) + self.logger.warning(f"Invalid timeout value '{requested_timeout}', using default: {self.server.timeout}") return self.server.timeout def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: @@ -285,21 +276,21 @@ def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: blksize = self.server.block_size timeout = self.server.timeout - if 'blksize' in options: - requested = options['blksize'] + if "blksize" in options: + requested = options["blksize"] blksize = self._negotiate_block_size(requested) - negotiated['blksize'] = blksize + negotiated["blksize"] = blksize - if 'timeout' in options: - requested = options['timeout'] + if "timeout" in options: + requested = options["timeout"] timeout = self._negotiate_timeout(requested) - negotiated['timeout'] = int(timeout) + negotiated["timeout"] = int(timeout) return negotiated, blksize, timeout - - async def _start_transfer(self, filepath: pathlib.Path, addr: Tuple[str, int], - blksize: int, timeout: float, negotiated_options: dict): + async def _start_transfer( + self, filepath: pathlib.Path, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict + ): transfer = TftpReadTransfer( server=self.server, filepath=filepath, @@ -307,11 +298,12 @@ async def _start_transfer(self, filepath: pathlib.Path, addr: Tuple[str, int], block_size=blksize, timeout=timeout, retries=self.server.retries, - negotiated_options=negotiated_options + negotiated_options=negotiated_options, ) self.server.register_transfer(transfer) asyncio.create_task(transfer.start()) + def is_subpath(path: pathlib.Path, root: pathlib.Path) -> bool: try: path.relative_to(root) @@ -325,8 +317,15 @@ class TftpTransfer: Base class for TFTP transfers. """ - def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], - block_size: int, timeout: float, retries: int): + def __init__( + self, + server: TftpServer, + filepath: pathlib.Path, + client_addr: Tuple[str, int], + block_size: int, + timeout: float, + retries: int, + ): self.server = server self.filepath = filepath self.client_addr = client_addr @@ -334,7 +333,7 @@ def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tupl self.timeout = timeout self.retries = retries self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional['TftpTransferProtocol'] = None + self.protocol: Optional["TftpTransferProtocol"] = None self.cleanup_task: Optional[asyncio.Task] = None self.logger = logging.getLogger(self.__class__.__name__) @@ -352,15 +351,23 @@ async def cleanup(self): class TftpReadTransfer(TftpTransfer): - def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], - block_size: int, timeout: float, retries: int, negotiated_options: Optional[dict] = None): + def __init__( + self, + server: TftpServer, + filepath: pathlib.Path, + client_addr: Tuple[str, int], + block_size: int, + timeout: float, + retries: int, + negotiated_options: Optional[dict] = None, + ): super().__init__( server=server, filepath=filepath, client_addr=client_addr, block_size=block_size, timeout=timeout, - retries=retries + retries=retries, ) self.block_num = 0 self.ack_received = asyncio.Event() @@ -390,17 +397,14 @@ async def _initialize_transfer(self) -> bool: loop = asyncio.get_running_loop() self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpTransferProtocol(self), - local_addr=('0.0.0.0', 0), - remote_addr=self.client_addr + lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr ) - local_addr = self.transport.get_extra_info('sockname') + local_addr = self.transport.get_extra_info("sockname") self.logger.debug(f"Transfer bound to local {local_addr}") # Only send OACK if we have non-default options to negotiate if self.negotiated_options and ( - self.negotiated_options['blksize'] != 512 or - self.negotiated_options['timeout'] != self.server.timeout + self.negotiated_options["blksize"] != 512 or self.negotiated_options["timeout"] != self.server.timeout ): oack_packet = self._create_oack_packet() if not await self._send_with_retries(oack_packet, is_oack=True): @@ -411,7 +415,7 @@ async def _initialize_transfer(self) -> bool: return True async def _perform_transfer(self): - async with aiofiles.open(self.filepath, 'rb') as f: + async with aiofiles.open(self.filepath, "rb") as f: while True: if self.server.shutdown_event.is_set(): self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") @@ -428,7 +432,7 @@ async def _handle_data_block(self, data: bytes) -> bool: """ if not data and self.block_num == 1: # Empty file case - packet = self._create_data_packet(b'') + packet = self._create_data_packet(b"") await self._send_with_retries(packet) return False elif data: @@ -450,7 +454,7 @@ async def _handle_data_block(self, data: bytes) -> bool: return True else: # EOF reached - packet = self._create_data_packet(b'') + packet = self._create_data_packet(b"") success = await self._send_with_retries(packet) if not success: self.logger.error(f"Failed to send final block {self.block_num}") @@ -459,25 +463,21 @@ async def _handle_data_block(self, data: bytes) -> bool: return False def _create_oack_packet(self) -> bytes: - packet = Opcode.OACK.to_bytes(2, 'big') + packet = Opcode.OACK.to_bytes(2, "big") for opt_name, opt_value in self.negotiated_options.items(): - packet += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + packet += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") return packet def _create_data_packet(self, data: bytes) -> bytes: - return ( - Opcode.DATA.to_bytes(2, 'big') + - self.block_num.to_bytes(2, 'big') + - data - ) + return Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + data def _send_packet(self, packet: bytes): self.transport.sendto(packet) - if packet[0:2] == Opcode.DATA.to_bytes(2, 'big'): - block = int.from_bytes(packet[2:4], 'big') + if packet[0:2] == Opcode.DATA.to_bytes(2, "big"): + block = int.from_bytes(packet[2:4], "big") data_length = len(packet) - 4 self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}") - elif packet[0:2] == Opcode.OACK.to_bytes(2, 'big'): + elif packet[0:2] == Opcode.OACK.to_bytes(2, "big"): self.logger.debug(f"Sent OACK to {self.client_addr}") async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool: @@ -488,8 +488,7 @@ async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool try: self._send_packet(packet) self.logger.debug( - f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, " - f"waiting for ACK (Attempt {attempt})" + f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})" ) self.ack_received.clear() await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout) @@ -524,6 +523,7 @@ def handle_ack(self, block_num: int): else: self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}") + class TftpTransferProtocol(asyncio.DatagramProtocol): """ Protocol for handling ACKs during a TFTP transfer. @@ -535,7 +535,7 @@ def __init__(self, transfer: TftpReadTransfer): def connection_made(self, transport: asyncio.DatagramTransport): self.transfer.transport = transport - local_addr = transport.get_extra_info('sockname') + local_addr = transport.get_extra_info("sockname") self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}") def datagram_received(self, data: bytes, addr: Tuple[str, int]): @@ -549,21 +549,20 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return try: - opcode = Opcode(int.from_bytes(data[0:2], 'big')) + opcode = Opcode(int.from_bytes(data[0:2], "big")) except ValueError: self.logger.error(f"Unknown opcode from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") return if opcode == Opcode.ACK: - block_num = int.from_bytes(data[2:4], 'big') + block_num = int.from_bytes(data[2:4], "big") self.logger.debug(f"Received ACK for block {block_num} from {addr}") self.transfer.handle_ack(block_num) else: self.logger.warning(f"Unexpected opcode {opcode} from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unexpected opcode") - def error_received(self, exc): self.logger.error(f"Error received: {exc}") @@ -572,9 +571,7 @@ def connection_lost(self, exc): def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, 'big') + - error_code.to_bytes(2, 'big') + - message.encode('utf-8') + b'\x00' + Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" ) if self.transfer.transport: self.transfer.transport.sendto(error_packet) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 5242f0c9..679e55cc 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -13,11 +13,7 @@ async def tftp_server(): test_file_path = Path(temp_dir) / "test.txt" test_file_path.write_text("Hello, TFTP!") - server = TftpServer( - host="127.0.0.1", - port=0, - root_dir=temp_dir - ) + server = TftpServer(host="127.0.0.1", port=0, root_dir=temp_dir) server_task = asyncio.create_task(server.start()) for _ in range(10): @@ -42,14 +38,13 @@ async def tftp_server(): except asyncio.CancelledError: pass + async def create_test_client(server_port): loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - asyncio.DatagramProtocol, - remote_addr=('127.0.0.1', 0) - ) + transport, protocol = await loop.create_datagram_endpoint(asyncio.DatagramProtocol, remote_addr=("127.0.0.1", 0)) return transport, protocol + @pytest.mark.asyncio async def test_server_startup_and_shutdown(tftp_server): """Test that server starts up and shuts down cleanly.""" @@ -64,6 +59,7 @@ async def test_server_startup_and_shutdown(tftp_server): assert True + @pytest.mark.asyncio async def test_read_request_for_existing_file(tftp_server): """Test reading an existing file from the server.""" @@ -76,9 +72,9 @@ async def test_read_request_for_existing_file(tftp_server): transport, _ = await create_test_client(server.port) rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + # filename - b'octet\x00' # mode + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" # filename + + b"octet\x00" # mode ) transport.sendto(rrq_packet) @@ -91,6 +87,7 @@ async def test_read_request_for_existing_file(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_read_request_for_nonexistent_file(tftp_server): """Test reading a non-existent file returns appropriate error.""" @@ -101,11 +98,7 @@ async def test_read_request_for_nonexistent_file(tftp_server): try: transport, protocol = await create_test_client(server.port) - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'nonexistent.txt\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"nonexistent.txt\x00" + b"octet\x00" transport.sendto(rrq_packet) assert server.transport is not None @@ -115,20 +108,16 @@ async def test_read_request_for_nonexistent_file(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_write_request_rejection(tftp_server): """Test that write requests are properly rejected (server is read-only).""" server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - try: transport, _ = await create_test_client(server.port) - wrq_packet = ( - Opcode.WRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' - ) + wrq_packet = Opcode.WRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00" transport.sendto(wrq_packet) @@ -139,6 +128,7 @@ async def test_write_request_rejection(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_invalid_packet_handling(tftp_server): server, temp_dir, server_port = tftp_server @@ -147,7 +137,7 @@ async def test_invalid_packet_handling(tftp_server): try: transport, _ = await create_test_client(server.port) - transport.sendto(b'\x00\x01') + transport.sendto(b"\x00\x01") assert server.transport is not None @@ -156,6 +146,7 @@ async def test_invalid_packet_handling(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_path_traversal_prevention(tftp_server): """Test that path traversal attempts are blocked.""" @@ -167,11 +158,7 @@ async def test_path_traversal_prevention(tftp_server): try: transport, _ = await create_test_client(server.port) - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'../../../etc/passwd\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"../../../etc/passwd\x00" + b"octet\x00" transport.sendto(rrq_packet) @@ -182,6 +169,7 @@ async def test_path_traversal_prevention(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_options_negotiation(tftp_server): """Test that options (blksize, timeout) are properly negotiated.""" @@ -194,13 +182,13 @@ async def test_options_negotiation(tftp_server): # RRQ with options rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' + - b'blksize\x00' + - b'1024\x00' + - b'timeout\x00' + - b'3\x00' + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" + + b"octet\x00" + + b"blksize\x00" + + b"1024\x00" + + b"timeout\x00" + + b"3\x00" ) transport.sendto(rrq_packet) @@ -212,6 +200,7 @@ async def test_options_negotiation(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_retry_mechanism(tftp_server): server, _, server_port = tftp_server @@ -234,29 +223,21 @@ def datagram_received(self, data, addr): try: loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - lambda: TestProtocol(), - local_addr=('127.0.0.1', 0) - ) + transport, protocol = await loop.create_datagram_endpoint(lambda: TestProtocol(), local_addr=("127.0.0.1", 0)) assert transport is not None, "Failed to create transport" - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00" - transport.sendto(rrq_packet, ('127.0.0.1', server_port)) + transport.sendto(rrq_packet, ("127.0.0.1", server_port)) await asyncio.sleep(server.timeout * 2) - data_packets = [p for p in protocol.received_packets - if p[0:2] == Opcode.DATA.to_bytes(2, 'big')] + data_packets = [p for p in protocol.received_packets if p[0:2] == Opcode.DATA.to_bytes(2, "big")] assert len(data_packets) > 1, "Server should have retried sending DATA packet" - block_numbers = {int.from_bytes(p[2:4], 'big') for p in data_packets} + block_numbers = {int.from_bytes(p[2:4], "big") for p in data_packets} assert len(block_numbers) == 1, "All retried packets should be for the same block" assert 1 in block_numbers, "First block number should be 1" @@ -278,13 +259,13 @@ async def test_invalid_options_handling(tftp_server): transport, _ = await create_test_client(server.port) rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' + - b'blksize\x00' + - b'invalid\x00' + - b'timeout\x00' + - b'999999\x00' + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" + + b"octet\x00" + + b"blksize\x00" + + b"invalid\x00" + + b"timeout\x00" + + b"999999\x00" ) transport.sendto(rrq_packet) diff --git a/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py b/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py index 8d16820b..875dda0e 100644 --- a/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py +++ b/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py @@ -11,25 +11,11 @@ VID = 0x04D8 PID = 0xF2F7 -PORT_UP_COMMANDS = { - '1': 0x11, - '2': 0x12, - '3': 0x13, - 'all': 0x1A -} - -PORT_DOWN_COMMANDS = { - '1': 0x01, - '2': 0x02, - '3': 0x03, - 'all': 0x0A -} - -PORT_STATUS_COMMANDS = { - '1': 0x21, - '2': 0x22, - '3': 0x23 -} +PORT_UP_COMMANDS = {"1": 0x11, "2": 0x12, "3": 0x13, "all": 0x1A} + +PORT_DOWN_COMMANDS = {"1": 0x01, "2": 0x02, "3": 0x03, "all": 0x0A} + +PORT_STATUS_COMMANDS = {"1": 0x21, "2": 0x22, "3": 0x23} VALID_DEFAULTS = ["on", "off", "keep"] @@ -37,9 +23,11 @@ _USB_DEVS = {} _USB_DEVS_LOCK = threading.Lock() # Lock for synchronizing access, we don't do multithread, but just in case.. + @dataclass(kw_only=True) class Ykush(PowerInterface, Driver): - """ driver for Yepkit Ykush USB Hub with Power control """ + """driver for Yepkit Ykush USB Hub with Power control""" + serial: str | None = field(default=None) default: str = "off" port: str = "all" @@ -52,12 +40,10 @@ def __post_init__(self): keys = PORT_UP_COMMANDS.keys() if self.port not in keys: - raise ValueError( - f"The ykush driver port must be any of the following values: {keys}") + raise ValueError(f"The ykush driver port must be any of the following values: {keys}") if self.default not in VALID_DEFAULTS: - raise ValueError( - f"The ykush driver default must be any of the following values: {VALID_DEFAULTS}") + raise ValueError(f"The ykush driver default must be any of the following values: {VALID_DEFAULTS}") with _USB_DEVS_LOCK: # another instance already claimed this device? @@ -75,8 +61,7 @@ def __post_init__(self): if serial == self.serial or self.serial is None: _USB_DEVS[serial] = dev if self.serial is None: - self.logger.warning( - f"No serial number provided for ykush, using the first one found: {serial}") + self.logger.warning(f"No serial number provided for ykush, using the first one found: {serial}") self.serial = serial self.dev = dev return @@ -86,7 +71,7 @@ def __post_init__(self): def _send_cmd(self, cmd, report_size=64): out_ep, in_ep = self._get_endpoints(self.dev) out_buf = [0x00] * report_size - out_buf[0] = cmd # YKUSH command + out_buf[0] = cmd # YKUSH command # Write to the OUT endpoint out_ep.write(out_buf) @@ -103,15 +88,11 @@ def _get_endpoints(self, dev): interface = cfg[(0, 0)] out_endpoint = usb.util.find_descriptor( - interface, - custom_match=lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT + interface, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT ) in_endpoint = usb.util.find_descriptor( - interface, - custom_match=lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN + interface, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN ) if not out_endpoint or not in_endpoint: @@ -127,18 +108,16 @@ def reset(self): self.off() @export - def on(self): + def on(self) -> None: self.logger.info(f"Power ON for Ykush {self.serial} on port {self.port}") cmd = PORT_UP_COMMANDS.get(self.port) _ = self._send_cmd(cmd) - return @export - def off(self): + def off(self) -> None: self.logger.info(f"Power OFF for Ykush {self.serial} on port {self.port}") cmd = PORT_DOWN_COMMANDS.get(self.port) _ = self._send_cmd(cmd) - return @export def read(self) -> AsyncGenerator[PowerReading, None]: diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index ac34fd45..3d7b1f35 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -29,6 +29,11 @@ def instantiate(self) -> Driver: return driver_class(children=children, **self.config) + @classmethod + def from_path(cls, path: str) -> ExporterConfigV1Alpha1DriverInstance: + with open(path) as f: + return cls.model_validate(yaml.safe_load(f)) + class ExporterConfigV1Alpha1(BaseModel): BASE_PATH: ClassVar[Path] = Path("/etc/jumpstarter/exporters") diff --git a/packages/jumpstarter/jumpstarter/listener_test.py b/packages/jumpstarter/jumpstarter/listener_test.py index d2025489..81bdf813 100644 --- a/packages/jumpstarter/jumpstarter/listener_test.py +++ b/packages/jumpstarter/jumpstarter/listener_test.py @@ -49,7 +49,7 @@ async def handle_async(stream): monkeypatch.setattr(lease, "handle_async", handle_async) async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") tg.cancel_scope.cancel() @@ -97,12 +97,12 @@ async def test_controller(mock_controller): unsafe=True, ) as lease: async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") # test concurrent connections async with lease.connect_async() as client2: - assert await client2.call_async("on") == "ok" + await client2.call_async("on") async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") tg.cancel_scope.cancel()