Skip to content

Commit

Permalink
Merge pull request #9125 from OpenMined/fl_tutorial
Browse files Browse the repository at this point in the history
Simplify syft.datasites registry to operate against a static file
  • Loading branch information
iamtrask authored Aug 4, 2024
2 parents 887010a + 1e16e41 commit 2f6423a
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.10.9"
},
"toc": {
"base_numbering": 1,
Expand Down
9 changes: 5 additions & 4 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from .client.registry import DatasiteRegistry
from .client.registry import EnclaveRegistry
from .client.registry import NetworkRegistry
from .client.search import Search
from .client.search import SearchResults

# from .client.search import Search
# from .client.search import SearchResults
from .client.syncing import compare_clients
from .client.syncing import compare_states
from .client.syncing import sync
Expand Down Expand Up @@ -147,5 +148,5 @@ def hello_baby() -> None:
print("Welcome to the world. \u2764\ufe0f")


def search(name: str) -> SearchResults:
return Search(_datasites()).search(name=name)
# def search(name: str) -> SearchResults:
# return Search(_datasites()).search(name=name)
93 changes: 93 additions & 0 deletions packages/syft/src/syft/client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

NETWORK_REGISTRY_REPO = "https://github.com/OpenMined/NetworkRegistry"

DATASITE_REGISTRY_URL = (
"https://raw.githubusercontent.com/OpenMined/NetworkRegistry/main/datasites.json"
)


def _get_all_networks(network_json: dict, version: str) -> list[dict]:
return network_json.get(version, {}).get("gateways", [])
Expand Down Expand Up @@ -183,6 +187,95 @@ def __getitem__(self, key: str | int) -> Client:


class DatasiteRegistry:
def __init__(self) -> None:
self.all_datasites: list[dict] = []
try:
response = requests.get(DATASITE_REGISTRY_URL) # nosec
datasites_json = response.json()
self.all_datasites = datasites_json["datasites"]
except Exception as e:
logger.warning(
f"Failed to get Datasite Registry, go checkout: {DATASITE_REGISTRY_URL}. {e}"
)

@property
def online_datasites(self) -> list[dict]:
datasites = self.all_datasites

def check_datasite(datasite: dict) -> dict[Any, Any] | None:
url = "http://" + datasite["host_or_ip"] + ":" + str(datasite["port"]) + "/"
try:
res = requests.get(url, timeout=DEFAULT_TIMEOUT) # nosec
if "status" in res.json():
online = res.json()["status"] == "ok"
elif "detail" in res.json():
online = True
except Exception as e:
print(e)
online = False
if online:
version = datasite.get("version", None)
# Check if syft version was described in DatasiteRegistry
# If it's unknown, try to update it to an available version.
if not version or version == "unknown":
# If not defined, try to ask in /syft/version endpoint (supported by 0.7.0)
try:
version_url = url + "api/v2/metadata"
res = requests.get(version_url, timeout=DEFAULT_TIMEOUT) # nosec
if res.status_code == 200:
datasite["version"] = res.json()["syft_version"]
else:
datasite["version"] = "unknown"
except Exception:
datasite["version"] = "unknown"
return datasite
return None

# We can use a with statement to ensure threads are cleaned up promptly
with futures.ThreadPoolExecutor(max_workers=20) as executor:
# map
_online_datasites = list(
executor.map(lambda datasite: check_datasite(datasite), datasites)
)

online_datasites = [each for each in _online_datasites if each is not None]
return online_datasites

def _repr_html_(self) -> str:
on = self.online_datasites
if len(on) == 0:
return "(no gateways online - try syft.gateways.all_networks to see offline gateways)"
df = pd.DataFrame(on)

return df._repr_html_() # type: ignore

@staticmethod
def create_client(datasite: dict[str, Any]) -> Client:
# relative
from .client import connect

try:
port = int(datasite["port"])
protocol = datasite["protocol"]
host_or_ip = datasite["host_or_ip"]
server_url = ServerURL(port=port, protocol=protocol, host_or_ip=host_or_ip)
client = connect(url=str(server_url))
return client.guest()
except Exception as e:
raise SyftException(f"Failed to login with: {datasite}. {e}")

def __getitem__(self, key: str | int) -> Client:
if isinstance(key, int):
return self.create_client(datasite=self.online_datasites[key])
else:
on = self.online_datasites
for datasite in on:
if datasite["name"] == key:
return self.create_client(datasite=datasite)
raise KeyError(f"Invalid key: {key} for {on}")


class NetworksOfDatasitesRegistry:
def __init__(self) -> None:
self.all_networks: list[dict] = []
self.all_datasites: dict[str, ServerPeer] = {}
Expand Down
104 changes: 49 additions & 55 deletions packages/syft/src/syft/client/search.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# stdlib
from concurrent.futures import ThreadPoolExecutor

# third party
from IPython.display import display

# relative
from ..service.dataset.dataset import Dataset
from ..service.metadata.server_metadata import ServerMetadataJSON
from ..service.network.network_service import ServerPeer
from ..service.response import SyftWarning
from ..types.uid import UID
from .client import SyftClient
from .registry import DatasiteRegistry


class SearchResults:
Expand Down Expand Up @@ -57,52 +51,52 @@ def __len__(self) -> int:
return len(self._datasets)


class Search:
def __init__(self, datasites: DatasiteRegistry) -> None:
self.datasites: list[tuple[ServerPeer, ServerMetadataJSON | None]] = (
datasites.online_datasites
)

@staticmethod
def __search_one_server(
peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str
) -> tuple[SyftClient | None, list[Dataset]]:
try:
peer, server_metadata = peer_tuple
client = peer.guest_client
results = client.api.services.dataset.search(name=name)
return (client, results)
except Exception as e: # noqa
warning = SyftWarning(
message=f"Got exception {e} at server {server_metadata.name}"
)
display(warning)
return (None, [])

def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]:
with ThreadPoolExecutor(max_workers=20) as executor:
# results: list[tuple[SyftClient | None, list[Dataset]]] = [
# self.__search_one_server(peer_tuple, name) for peer_tuple in self.datasites
# ]
results: list[tuple[SyftClient | None, list[Dataset]]] = list(
executor.map(
lambda peer_tuple: self.__search_one_server(peer_tuple, name),
self.datasites,
)
)
# filter out SyftError
filtered = [(client, result) for client, result in results if client and result]

