Skip to content

Commit

Permalink
Merge branch 'dev' into node_description
Browse files Browse the repository at this point in the history
  • Loading branch information
jcardonnet authored Jun 12, 2024
2 parents 2a2b296 + 8f81968 commit 7c758a0
Show file tree
Hide file tree
Showing 35 changed files with 149 additions and 150 deletions.
4 changes: 2 additions & 2 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def debox_signed_syftapicall_response(

def downgrade_signature(signature: Signature, object_versions: dict) -> Signature:
migrated_parameters = []
for _, parameter in signature.parameters.items():
for parameter in signature.parameters.values():
annotation = unwrap_and_migrate_annotation(
parameter.annotation, object_versions
)
Expand Down Expand Up @@ -1114,7 +1114,7 @@ def build_endpoint_tree(
endpoints: dict[str, LibEndpoint], communication_protocol: PROTOCOL_TYPE
) -> APIModule:
api_module = APIModule(path="", refresh_callback=self.refresh_api_callback)
for _, v in endpoints.items():
for v in endpoints.values():
signature = v.signature
if not v.has_self:
signature = signature_remove_self(signature)
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def get_nested_codes(code: UserCode) -> list[UserCode]:
if code.nested_codes is None:
return result

for _, (linked_code_obj, _) in code.nested_codes.items():
for linked_code_obj, _ in code.nested_codes.values():
nested_code = linked_code_obj.resolve
nested_code = deepcopy(nested_code)
nested_code.node_uid = code.node_uid
Expand Down
5 changes: 1 addition & 4 deletions packages/syft/src/syft/client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,7 @@ def check_enclave(enclave: dict) -> dict[Any, Any] | None:
executor.map(lambda enclave: check_enclave(enclave), enclaves)
)

online_enclaves = []
for each in _online_enclaves:
if each is not None:
online_enclaves.append(each)
online_enclaves = [each for each in _online_enclaves if each is not None]
return online_enclaves

def _repr_html_(self) -> str:
Expand Down
23 changes: 10 additions & 13 deletions packages/syft/src/syft/custom_worker/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,11 @@ def resolve_pod(client: kr8s.Api, pod: str | Pod) -> Pod | None:

@staticmethod
def get_logs(pods: list[Pod]) -> str:
"""Combine and return logs for all the pods as string"""
logs = []
for pod in pods:
logs.append(f"----------Logs for pod={pod.metadata.name}----------")
for log in pod.logs():
logs.append(log)

return "\n".join(logs)
"""Combine and return logs for all the pods as a single string."""
return "\n".join(
f"----------Logs for pod={pod.metadata.name}----------\n{''.join(pod.logs())}"
for pod in pods
)

@staticmethod
def get_pod_status(pod: Pod) -> PodStatus | None:
Expand All @@ -150,11 +147,11 @@ def get_pod_env(pod: Pod) -> list[dict] | None:
@staticmethod
def get_container_exit_code(pods: list[Pod]) -> list[int]:
"""Return the exit codes of all the containers in the given pods."""
exit_codes = []
for pod in pods:
for container_status in pod.status.containerStatuses:
exit_codes.append(container_status.state.terminated.exitCode)
return exit_codes
return [
container_status.state.terminated.exitCode
for pod in pods
for container_status in pod.status.containerStatuses
]

