Skip to content

Commit

Permalink
Merge branch 'dev' into aziz/skip_deletions
Browse files Browse the repository at this point in the history
  • Loading branch information
iamtrask authored Aug 4, 2024
2 parents a1de49f + 3238317 commit ace94b8
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
9 changes: 5 additions & 4 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from .client.registry import DatasiteRegistry
from .client.registry import EnclaveRegistry
from .client.registry import NetworkRegistry
from .client.search import Search
from .client.search import SearchResults

# from .client.search import Search
# from .client.search import SearchResults
from .client.syncing import compare_clients
from .client.syncing import compare_states
from .client.syncing import sync
Expand Down Expand Up @@ -147,5 +148,5 @@ def hello_baby() -> None:
print("Welcome to the world. \u2764\ufe0f")


def search(name: str) -> SearchResults:
return Search(_datasites()).search(name=name)
# def search(name: str) -> SearchResults:
# return Search(_datasites()).search(name=name)
146 changes: 146 additions & 0 deletions packages/syft/src/syft/client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..service.network.server_peer import ServerPeerConnectionStatus
from ..service.response import SyftException
from ..types.server_url import ServerURL
from ..types.syft_object import SyftObject
from ..util.constants import DEFAULT_TIMEOUT
from .client import SyftClient as Client

Expand All @@ -28,6 +29,10 @@

NETWORK_REGISTRY_REPO = "https://github.com/OpenMined/NetworkRegistry"

DATASITE_REGISTRY_URL = (
"https://raw.githubusercontent.com/OpenMined/NetworkRegistry/main/datasites.json"
)


def _get_all_networks(network_json: dict, version: str) -> list[dict]:
return network_json.get(version, {}).get("gateways", [])
Expand Down Expand Up @@ -182,7 +187,148 @@ def __getitem__(self, key: str | int) -> Client:
raise KeyError(f"Invalid key: {key} for {on}")


class Datasite(SyftObject):
__canonical_name__ = "ServerMetadata"
# __version__ = SYFT_OBJECT_VERSION_1

name: str
host_or_ip: str
version: str
protocol: str
admin_email: str
website: str
slack: str
slack_channel: str

__attr_searchable__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__attr_unique__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__repr_attrs__ = [
"name",
"host_or_ip",
"version",
"port",
"admin_email",
"website",
"slack",
"slack_channel",
"protocol",
]
__table_sort_attr__ = "name"


class DatasiteRegistry:
def __init__(self) -> None:
self.all_datasites: list[dict] = []
try:
response = requests.get(DATASITE_REGISTRY_URL) # nosec
datasites_json = response.json()
self.all_datasites = datasites_json["datasites"]
except Exception as e:
logger.warning(
f"Failed to get Datasite Registry, go checkout: {DATASITE_REGISTRY_URL}. {e}"
)

@property
def online_datasites(self) -> list[dict]:
datasites = self.all_datasites

def check_datasite(datasite: dict) -> dict[Any, Any] | None:
url = "http://" + datasite["host_or_ip"] + ":" + str(datasite["port"]) + "/"
try:
res = requests.get(url, timeout=DEFAULT_TIMEOUT) # nosec
if "status" in res.json():
online = res.json()["status"] == "ok"
elif "detail" in res.json():
online = True
except Exception:
online = False
if online:
version = datasite.get("version", None)
# Check if syft version was described in DatasiteRegistry
# If it's unknown, try to update it to an available version.
if not version or version == "unknown":
# If not defined, try to ask in /syft/version endpoint (supported by 0.7.0)
try:
version_url = url + "api/v2/metadata"
res = requests.get(version_url, timeout=DEFAULT_TIMEOUT) # nosec
if res.status_code == 200:
datasite["version"] = res.json()["syft_version"]
else:
datasite["version"] = "unknown"
except Exception:
datasite["version"] = "unknown"
return datasite
return None

# We can use a with statement to ensure threads are cleaned up promptly
with futures.ThreadPoolExecutor(max_workers=20) as executor:
# map
_online_datasites = list(
executor.map(lambda datasite: check_datasite(datasite), datasites)
)

online_datasites = [each for each in _online_datasites if each is not None]
return online_datasites

def _repr_html_(self) -> str:
on = self.online_datasites
if len(on) == 0:
return "(no gateways online - try syft.gateways.all_networks to see offline gateways)"

# df = pd.DataFrame(on)
print(
"Add your datasite to this list: https://github.com/OpenMined/NetworkRegistry/"
)
# return df._repr_html_() # type: ignore
return ([Datasite(**ds) for ds in on])._repr_html_()

@staticmethod
def create_client(datasite: dict[str, Any]) -> Client:
# relative
from .client import connect

try:
port = int(datasite["port"])
protocol = datasite["protocol"]
host_or_ip = datasite["host_or_ip"]
server_url = ServerURL(port=port, protocol=protocol, host_or_ip=host_or_ip)
client = connect(url=str(server_url))
return client.guest()
except Exception as e:
raise SyftException(f"Failed to login with: {datasite}. {e}")

def __getitem__(self, key: str | int) -> Client:
if isinstance(key, int):
return self.create_client(datasite=self.online_datasites[key])
else:
on = self.online_datasites
for datasite in on:
if datasite["name"] == key:
return self.create_client(datasite=datasite)
raise KeyError(f"Invalid key: {key} for {on}")