return filtered

def search(self, name: str) -> SearchResults:
"""
Searches for a specific dataset by name.
Args:
name (str): The name of the dataset to search for.
Returns:
SearchResults: An object containing the search results.
"""
return SearchResults(self.__search(name))
# class Search:
# def __init__(self, datasites: DatasiteRegistry) -> None:
# self.datasites: list[tuple[ServerPeer, ServerMetadataJSON | None]] = (
# datasites.online_datasites
# )

# @staticmethod
# def __search_one_server(
# peer_tuple: tuple[ServerPeer, ServerMetadataJSON], name: str
# ) -> tuple[SyftClient | None, list[Dataset]]:
# try:
# peer, server_metadata = peer_tuple
# client = peer.guest_client
# results = client.api.services.dataset.search(name=name)
# return (client, results)
# except Exception as e: # noqa
# warning = SyftWarning(
# message=f"Got exception {e} at server {server_metadata.name}"
# )
# display(warning)
# return (None, [])

# def __search(self, name: str) -> list[tuple[SyftClient, list[Dataset]]]:
# with ThreadPoolExecutor(max_workers=20) as executor:
# # results: list[tuple[SyftClient | None, list[Dataset]]] = [
# # self.__search_one_server(peer_tuple, name) for peer_tuple in self.datasites
# # ]
# results: list[tuple[SyftClient | None, list[Dataset]]] = list(
# executor.map(
# lambda peer_tuple: self.__search_one_server(peer_tuple, name),
# self.datasites,
# )
# )
# # filter out SyftError
# filtered = [(client, result) for client, result in results if client and result]

