Skip to content

Commit

Permalink
Merge pull request #9130 from OpenMined/fl_tutorial
Browse files Browse the repository at this point in the history
Improve printing for syft.datasites
  • Loading branch information
iamtrask authored Aug 4, 2024
2 parents 2f6423a + 7946ca6 commit 3238317
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 15 deletions.
61 changes: 57 additions & 4 deletions packages/syft/src/syft/client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..service.network.server_peer import ServerPeerConnectionStatus
from ..service.response import SyftException
from ..types.server_url import ServerURL
from ..types.syft_object import SyftObject
from ..util.constants import DEFAULT_TIMEOUT
from .client import SyftClient as Client

Expand Down Expand Up @@ -186,6 +187,55 @@ def __getitem__(self, key: str | int) -> Client:
raise KeyError(f"Invalid key: {key} for {on}")


class Datasite(SyftObject):
__canonical_name__ = "ServerMetadata"
# __version__ = SYFT_OBJECT_VERSION_1

name: str
host_or_ip: str
version: str
protocol: str
admin_email: str
website: str
slack: str
slack_channel: str

__attr_searchable__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__attr_unique__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__repr_attrs__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__table_sort_attr__ = "name"


class DatasiteRegistry:
def __init__(self) -> None:
self.all_datasites: list[dict] = []
Expand All @@ -210,8 +260,7 @@ def check_datasite(datasite: dict) -> dict[Any, Any] | None:
online = res.json()["status"] == "ok"
elif "detail" in res.json():
online = True
except Exception as e:
print(e)
except Exception:
online = False
if online:
version = datasite.get("version", None)
Expand Down Expand Up @@ -245,9 +294,13 @@ def _repr_html_(self) -> str:
on = self.online_datasites
if len(on) == 0:
return "(no gateways online - try syft.gateways.all_networks to see offline gateways)"
df = pd.DataFrame(on)

return df._repr_html_() # type: ignore
# df = pd.DataFrame(on)
print(
"Add your datasite to this list: https://github.com/OpenMined/NetworkRegistry/"
)
# return df._repr_html_() # type: ignore
return ([Datasite(**ds) for ds in on])._repr_html_()

@staticmethod
def create_client(datasite: dict[str, Any]) -> Client:
Expand Down
24 changes: 13 additions & 11 deletions tests/integration/local/gateway_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from syft.client.gateway_client import GatewayClient
from syft.service.network.network_service import ServerPeerAssociationStatus
from syft.service.network.server_peer import ServerPeer
from syft.service.network.server_peer import ServerPeerConnectionStatus
from syft.service.network.utils import PeerHealthCheckTask
from syft.service.request.request import Request
from syft.service.response import SyftSuccess
Expand Down Expand Up @@ -164,16 +163,19 @@ def test_create_gateway(
assert isinstance(result, SyftSuccess)

time.sleep(PeerHealthCheckTask.repeat_time * 2 + 1)
assert len(sy.datasites.all_datasites) == 2
assert len(sy.datasites.online_datasites) == 2
# check for peer connection status
for peer in gateway_client.api.services.network.get_all_peers():
assert peer.ping_status == ServerPeerConnectionStatus.ACTIVE

# check the guest client
client = gateway_webserver.client
assert isinstance(client, GatewayClient)
assert client.metadata.server_type == ServerType.GATEWAY.value

# TRASK: i've changed the functionality here so that
# sy.datasites always goes out to the network
# assert len(sy.datasites.all_datasites) == 2
# assert len(sy.datasites.online_datasites) == 2
# # check for peer connection status
# for peer in gateway_client.api.services.network.get_all_peers():
# assert peer.ping_status == ServerPeerConnectionStatus.ACTIVE

# # check the guest client
# client = gateway_webserver.client
# assert isinstance(client, GatewayClient)
# assert client.metadata.server_type == ServerType.GATEWAY.value


@pytest.mark.local_server
Expand Down

0 comments on commit 3238317

Please sign in to comment.