diff --git a/tests/scenariosv2/flows/admin_bigquery_pool.py b/tests/scenariosv2/flows/admin_bigquery_pool.py index 6f42d6a72b2..11c75feb641 100644 --- a/tests/scenariosv2/flows/admin_bigquery_pool.py +++ b/tests/scenariosv2/flows/admin_bigquery_pool.py @@ -1,5 +1,9 @@ +# stdlib +import os + # syft absolute import syft as sy +from syft.orchestra import DeploymentType from syft.util.test_helpers.worker_helpers import ( build_and_launch_worker_pool_from_docker_str, ) @@ -33,8 +37,11 @@ def bq_create_pool( ) ctx.logger.info(f"{msg} - Creating") + + deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON) + build_and_launch_worker_pool_from_docker_str( - environment="remote", + environment=str(deployment_type), client=admin_client, worker_pool_name=worker_pool, worker_dockerfile=worker_dockerfile, diff --git a/tests/scenariosv2/flows/utils.py b/tests/scenariosv2/flows/utils.py index 8492ea1cec1..c96a346e3a8 100644 --- a/tests/scenariosv2/flows/utils.py +++ b/tests/scenariosv2/flows/utils.py @@ -1,10 +1,9 @@ # stdlib -import os from urllib.parse import urlparse # syft absolute import syft as sy -from syft.orchestra import DeploymentType +from syft.orchestra import ServerHandle # relative from ..sim.core import SimulatorContext @@ -15,18 +14,21 @@ def server_info(client: sy.DatasiteClient) -> str: return f"{client.name}(url={url}, side={client.metadata.server_side_type})" -def launch_server(ctx: SimulatorContext, server_url: str, server_name: str): - deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON) - ctx.logger.info(f"Deployment type: {deployment_type}") - if deployment_type == DeploymentType.PYTHON: - ctx.logger.info(f"Launching python server '{server_name}' at {server_url}") - parsed_url = urlparse(server_url) - port = parsed_url.port - sy.orchestra.launch( - name=server_name, - reset=True, - dev_mode=True, - port=port, - create_producer=True, - n_consumers=1, - ) +def launch_server( + ctx: SimulatorContext, + server_url: str, + server_name: str, + server_side_type: str | None = "high", +) -> ServerHandle | None: + ctx.logger.info(f"Launching python server '{server_name}' at {server_url}") + parsed_url = urlparse(server_url) + port = parsed_url.port + return sy.orchestra.launch( + name=server_name, + server_side_type=server_side_type, + reset=True, + dev_mode=True, + port=port, + create_producer=True, + n_consumers=1, + ) diff --git a/tests/scenariosv2/l0_test.py b/tests/scenariosv2/l0_test.py index e57c7618a6a..061cea97bfb 100644 --- a/tests/scenariosv2/l0_test.py +++ b/tests/scenariosv2/l0_test.py @@ -1,6 +1,7 @@ # stdlib import asyncio from enum import auto +import os import random # third party @@ -9,6 +10,7 @@ # syft absolute import syft as sy +from syft.orchestra import DeploymentType from syft.service.request.request import RequestStatus # relative @@ -368,8 +370,17 @@ async def sim_l0_scenario(ctx: SimulatorContext): for _ in range(NUM_USERS) ] + deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON) + ctx.logger.info(f"Deployment type: {deployment_type}") + server_url_high = "http://localhost:8080" - launch_server(ctx, server_url_high, "syft-high") + if deployment_type == DeploymentType.PYTHON: + server_high = launch_server( + ctx=ctx, + server_url=server_url_high, + server_name="syft-high", + server_side_type="high", + ) admin_auth_high = dict( # noqa: C408 url=server_url_high, email="info@openmined.org", @@ -377,7 +388,13 @@ async def sim_l0_scenario(ctx: SimulatorContext): ) server_url_low = "http://localhost:8081" - launch_server(ctx, server_url_low, "syft-low") + if deployment_type == DeploymentType.PYTHON: + server_low = launch_server( + ctx=ctx, + server_url=server_url_low, + server_name="syft-low", + server_side_type="low", + ) admin_auth_low = dict( # noqa: C408 url=server_url_low, email="info@openmined.org", @@ -395,6 +412,10 @@ async def sim_l0_scenario(ctx: SimulatorContext): *[user_low_side_flow(ctx, server_url_low, user) for user in users], ) + if deployment_type == DeploymentType.PYTHON: + server_high.land() + server_low.land() + @pytest.mark.asyncio async def test_l0_scenario(request): diff --git a/tests/scenariosv2/l2_test.py b/tests/scenariosv2/l2_test.py index 9198338e12d..20937eb30c3 100644 --- a/tests/scenariosv2/l2_test.py +++ b/tests/scenariosv2/l2_test.py @@ -1,8 +1,6 @@ -# RUN: just reset-high && pytest -s tests/scenariosv2/l2_test.py -## .logs files will be created in pwd - # stdlib import asyncio +import os import random # third party @@ -11,6 +9,7 @@ # syft absolute import syft as sy +from syft.orchestra import DeploymentType # relative from .flows.user_bigquery_api import bq_submit_query @@ -124,7 +123,11 @@ async def sim_l2_scenario(ctx: SimulatorContext): ] server_url = "http://localhost:8080" - launch_server(ctx, server_url, "syft-high") + deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON) + ctx.logger.info(f"Deployment type: {deployment_type}") + if deployment_type == DeploymentType.PYTHON: + server = launch_server(ctx, server_url, "syft-high") + admin_auth = { "url": server_url, "email": "info@openmined.org", @@ -136,6 +139,9 @@ async def sim_l2_scenario(ctx: SimulatorContext): *[user_flow(ctx, server_url, user) for user in users], ) + if deployment_type == DeploymentType.PYTHON: + server.land() + @pytest.mark.asyncio async def test_l2_scenario(request):