# return filtered

# def search(self, name: str) -> SearchResults:
# """
# Searches for a specific dataset by name.

# Args:
# name (str): The name of the dataset to search for.

# Returns:
# SearchResults: An object containing the search results.
# """
# return SearchResults(self.__search(name))
46 changes: 23 additions & 23 deletions tests/integration/network/gateway_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from syft.client.datasite_client import DatasiteClient
from syft.client.gateway_client import GatewayClient
from syft.client.registry import NetworkRegistry
from syft.client.search import SearchResults
from syft.service.dataset.dataset import Dataset
from syft.service.network.association_request import AssociationRequestChange
from syft.service.network.network_service import ServerPeerAssociationStatus
from syft.service.network.routes import HTTPServerRoute
Expand Down Expand Up @@ -133,9 +131,11 @@ def test_datasite_connect_to_gateway(
assert len(gateway_client.peers) == 1

time.sleep(PeerHealthCheckTask.repeat_time * 2 + 1)

# this is the wrong test — sy.datasites checks the gateway registry
# check that the datasite is online on the network
assert len(sy.datasites.all_datasites) == 1
assert len(sy.datasites.online_datasites) == 1
# assert len(sy.datasites.all_datasites) == 1
# assert len(sy.datasites.online_datasites) == 1

proxy_datasite_client = gateway_client.peers[0]
datasite_peer = datasite_client.peers[0]
Expand Down Expand Up @@ -215,25 +215,25 @@ def test_dataset_search(set_env_var, gateway_port: int, datasite_1_port: int) ->
# we need to wait to make sure peers health check is done
time.sleep(PeerHealthCheckTask.repeat_time * 2 + 1)
# test if the dataset can be searched by the syft network
right_search = sy.search(dataset_name)
assert isinstance(right_search, SearchResults)
assert len(right_search) == 1
dataset = right_search[0]
assert isinstance(dataset, Dataset)
assert len(dataset.assets) == 1
assert isinstance(dataset.assets[0].mock, np.ndarray)
assert dataset.assets[0].data is None

# search a wrong dataset should return an empty list
wrong_search = sy.search(_random_hash())
assert len(wrong_search) == 0
# right_search = sy.search(dataset_name)
# assert isinstance(right_search, SearchResults)
# assert len(right_search) == 1
# dataset = right_search[0]
# assert isinstance(dataset, Dataset)
# assert len(dataset.assets) == 1
# assert isinstance(dataset.assets[0].mock, np.ndarray)
# assert dataset.assets[0].data is None

# the datasite client delete the dataset
datasite_client.api.services.dataset.delete(uid=dataset.id)
# # search a wrong dataset should return an empty list
# wrong_search = sy.search(_random_hash())
# assert len(wrong_search) == 0

# Remove existing peers
assert isinstance(_remove_existing_peers(datasite_client), SyftSuccess)
assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess)
# # the datasite client delete the dataset
# datasite_client.api.services.dataset.delete(uid=dataset.id)

# # Remove existing peers
# assert isinstance(_remove_existing_peers(datasite_client), SyftSuccess)
# assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess)


@pytest.mark.skip(reason="Possible bug")
Expand Down Expand Up @@ -352,8 +352,8 @@ def test_deleting_peers(set_env_var, datasite_1_port: int, gateway_port: int) ->
# check that the online datasites and gateways are updated
time.sleep(PeerHealthCheckTask.repeat_time * 2 + 1)
assert len(sy.gateways.all_networks) == 1
assert len(sy.datasites.all_datasites) == 0
assert len(sy.datasites.online_datasites) == 0
# assert len(sy.datasites.all_datasites) == 0
# assert len(sy.datasites.online_datasites) == 0

# reconnect the datasite to the gateway
result = datasite_client.connect_to_gateway(gateway_client)
Expand Down

0 comments on commit 2f6423a

Please sign in to comment.