@staticmethod
def get_container_exit_message(pods: list[Pod]) -> str | None:
Expand Down
6 changes: 2 additions & 4 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def add_consumer_for_service(
consumer.run()

def remove_consumer_with_id(self, syft_worker_id: UID) -> None:
for _, consumers in self.queue_manager.consumers.items():
for consumers in self.queue_manager.consumers.values():
# Grab the list of consumers for the given queue
consumer_to_pop = None
for consumer_idx, consumer in enumerate(consumers):
Expand Down Expand Up @@ -833,9 +833,7 @@ def get_guest_client(self, verbose: bool = True) -> SyftClient:
def __repr__(self) -> str:
service_string = ""
if not self.is_subprocess:
services = []
for service in self.services:
services.append(service.__name__)
services = [service.__name__ for service in self.services]
service_string = ", ".join(sorted(services))
service_string = f"\n\nServices:\n{service_string}"
return f"{type(self).__name__}: {self.name} - {self.id} - {self.node_type}{service_string}"
Expand Down
24 changes: 12 additions & 12 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,14 @@ def find_python_processes_on_port(port: int) -> list[int]:

python_pids = []
for pid in pids:
try:
if system == "Windows":
command = (
f"wmic process where (ProcessId='{pid}') get ProcessId,CommandLine"
)
else:
command = f"ps -p {pid} -o pid,command"
if system == "Windows":
command = (
f"wmic process where (ProcessId='{pid}') get ProcessId,CommandLine"
)
else:
command = f"ps -p {pid} -o pid,command"

try:
process = subprocess.Popen( # nosec
command,
shell=True,
Expand All @@ -301,13 +301,13 @@ def find_python_processes_on_port(port: int) -> list[int]:
text=True,
)
output, _ = process.communicate()
lines = output.strip().split("\n")

if len(lines) > 1 and "python" in lines[1].lower():
python_pids.append(pid)

except Exception as e:
print(f"Error checking process {pid}: {e}")
continue

lines = output.strip().split("\n")
if len(lines) > 1 and "python" in lines[1].lower():
python_pids.append(pid)

return python_pids

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def calculate_supported_protocols(self) -> dict:
# we assume its supported until we prove otherwise
protocol_supported[v] = True
# iterate through each object
for canonical_name, _ in version_data["object_versions"].items():
for canonical_name in version_data["object_versions"].keys():
if canonical_name not in self.state:
protocol_supported[v] = False
break
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/serde/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def rs_proto2object(proto: _DynamicStructBuilder) -> Any:
# relative
from ..node.node import CODE_RELOADER

for _, load_user_code in CODE_RELOADER.items():
for load_user_code in CODE_RELOADER.values():
load_user_code()
try:
class_type = getattr(sys.modules[".".join(module_parts)], klass)
Expand Down
7 changes: 4 additions & 3 deletions packages/syft/src/syft/serde/recursive_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection:
from .deserialize import _deserialize

MAX_TRAVERSAL_LIMIT = 2**64 - 1
values = []

with iterable_schema.from_bytes(
blob, traversal_limit_in_words=MAX_TRAVERSAL_LIMIT
) as msg:
for element in msg.values:
values.append(_deserialize(combine_bytes(element), from_bytes=True))
values = [
_deserialize(combine_bytes(element), from_bytes=True)
for element in msg.values
]

return iterable_type(values)

Expand Down
11 changes: 3 additions & 8 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,14 +1090,9 @@ def syft_make_action(
if kwargs is None:
kwargs = {}

arg_ids = []
kwarg_ids = {}

for obj in args:
arg_ids.append(self._syft_prepare_obj_uid(obj))
arg_ids = [self._syft_prepare_obj_uid(obj) for obj in args]

for k, obj in kwargs.items():
kwarg_ids[k] = self._syft_prepare_obj_uid(obj)
kwarg_ids = {k: self._syft_prepare_obj_uid(obj) for k, obj in kwargs.items()}

action = Action(
path=path,
Expand Down Expand Up @@ -2172,7 +2167,7 @@ def has_action_data_empty(args: Any, kwargs: Any) -> bool:
if is_action_data_empty(a):
return True

for _, a in kwargs.items():
for a in kwargs.values():
if is_action_data_empty(a):
return True
return False
9 changes: 4 additions & 5 deletions packages/syft/src/syft/service/api/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,10 @@ def api_endpoints(
return SyftError(message=result.err())

all_api_endpoints = result.ok()
api_endpoint_view = []
for api_endpoint in all_api_endpoints:
api_endpoint_view.append(
api_endpoint.to(TwinAPIEndpointView, context=context)
)
api_endpoint_view = [
api_endpoint.to(TwinAPIEndpointView, context=context)
for api_endpoint in all_api_endpoints
]

return api_endpoint_view

Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _get_input_policy(self) -> InputPolicy | None:
):
# TODO: Tech Debt here
node_view_workaround = False
for k, _ in self.input_policy_init_kwargs.items():
for k in self.input_policy_init_kwargs.keys():
if isinstance(k, NodeIdentity):
node_view_workaround = True

Expand Down Expand Up @@ -727,7 +727,7 @@ def _inner_repr(self, level: int = 0) -> str:
[f"{' '*level}{substring}" for substring in md.split("\n")[:-1]]
)
if self.nested_codes is not None:
for _, (obj, _) in self.nested_codes.items():
for obj, _ in self.nested_codes.values():
code = obj.resolve
md += "\n"
md += code._inner_repr(level=level + 1)
Expand Down Expand Up @@ -876,7 +876,7 @@ def _ephemeral_node_call(
# And need only ActionObjects
# Also, this works only on the assumption that all inputs
# are ActionObjects, which might change in the future
for _, id in obj_dict.items():
for id in obj_dict.values():
mock_obj = api.services.action.get_mock(id)
if isinstance(mock_obj, SyftError):
data_obj = api.services.action.get(id)
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/code_history/code_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __getitem__(self, key: str | int) -> CodeHistoriesDict | SyftError:
return api.services.code_history.get_history_for_user(key)

def _repr_html_(self) -> str:
rows = []
for user, funcs in self.user_dict.items():
rows += [{"user": user, "UserCodes": funcs}]
rows = [
{"user": user, "UserCodes": funcs} for user, funcs in self.user_dict.items()
]
return create_table_template(rows, "UserCodeHistory", icon=None)
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def get_code(uid: UID) -> UserCode | SyftError:
code_versions_dict = {}

for code_history in code_histories:
user_code_list = []
for uid in code_history.user_code_history:
user_code_list.append(get_code(uid))
user_code_list = [
get_code(uid) for uid in code_history.user_code_history
]
code_versions = CodeHistoryView(
user_code_history=user_code_list,
service_func_name=code_history.service_func_name,
Expand Down
5 changes: 1 addition & 4 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,7 @@ def _repr_html_(self) -> Any:
"""

def action_ids(self) -> list[UID]:
data = []
for asset in self.asset_list:
if asset.action_id:
data.append(asset.action_id)
data = [asset.action_id for asset in self.asset_list if asset.action_id]
return data

@property
Expand Down
16 changes: 7 additions & 9 deletions packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,14 @@ def get_assets_by_action_id(
) -> list[Asset] | SyftError:
"""Get Assets by an Action ID"""
datasets = self.get_by_action_id(context=context, uid=uid)
assets = []
if isinstance(datasets, list):
for dataset in datasets:
for asset in dataset.asset_list:
if asset.action_id == uid:
assets.append(asset)
return assets
elif isinstance(datasets, SyftError):
if isinstance(datasets, SyftError):
return datasets
return []
return [
asset
for dataset in datasets
for asset in dataset.asset_list
if asset.action_id == uid
]

@service_method(
path="dataset.delete_by_uid",
Expand Down
24 changes: 16 additions & 8 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
import random
from string import Template
from time import sleep
from typing import Any

# third party
Expand Down Expand Up @@ -644,23 +645,30 @@ def _repr_html_(self) -> str:
def wait(
self, job_only: bool = False, timeout: int | None = None
) -> Any | SyftNotReady:
# stdlib
from time import sleep
self.fetch()
if self.resolved:
return self.resolve

api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
)
if self.resolved:
return self.resolve

if not job_only and self.result is not None:
self.result.wait(timeout)

if api is None:
raise ValueError(
f"Can't access Syft API. You must login to {self.syft_node_location}"
)

workers = api.services.worker.get_all()
if not isinstance(workers, SyftError) and len(workers) == 0:
return SyftError(
message="This node has no workers. "
"You need to start a worker to run jobs "
"by setting n_consumers > 0."
)

if not job_only and self.result is not None:
self.result.wait(timeout)

print_warning = True
counter = 0
while True:
Expand Down
17 changes: 9 additions & 8 deletions packages/syft/src/syft/service/network/network_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,14 +900,15 @@ def _get_association_requests_by_peer_id(
RequestService.get_all
)
all_requests: list[Request] = request_get_all_method(context)
association_requests: list[Request] = []
for request in all_requests:
for change in request.changes:
if (
isinstance(change, AssociationRequestChange)
and change.remote_peer.id == peer_id
):
association_requests.append(request)
association_requests: list[Request] = [
request
for request in all_requests
if any(
isinstance(change, AssociationRequestChange)
and change.remote_peer.id == peer_id
for change in request.changes
)
]

return sorted(
association_requests, key=lambda request: request.request_time.utc_timestamp
Expand Down
Loading

0 comments on commit 7c758a0

Please sign in to comment.