diff --git a/notebooks/experimental/enclaves/V2/01-connect-domains.ipynb b/notebooks/experimental/enclaves/V2/01-connect-domains.ipynb index 386694c974c..656c2abc900 100644 --- a/notebooks/experimental/enclaves/V2/01-connect-domains.ipynb +++ b/notebooks/experimental/enclaves/V2/01-connect-domains.ipynb @@ -207,7 +207,7 @@ "outputs": [], "source": [ "# syft absolute\n", - "from syft.service.project.project import check_route_reachability" + "from syft.service.network.utils import check_route_reachability" ] }, { @@ -263,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/notebooks/experimental/enclaves/V2/V2-Enclave-Single-Notebook.ipynb b/notebooks/experimental/enclaves/V2/V2-Enclave-Single-Notebook.ipynb index dcdda3a5e4c..1244ecdcb16 100644 --- a/notebooks/experimental/enclaves/V2/V2-Enclave-Single-Notebook.ipynb +++ b/notebooks/experimental/enclaves/V2/V2-Enclave-Single-Notebook.ipynb @@ -16,8 +16,8 @@ "from syft.service.code.user_code import UserCodeStatus\n", "from syft.service.network.routes import HTTPServerRoute\n", "from syft.service.network.server_peer import ServerPeer\n", + "from syft.service.network.utils import check_route_reachability\n", "from syft.service.project.project import ProjectCode\n", - "from syft.service.project.project import check_route_reachability\n", "from syft.service.response import SyftSuccess\n", "from syft.types.uid import UID\n", "\n", @@ -692,7 +692,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb b/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb index 3a3431e2917..471550d7a54 100644 --- a/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb +++ b/notebooks/experimental/enclaves/V3/V3-Enclave-Model-HostingSingle-Notebook.ipynb @@ -24,9 +24,9 @@ "from syft.abstract_server import ServerType\n", "from syft.service.code.user_code import UserCodeStatus\n", "from syft.service.network.routes import HTTPServerRoute\n", - "from syft.service.network.server_peer import ServerPeer\n", + "from syft.service.network.utils import check_route_reachability\n", + "from syft.service.network.utils import exchange_routes\n", "from syft.service.project.project import ProjectCode\n", - "from syft.service.project.project import check_route_reachability\n", "from syft.service.response import SyftSuccess\n", "from syft.types.uid import UID" ] @@ -530,60 +530,6 @@ "## Create Association Requests" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "29", - "metadata": {}, - "outputs": [], - "source": [ - "canada_server_peer = ServerPeer.from_client(ds_canada_client)\n", - "canada_server_peer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": {}, - "outputs": [], - "source": [ - "italy_server_peer = ServerPeer.from_client(ds_italy_client)\n", - "italy_server_peer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": {}, - "outputs": [], - "source": [ - "canada_conn_req = ds_canada_client.api.services.network.add_peer(italy_server_peer)\n", - "canada_conn_req" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "italy_conn_req = ds_italy_client.api.services.network.add_peer(canada_server_peer)\n", - "italy_conn_req" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [ - "do_canada_client.requests[-1].approve()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -591,7 +537,7 @@ "metadata": {}, "outputs": [], "source": [ - "do_italy_client.requests[-1].approve()" + "exchange_routes(clients=[do_canada_client, do_italy_client], auto_approve=True)" ] }, { @@ -959,7 +905,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/notebooks/scenarios/enclave/03-secondary-datasite-setup.ipynb b/notebooks/scenarios/enclave/03-secondary-datasite-setup.ipynb index 1e93bd2eb16..0292f1d622a 100644 --- a/notebooks/scenarios/enclave/03-secondary-datasite-setup.ipynb +++ b/notebooks/scenarios/enclave/03-secondary-datasite-setup.ipynb @@ -7692,7 +7692,7 @@ "outputs": [], "source": [ "# syft absolute\n", - "from syft.service.project.project import check_route_reachability" + "from syft.service.network.utils import check_route_reachability" ] }, { diff --git a/notebooks/scenarios/enclave/04-data-scientist-join.ipynb b/notebooks/scenarios/enclave/04-data-scientist-join.ipynb index 2b080c82957..9a3abe94e8e 100644 --- a/notebooks/scenarios/enclave/04-data-scientist-join.ipynb +++ b/notebooks/scenarios/enclave/04-data-scientist-join.ipynb @@ -351,7 +351,7 @@ ], "source": [ "# syft absolute\n", - "from syft.service.project.project import check_route_reachability\n", + "from syft.service.network.utils import check_route_reachability\n", "\n", "check_route_reachability([model_owner_ds_client, model_auditor_ds_client])" ] diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 6392b912136..eb558ff122d 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -62,6 +62,8 @@ from .service.model.model import CreateModelAsset as ModelAsset from .service.model.model import SyftModelClass from .service.model.model import syft_model +from .service.network.utils import check_route_reachability # noqa: F401 +from .service.network.utils import exchange_routes # noqa: F401 from .service.notification.notifications import NotificationStatus from .service.policy.policy import CreatePolicyRuleConstant as Constant from .service.policy.policy import CustomInputPolicy diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index c5b9e0c084e..3a53cf32284 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -1,14 +1,20 @@ # stdlib +from enum import Enum +import itertools import logging import threading import time from typing import cast # relative +from ...client.client import SyftClient from ...serde.serializable import serializable from ...types.datetime import DateTime from ..context import AuthedServiceContext +from ..request.request import Request from ..response import SyftError +from ..response import SyftSuccess +from ..user.user_roles import ServiceRole from .network_service import NetworkService from .network_service import ServerPeerAssociationStatus from .server_peer import ServerPeer @@ -130,3 +136,85 @@ def stop(self) -> None: self.thread = None self.started_time = None logger.info("Peer health check task stopped.") + + +def exchange_routes( + clients: list[SyftClient], auto_approve: bool = False +) -> SyftSuccess | SyftError: + """Exchange routes between a list of clients.""" + if auto_approve: + # Check that all clients are admin clients + for client in clients: + if not client.user_role == ServiceRole.ADMIN: + return SyftError( + message=f"Client {client} is not an admin client. " + "Only admin clients can auto-approve connection requests." + ) + + for client1, client2 in itertools.combinations(clients, 2): + peer1 = ServerPeer.from_client(client1) + peer2 = ServerPeer.from_client(client2) + + client1_connection_request = client1.api.services.network.add_peer(peer2) + if isinstance(client1_connection_request, SyftError): + return SyftError( + message=f"Failed to add peer {peer2} to {client1}: {client1_connection_request}" + ) + + client2_connection_request = client2.api.services.network.add_peer(peer1) + if isinstance(client2_connection_request, SyftError): + return SyftError( + message=f"Failed to add peer {peer1} to {client2}: {client2_connection_request}" + ) + + if auto_approve: + if isinstance(client1_connection_request, Request): + res1 = client1_connection_request.approve() + if isinstance(res1, SyftError): + return SyftError( + message=f"Failed to approve connection request between {client1} and {client2}: {res1}" + ) + if isinstance(client2_connection_request, Request): + res2 = client2_connection_request.approve() + if isinstance(res2, SyftError): + return SyftError( + message=f"Failed to approve connection request between {client2} and {client1}: {res2}" + ) + logger.info(f"Exchanged routes between {client1} and {client2}") + else: + logger.info(f"Connection requests sent between {client1} and {client2}.") + + return SyftSuccess(message="Routes exchanged successfully.") + + +class NetworkTopology(Enum): + STAR = "STAR" + MESH = "MESH" + HYBRID = "HYBRID" + + +def check_route_reachability( + clients: list[SyftClient], topology: NetworkTopology = NetworkTopology.MESH +) -> SyftSuccess | SyftError: + if topology == NetworkTopology.STAR: + return SyftError(message="STAR topology is not supported yet") + elif topology == NetworkTopology.MESH: + return check_mesh_topology(clients) + else: + return SyftError(message=f"Invalid topology: {topology}") + + +def check_mesh_topology(clients: list[SyftClient]) -> SyftSuccess | SyftError: + for client in clients: + for other_client in clients: + if client == other_client: + continue + result = client.api.services.network.ping_peer( + verify_key=other_client.root_verify_key + ) + if isinstance(result, SyftError): + return SyftError( + message=f"{client.name}-<{client.id}> - cannot reach" + + f"{other_client.name}-<{other_client.id} - {result.message}" + ) + return SyftSuccess(message="All clients are reachable") diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index a52647b19f5..41ed369825c 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -5,7 +5,6 @@ from collections.abc import Callable from collections.abc import Iterable import copy -from enum import Enum import hashlib import textwrap import time @@ -56,6 +55,7 @@ from ..network.network_service import ServerPeer from ..network.routes import ServerRoute from ..network.routes import connection_to_route +from ..network.utils import check_route_reachability from ..request.request import Request from ..request.request import RequestStatus from ..response import SyftError @@ -1707,36 +1707,3 @@ def create_project_event_hash(project_event: ProjectEvent) -> tuple[bytes, str]: hash_object(project_event.creator_verify_key)[1], ] ) - - -class NetworkTopology(Enum): - STAR = "STAR" - MESH = "MESH" - HYBRID = "HYBRID" - - -def check_route_reachability( - clients: list[SyftClient], topology: NetworkTopology = NetworkTopology.MESH -) -> SyftSuccess | SyftError: - if topology == NetworkTopology.STAR: - return SyftError("STAR topology is not supported yet") - elif topology == NetworkTopology.MESH: - return check_mesh_topology(clients) - else: - return SyftError(message=f"Invalid topology: {topology}") - - -def check_mesh_topology(clients: list[SyftClient]) -> SyftSuccess | SyftError: - for client in clients: - for other_client in clients: - if client == other_client: - continue - result = client.api.services.network.ping_peer( - verify_key=other_client.root_verify_key - ) - if isinstance(result, SyftError): - return SyftError( - message=f"{client.name}-<{client.id}> - cannot reach" - + f"{other_client.name}-<{other_client.id} - {result.message}" - ) - return SyftSuccess(message="All clients are reachable")