class NetworksOfDatasitesRegistry:
def __init__(self) -> None:
self.all_networks: list[dict] = []
self.all_datasites: dict[str, ServerPeer] = {}
Expand Down
104 changes: 49 additions & 55 deletions packages/syft/src/syft/client/search.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# stdlib
from concurrent.futures import ThreadPoolExecutor

# third party
from IPython.display import display

# relative
from ..service.dataset.dataset import Dataset
from ..service.metadata.server_metadata import ServerMetadataJSON
from ..service.network.network_service import ServerPeer
from ..service.response import SyftWarning
from ..types.uid import UID
from .client import SyftClient
from .registry import DatasiteRegistry


class SearchResults:
Expand Down Expand Up @@ -57,52 +51,52 @@ def __len__(self) -> int:
return len(self._datasets)


class Search:
def __init__(self, datasites: DatasiteRegistry) -> None:
self.datasites: list[tuple[ServerPeer, ServerMetadataJSON | None]] = (
datasites.online_datasites
)

@staticmethod
def __search_one_server(
peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str
) -> tuple[SyftClient | None, list[Dataset]]:
try:
peer, server_metadata = peer_tuple
client = peer.guest_client
results = client.api.services.dataset.search(name=name)
return (client, results)
except Exception as e: # noqa
warning = SyftWarning(
message=f"Got exception {e} at server {server_metadata.name}"
)
display(warning)
return (None, [])

def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]:
with ThreadPoolExecutor(max_workers=20) as executor:
# results: list[tuple[SyftClient | None, list[Dataset]]] = [
# self.__search_one_server(peer_tuple, name) for peer_tuple in self.datasites
# ]
results: list[tuple[SyftClient | None, list[Dataset]]] = list(
executor.map(
lambda peer_tuple: self.__search_one_server(peer_tuple, name),
self.datasites,
)
)
# filter out SyftError
filtered = [(client, result) for client, result in results if client and result]

return filtered

def search(self, name: str) -> SearchResults:
"""
Searches for a specific dataset by name.
Args:
name (str): The name of the dataset to search for.
Returns:
SearchResults: An object containing the search results.
"""
return SearchResults(self.__search(name))
# class Search:
# def __init__(self, datasites: DatasiteRegistry) -> None:
# self.datasites: list[tuple[ServerPeer, ServerMetadataJSON | None]] = (
# datasites.online_datasites
# )

# @staticmethod
# def __search_one_server(
# peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str
# ) -> tuple[SyftClient | None, list[Dataset]]:
# try:
# peer, server_metadata = peer_tuple
# client = peer.guest_client
# results = client.api.services.dataset.search(name=name)
# return (client, results)
# except Exception as e: # noqa
# warning = SyftWarning(
# message=f"Got exception {e} at server {server_metadata.name}"
# )
# display(warning)
# return (None, [])

# def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]:
# with ThreadPoolExecutor(max_workers=20) as executor:
# # results: list[tuple[SyftClient | None, list[Dataset]]] = [
# # self.__search_one_server(peer_tuple, name) for peer_tuple in self.datasites
# # ]
# results: list[tuple[SyftClient | None, list[Dataset]]] = list(
# executor.map(
# lambda peer_tuple: self.__search_one_server(peer_tuple, name),
# self.datasites,
# )
# )
# # filter out SyftError
# filtered = [(client, result) for client, result in results if client and result]

# return filtered

# def search(self, name: str) -> SearchResults:
# """
# Searches for a specific dataset by name.

# Args:
# name (str): The name of the dataset to search for.

# Returns:
# SearchResults: An object containing the search results.
# """
# return SearchResults(self.__search(name))
24 changes: 13 additions & 11 deletions tests/integration/local/gateway_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from syft.client.gateway_client import GatewayClient
from syft.service.network.network_service import ServerPeerAssociationStatus
from syft.service.network.server_peer import ServerPeer
from syft.service.network.server_peer import ServerPeerConnectionStatus
from syft.service.network.utils import PeerHealthCheckTask
from syft.service.request.request import Request
from syft.service.response import SyftSuccess
Expand Down Expand Up @@ -164,16 +163,19 @@ def test_create_gateway(
assert isinstance(result, SyftSuccess)

time.sleep(PeerHealthCheckTask.repeat_time * 2 + 1)
assert len(sy.datasites.all_datasites) == 2
assert len(sy.datasites.online_datasites) == 2
# check for peer connection status
for peer in gateway_client.api.services.network.get_all_peers():
assert peer.ping_status == ServerPeerConnectionStatus.ACTIVE

# check the guest client
client = gateway_webserver.client
assert isinstance(client, GatewayClient)
assert client.metadata.server_type == ServerType.GATEWAY.value

# TRASK: i've changed the functionality here so that
# sy.datasites always goes out to the network
# assert len(sy.datasites.all_datasites) == 2
# assert len(sy.datasites.online_datasites) == 2
# # check for peer connection status
# for peer in gateway_client.api.services.network.get_all_peers():
# assert peer.ping_status == ServerPeerConnectionStatus.ACTIVE

# # check the guest client
# client = gateway_webserver.client
# assert isinstance(client, GatewayClient)
# assert client.metadata.server_type == ServerType.GATEWAY.value


@pytest.mark.local_server
Expand Down
Loading

0 comments on commit ace94b8

Please sign in to comment.