Skip to content

Commit

Permalink
Merge pull request #8472 from OpenMined/resolve-cyclic-error
Browse files Browse the repository at this point in the history
remove syft imports from Orchestra
  • Loading branch information
rasswanth-s authored Feb 12, 2024
2 parents 79084ac + f1e0f4f commit ed20e8e
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 28 deletions.
85 changes: 81 additions & 4 deletions notebooks/api/0.8/11-container-images-k8s.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@
"metadata": {},
"outputs": [],
"source": [
"default_worker_pool = domain_client.api.services.worker_pool.scale(\n",
"default_pool_scale_res = domain_client.api.services.worker_pool.scale(\n",
" number=1, pool_name=\"default-pool\"\n",
")\n",
"assert not isinstance(result, sy.SyftError), str(result)\n",
"default_worker_pool"
"assert not isinstance(default_pool_scale_res, sy.SyftError), str(default_pool_scale_res)\n",
"default_pool_scale_res"
]
},
{
Expand Down Expand Up @@ -850,6 +850,31 @@
"assert result_matches.all()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b5f63c2-028a-4b48-a5f9-392ac89440ed",
"metadata": {},
"outputs": [],
"source": [
"# Scale Down the workers\n",
"custom_pool_scale_res = domain_client.api.services.worker_pool.scale(\n",
" number=1, pool_name=worker_pool_name\n",
")\n",
"assert not isinstance(custom_pool_scale_res, sy.SyftError), str(custom_pool_scale_res)\n",
"custom_pool_scale_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3325165b-525f-4ffd-add5-e0c93d235723",
"metadata": {},
"outputs": [],
"source": [
"assert len(domain_client.worker_pools[worker_pool_name].worker_list) == 1"
]
},
{
"cell_type": "markdown",
"id": "f20a29df-2e63-484f-8b67-d6a397722e66",
Expand Down Expand Up @@ -1080,6 +1105,31 @@
"assert len(worker_pool_list) == 3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18ddb1e7-8d8b-480c-b6a4-e4c79d27bcf1",
"metadata": {},
"outputs": [],
"source": [
"# Scale Down the workers\n",
"opendp_pool_scale_res = domain_client.api.services.worker_pool.scale(\n",
" number=1, pool_name=pool_name_opendp\n",
")\n",
"assert not isinstance(opendp_pool_scale_res, sy.SyftError), str(opendp_pool_scale_res)\n",
"opendp_pool_scale_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83b3ec7b-3fbe-429d-bd1e-5e9afa223c3c",
"metadata": {},
"outputs": [],
"source": [
"assert len(domain_client.worker_pools[pool_name_opendp].worker_list) == 1"
]
},
{
"cell_type": "markdown",
"id": "6e671e1e",
Expand Down Expand Up @@ -1238,6 +1288,33 @@
"assert domain_client.worker_pools[pool_name_recordlinkage]\n",
"assert len(domain_client.worker_pools[pool_name_recordlinkage].worker_list) == 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0cec28e8-784e-4a8d-91f9-f2481a967008",
"metadata": {},
"outputs": [],
"source": [
"# Scale down the workers\n",
"recordlinkage_pool_scale_res = domain_client.api.services.worker_pool.scale(\n",
" number=1, pool_name=pool_name_recordlinkage\n",
")\n",
"assert not isinstance(recordlinkage_pool_scale_res, sy.SyftError), str(\n",
" recordlinkage_pool_scale_res\n",
")\n",
"recordlinkage_pool_scale_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a43cf8cf-b8ca-4df4-aec9-6651d0a2fcda",
"metadata": {},
"outputs": [],
"source": [
"assert len(domain_client.worker_pools[pool_name_recordlinkage].worker_list) == 1"
]
}
],
"metadata": {
Expand All @@ -1256,7 +1333,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
24 changes: 6 additions & 18 deletions packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,13 @@

# relative
from .cli import str_to_bool
from .dummynum import DummyNum
from .grammar import find_available_port
from .names import random_name
from .util import ImportFromSyft
from .util import NodeSideType
from .util import NodeType
from .util import shell

try:
# syft absolute
from syft.abstract_node import NodeSideType
from syft.abstract_node import NodeType
from syft.protocol.data_protocol import stage_protocol_changes
from syft.service.response import SyftError
except Exception: # nosec
NodeSideType = DummyNum
NodeType = DummyNum

def stage_protocol_changes(*args: Any, **kwargs: Any) -> None:
pass

SyftError = DummyNum
# print("Please install syft with `pip install syft`")

DEFAULT_PORT = 8080
DEFAULT_URL = "http://localhost"
# Gevent used instead of threading module ,as we monkey patch gevent in syft
Expand Down Expand Up @@ -203,6 +189,7 @@ def register(
institution: Optional[str] = None,
website: Optional[str] = None,
) -> Any:
SyftError = ImportFromSyft.import_syft_error()
if not email:
email = input("Email: ")
if not password:
Expand Down Expand Up @@ -248,6 +235,7 @@ def deploy_to_python(
create_producer: bool = False,
queue_port: Optional[int] = None,
) -> Optional[NodeHandle]:
stage_protocol_changes = ImportFromSyft.import_stage_protocol_changes()
sy = get_syft_client()
if sy is None:
return sy
Expand All @@ -271,7 +259,7 @@ def deploy_to_python(
"processes": processes,
"dev_mode": dev_mode,
"tail": tail,
"node_type": node_type_enum,
"node_type": str(node_type_enum),
"node_side_type": node_side_type,
"enable_warnings": enable_warnings,
# new kwargs
Expand Down
49 changes: 49 additions & 0 deletions packages/hagrid/hagrid/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,61 @@
# stdlib
from enum import Enum
import os
import subprocess # nosec
import sys
from typing import Any
from typing import Callable
from typing import Tuple
from typing import Union
from urllib.parse import urlparse

# relative
from .dummynum import DummyNum


class NodeSideType(str, Enum):
LOW_SIDE = "low"
HIGH_SIDE = "high"

def __str__(self) -> str:
# Use values when transforming NodeType to str
return self.value


class NodeType(str, Enum):
DOMAIN = "domain"
NETWORK = "network"
ENCLAVE = "enclave"
GATEWAY = "gateway"

def __str__(self) -> str:
# Use values when transforming NodeType to str
return self.value


class ImportFromSyft:
@staticmethod
def import_syft_error() -> Callable:
try:
# syft absolute
from syft.service.response import SyftError
except Exception:
SyftError = DummyNum

return SyftError

@staticmethod
def import_stage_protocol_changes() -> Callable:
try:
# syft absolute
from syft.protocol.data_protocol import stage_protocol_changes
except Exception:

def stage_protocol_changes(*args: Any, **kwargs: Any) -> None:
pass

return stage_protocol_changes


def from_url(url: str) -> Tuple[str, str, int, str, Union[Any, str]]:
try:
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# relative
from . import gevent_patch # noqa: F401
from .abstract_node import NodeSideType # noqa: F401
from .abstract_node import NodeType # noqa: F401
from .client.client import connect # noqa: F401
from .client.client import login # noqa: F401
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/abstract_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@serializable()
class NodeType(Enum):
class NodeType(str, Enum):
DOMAIN = "domain"
NETWORK = "network"
ENCLAVE = "enclave"
Expand All @@ -26,6 +26,9 @@ class NodeSideType(str, Enum):
LOW_SIDE = "low"
HIGH_SIDE = "high"

def __str__(self) -> str:
return self.value


class AbstractNode:
id: Optional[UID]
Expand Down
5 changes: 5 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ def create_queue_config(
if queue_config:
queue_config_ = queue_config
elif queue_port is not None or n_consumers > 0 or create_producer:
if not create_producer and queue_port is None:
print("No queue port defined to bind consumers.")
queue_config_ = ZMQQueueConfig(
client_config=ZMQClientConfig(
create_producer=create_producer,
Expand Down Expand Up @@ -631,6 +633,9 @@ def named(
client_config=blob_client_config
)

node_type = NodeType(node_type)
node_side_type = NodeSideType(node_side_type)

return cls(
name=name,
id=uid,
Expand Down
16 changes: 11 additions & 5 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,17 @@ commands =
exit $return; \
fi'

#Integration + Gateway Connection Tests
# Gateway tests are not run in kuberetes, as currently,it does not have a way to configure
# high/low side warning flag.
bash -c " source ./scripts/get_k8s_secret_ci.sh; \
pytest tests/integration/network -k 'not test_domain_gateway_user_code' -p no:randomly -vvvv"

# Shutting down the gateway cluster to free up space, as the
# below code does not require gateway cluster
bash -c "k3d cluster delete testgateway1 || true"
bash -c "docker volume rm k3d-testgateway1-images --force || true"

; ; container workload
; bash -c 'if [[ "$PYTEST_MODULES" == *"container_workload"* ]]; then \
; echo "Starting Container Workload test"; date; \
Expand All @@ -786,11 +797,6 @@ commands =
bash -c " source ./scripts/get_k8s_secret_ci.sh; \
pytest --nbmake notebooks/api/0.8 -p no:randomly -k 'not 10-container-images.ipynb' -vvvv --nbmake-timeout=1000"

#Integration + Gateway Connection Tests
# Gateway tests are not run in kuberetes, as currently,it does not have a way to configure
# high/low side warning flag.
bash -c " source ./scripts/get_k8s_secret_ci.sh; \
pytest tests/integration/network -k 'not test_domain_gateway_user_code' -p no:randomly -vvvv"

# deleting clusters created
bash -c "k3d cluster delete testgateway1 || true"
Expand Down

0 comments on commit ed20e8e

Please sign in to comment.