Skip to content

Commit

Permalink
Add util functions to exchange peer routes and check reachability
Browse files Browse the repository at this point in the history
  • Loading branch information
itstauq committed Aug 3, 2024
1 parent e9596c5 commit 535c3cb
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 98 deletions.
4 changes: 2 additions & 2 deletions notebooks/experimental/enclaves/V2/01-connect-domains.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
"outputs": [],
"source": [
"# syft absolute\n",
"from syft.service.project.project import check_route_reachability"
"from syft.service.network.utils import check_route_reachability"
]
},
{
Expand Down Expand Up @@ -263,7 +263,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"from syft.service.code.user_code import UserCodeStatus\n",
"from syft.service.network.routes import HTTPServerRoute\n",
"from syft.service.network.server_peer import ServerPeer\n",
"from syft.service.network.utils import check_route_reachability\n",
"from syft.service.project.project import ProjectCode\n",
"from syft.service.project.project import check_route_reachability\n",
"from syft.service.response import SyftSuccess\n",
"from syft.types.uid import UID\n",
"\n",
Expand Down Expand Up @@ -692,7 +692,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
"from syft.abstract_server import ServerType\n",
"from syft.service.code.user_code import UserCodeStatus\n",
"from syft.service.network.routes import HTTPServerRoute\n",
"from syft.service.network.server_peer import ServerPeer\n",
"from syft.service.network.utils import check_route_reachability\n",
"from syft.service.network.utils import exchange_routes\n",
"from syft.service.project.project import ProjectCode\n",
"from syft.service.project.project import check_route_reachability\n",
"from syft.service.response import SyftSuccess\n",
"from syft.types.uid import UID"
]
Expand Down Expand Up @@ -530,68 +530,14 @@
"## Create Association Requests"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29",
"metadata": {},
"outputs": [],
"source": [
"canada_server_peer = ServerPeer.from_client(ds_canada_client)\n",
"canada_server_peer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30",
"metadata": {},
"outputs": [],
"source": [
"italy_server_peer = ServerPeer.from_client(ds_italy_client)\n",
"italy_server_peer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31",
"metadata": {},
"outputs": [],
"source": [
"canada_conn_req = ds_canada_client.api.services.network.add_peer(italy_server_peer)\n",
"canada_conn_req"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32",
"metadata": {},
"outputs": [],
"source": [
"italy_conn_req = ds_italy_client.api.services.network.add_peer(canada_server_peer)\n",
"italy_conn_req"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33",
"metadata": {},
"outputs": [],
"source": [
"do_canada_client.requests[-1].approve()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34",
"metadata": {},
"outputs": [],
"source": [
"do_italy_client.requests[-1].approve()"
"exchange_routes(clients=[do_canada_client, do_italy_client], auto_approve=True)"
]
},
{
Expand Down Expand Up @@ -959,7 +905,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7692,7 +7692,7 @@
"outputs": [],
"source": [
"# syft absolute\n",
"from syft.service.project.project import check_route_reachability"
"from syft.service.network.utils import check_route_reachability"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/scenarios/enclave/04-data-scientist-join.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@
],
"source": [
"# syft absolute\n",
"from syft.service.project.project import check_route_reachability\n",
"from syft.service.network.utils import check_route_reachability\n",
"\n",
"check_route_reachability([model_owner_ds_client, model_auditor_ds_client])"
]
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from .service.model.model import CreateModelAsset as ModelAsset
from .service.model.model import SyftModelClass
from .service.model.model import syft_model
from .service.network.utils import check_route_reachability # noqa: F401
from .service.network.utils import exchange_routes # noqa: F401
from .service.notification.notifications import NotificationStatus
from .service.policy.policy import CreatePolicyRuleConstant as Constant
from .service.policy.policy import CustomInputPolicy
Expand Down
88 changes: 88 additions & 0 deletions packages/syft/src/syft/service/network/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# stdlib
from enum import Enum
import itertools
import logging
import threading
import time
from typing import cast

# relative
from ...client.client import SyftClient
from ...serde.serializable import serializable
from ...types.datetime import DateTime
from ..context import AuthedServiceContext
from ..request.request import Request
from ..response import SyftError
from ..response import SyftSuccess
from ..user.user_roles import ServiceRole
from .network_service import NetworkService
from .network_service import ServerPeerAssociationStatus
from .server_peer import ServerPeer
Expand Down Expand Up @@ -130,3 +136,85 @@ def stop(self) -> None:
self.thread = None
self.started_time = None
logger.info("Peer health check task stopped.")


def exchange_routes(
clients: list[SyftClient], auto_approve: bool = False
) -> SyftSuccess | SyftError:
"""Exchange routes between a list of clients."""
if auto_approve:
# Check that all clients are admin clients
for client in clients:
if not client.user_role == ServiceRole.ADMIN:
return SyftError(
message=f"Client {client} is not an admin client. "
"Only admin clients can auto-approve connection requests."
)

for client1, client2 in itertools.combinations(clients, 2):
peer1 = ServerPeer.from_client(client1)
peer2 = ServerPeer.from_client(client2)

client1_connection_request = client1.api.services.network.add_peer(peer2)
if isinstance(client1_connection_request, SyftError):
return SyftError(
message=f"Failed to add peer {peer2} to {client1}: {client1_connection_request}"
)

client2_connection_request = client2.api.services.network.add_peer(peer1)
if isinstance(client2_connection_request, SyftError):
return SyftError(
message=f"Failed to add peer {peer1} to {client2}: {client2_connection_request}"
)

if auto_approve:
if isinstance(client1_connection_request, Request):
res1 = client1_connection_request.approve()
if isinstance(res1, SyftError):
return SyftError(
message=f"Failed to approve connection request between {client1} and {client2}: {res1}"
)
if isinstance(client2_connection_request, Request):
res2 = client2_connection_request.approve()
if isinstance(res2, SyftError):
return SyftError(
message=f"Failed to approve connection request between {client2} and {client1}: {res2}"
)
logger.info(f"Exchanged routes between {client1} and {client2}")
else:
logger.info(f"Connection requests sent between {client1} and {client2}.")

return SyftSuccess(message="Routes exchanged successfully.")


class NetworkTopology(Enum):
STAR = "STAR"
MESH = "MESH"
HYBRID = "HYBRID"


def check_route_reachability(
clients: list[SyftClient], topology: NetworkTopology = NetworkTopology.MESH
) -> SyftSuccess | SyftError:
if topology == NetworkTopology.STAR:
return SyftError(message="STAR topology is not supported yet")
elif topology == NetworkTopology.MESH:
return check_mesh_topology(clients)
else:
return SyftError(message=f"Invalid topology: {topology}")


def check_mesh_topology(clients: list[SyftClient]) -> SyftSuccess | SyftError:
for client in clients:
for other_client in clients:
if client == other_client:
continue
result = client.api.services.network.ping_peer(
verify_key=other_client.root_verify_key
)
if isinstance(result, SyftError):
return SyftError(
message=f"{client.name}-<{client.id}> - cannot reach"
+ f"{other_client.name}-<{other_client.id} - {result.message}"
)
return SyftSuccess(message="All clients are reachable")
35 changes: 1 addition & 34 deletions packages/syft/src/syft/service/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Callable
from collections.abc import Iterable
import copy
from enum import Enum
import hashlib
import textwrap
import time
Expand Down Expand Up @@ -56,6 +55,7 @@
from ..network.network_service import ServerPeer
from ..network.routes import ServerRoute
from ..network.routes import connection_to_route
from ..network.utils import check_route_reachability
from ..request.request import Request
from ..request.request import RequestStatus
from ..response import SyftError
Expand Down Expand Up @@ -1707,36 +1707,3 @@ def create_project_event_hash(project_event: ProjectEvent) -> tuple[bytes, str]:
hash_object(project_event.creator_verify_key)[1],
]
)


class NetworkTopology(Enum):
STAR = "STAR"
MESH = "MESH"
HYBRID = "HYBRID"


def check_route_reachability(
clients: list[SyftClient], topology: NetworkTopology = NetworkTopology.MESH
) -> SyftSuccess | SyftError:
if topology == NetworkTopology.STAR:
return SyftError("STAR topology is not supported yet")
elif topology == NetworkTopology.MESH:
return check_mesh_topology(clients)
else:
return SyftError(message=f"Invalid topology: {topology}")


def check_mesh_topology(clients: list[SyftClient]) -> SyftSuccess | SyftError:
for client in clients:
for other_client in clients:
if client == other_client:
continue
result = client.api.services.network.ping_peer(
verify_key=other_client.root_verify_key
)
if isinstance(result, SyftError):
return SyftError(
message=f"{client.name}-<{client.id}> - cannot reach"
+ f"{other_client.name}-<{other_client.id} - {result.message}"
)
return SyftSuccess(message="All clients are reachable")

0 comments on commit 535c3cb

Please sign in to comment.