diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0a0c40..0f957f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,7 +45,7 @@ jobs: uses: ./.github/workflows/protobuf - name: Install Hatch - run: pip install hatch + run: pip install hatch==1.7.0 - name: Start the cluster run: ./scripts/quick_start.sh diff --git a/client/client.py b/client/client.py index ff8b662..2437996 100644 --- a/client/client.py +++ b/client/client.py @@ -7,6 +7,7 @@ from client.lease import LeaseClient, LeaseIdGenerator from client.watch import WatchClient from client.auth import AuthClient +from client.maintenance import MaintenanceClient class Client: @@ -18,18 +19,23 @@ class Client: lease_client: Lease client watch_client: Watch client auth_client: Auth client + maintenance_client: Maintenance client """ kv_client: KvClient lease_client: LeaseClient watch_client: WatchClient auth_client: AuthClient + maintenance_client: MaintenanceClient - def __init__(self, kv: KvClient, lease: LeaseClient, watch: WatchClient, auth: AuthClient) -> None: + def __init__( + self, kv: KvClient, lease: LeaseClient, watch: WatchClient, auth: AuthClient, maintenance: MaintenanceClient + ) -> None: self.kv_client = kv self.lease_client = lease self.watch_client = watch self.auth_client = auth + self.maintenance_client = maintenance @classmethod async def connect(cls, addrs: list[str]) -> Client: @@ -46,5 +52,6 @@ async def connect(cls, addrs: list[str]) -> Client: lease_client = LeaseClient("client", protocol_client, channel, "", id_gen) watch_client = WatchClient(channel) auth_client = AuthClient("client", protocol_client, channel, "") + maintenance_client = MaintenanceClient(channel) - return cls(kv_client, lease_client, watch_client, auth_client) + return cls(kv_client, lease_client, watch_client, auth_client, maintenance_client) diff --git a/client/maintenance.py b/client/maintenance.py new file mode 100644 index 0000000..c4fd180 --- /dev/null +++ b/client/maintenance.py @@ -0,0 +1,25 @@ +"""Maintenance Client""" + +from typing import AsyncIterable +from grpc import Channel +from api.xline.rpc_pb2_grpc import MaintenanceStub +from api.xline.rpc_pb2 import SnapshotRequest, SnapshotResponse + + +class MaintenanceClient: + """ + Client for Maintenance operations. + + Attributes: + maintenance_client: The client running the Maintenance protocol, communicate with all servers. + """ + + maintenance_client: MaintenanceStub + + def __init__(self, channel: Channel) -> None: + self.maintenance_client = MaintenanceStub(channel) + + async def snapshot(self) -> AsyncIterable[SnapshotResponse]: + """Gets a snapshot over a stream""" + res = self.maintenance_client.Snapshot(SnapshotRequest()) + return res diff --git a/pyproject.toml b/pyproject.toml index 60aef56..36df693 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,8 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "grpcio", - "grpcio-tools", + "grpcio==1.58.0", + "grpcio-tools==1.58.0", "pytest-asyncio", "passlib", ] @@ -167,3 +167,6 @@ exclude_lines = [ "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + +[tool.hatch.build.targets.wheel] +packages = ["client"] diff --git a/tests/maintenance_test.py b/tests/maintenance_test.py new file mode 100644 index 0000000..65eb570 --- /dev/null +++ b/tests/maintenance_test.py @@ -0,0 +1,18 @@ +"""Test for the maintenance client""" + +import pytest +from client import client + + +@pytest.mark.asyncio +async def test_snapshot_should_get_valid_data(): + """ + Snapshot should get valid data + """ + curp_members = ["172.20.0.3:2379", "172.20.0.4:2379", "172.20.0.5:2379"] + cli = await client.Client.connect(curp_members) + maintenance_client = cli.maintenance_client + + res = await maintenance_client.snapshot() + async for snapshot in res: + assert snapshot.blob != b""