Skip to content

Commit

Permalink
[tests/scenario] l0 and l2 bigquery tests work with python servers
Browse files Browse the repository at this point in the history
  • Loading branch information
khoaguin committed Sep 30, 2024
1 parent ff6f168 commit 40c9387
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 24 deletions.
9 changes: 8 additions & 1 deletion tests/scenariosv2/flows/admin_bigquery_pool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# stdlib
import os

# syft absolute
import syft as sy
from syft.orchestra import DeploymentType
from syft.util.test_helpers.worker_helpers import (
build_and_launch_worker_pool_from_docker_str,
)
Expand Down Expand Up @@ -33,8 +37,11 @@ def bq_create_pool(
)

ctx.logger.info(f"{msg} - Creating")

deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON)

build_and_launch_worker_pool_from_docker_str(
environment="remote",
environment=str(deployment_type),
client=admin_client,
worker_pool_name=worker_pool,
worker_dockerfile=worker_dockerfile,
Expand Down
36 changes: 19 additions & 17 deletions tests/scenariosv2/flows/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# stdlib
import os
from urllib.parse import urlparse

# syft absolute
import syft as sy
from syft.orchestra import DeploymentType
from syft.orchestra import ServerHandle

# relative
from ..sim.core import SimulatorContext
Expand All @@ -15,18 +14,21 @@ def server_info(client: sy.DatasiteClient) -> str:
return f"{client.name}(url={url}, side={client.metadata.server_side_type})"


def launch_server(ctx: SimulatorContext, server_url: str, server_name: str):
deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON)
ctx.logger.info(f"Deployment type: {deployment_type}")
if deployment_type == DeploymentType.PYTHON:
ctx.logger.info(f"Launching python server '{server_name}' at {server_url}")
parsed_url = urlparse(server_url)
port = parsed_url.port
sy.orchestra.launch(
name=server_name,
reset=True,
dev_mode=True,
port=port,
create_producer=True,
n_consumers=1,
)
def launch_server(
ctx: SimulatorContext,
server_url: str,
server_name: str,
server_side_type: str | None = "high",
) -> ServerHandle | None:
ctx.logger.info(f"Launching python server '{server_name}' at {server_url}")
parsed_url = urlparse(server_url)
port = parsed_url.port
return sy.orchestra.launch(
name=server_name,
server_side_type=server_side_type,
reset=True,
dev_mode=True,
port=port,
create_producer=True,
n_consumers=1,
)
25 changes: 23 additions & 2 deletions tests/scenariosv2/l0_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
import asyncio
from enum import auto
import os
import random

# third party
Expand All @@ -9,6 +10,7 @@

# syft absolute
import syft as sy
from syft.orchestra import DeploymentType
from syft.service.request.request import RequestStatus

# relative
Expand Down Expand Up @@ -368,16 +370,31 @@ async def sim_l0_scenario(ctx: SimulatorContext):
for _ in range(NUM_USERS)
]

deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON)
ctx.logger.info(f"Deployment type: {deployment_type}")

server_url_high = "http://localhost:8080"
launch_server(ctx, server_url_high, "syft-high")
if deployment_type == DeploymentType.PYTHON:
server_high = launch_server(
ctx=ctx,
server_url=server_url_high,
server_name="syft-high",
server_side_type="high",
)
admin_auth_high = dict( # noqa: C408
url=server_url_high,
email="[email protected]",
password="changethis",
)

server_url_low = "http://localhost:8081"
launch_server(ctx, server_url_low, "syft-low")
if deployment_type == DeploymentType.PYTHON:
server_low = launch_server(
ctx=ctx,
server_url=server_url_low,
server_name="syft-low",
server_side_type="low",
)
admin_auth_low = dict( # noqa: C408
url=server_url_low,
email="[email protected]",
Expand All @@ -395,6 +412,10 @@ async def sim_l0_scenario(ctx: SimulatorContext):
*[user_low_side_flow(ctx, server_url_low, user) for user in users],
)

if deployment_type == DeploymentType.PYTHON:
server_high.land()
server_low.land()


@pytest.mark.asyncio
async def test_l0_scenario(request):
Expand Down
14 changes: 10 additions & 4 deletions tests/scenariosv2/l2_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: just reset-high && pytest -s tests/scenariosv2/l2_test.py
## .logs files will be created in pwd

# stdlib
import asyncio
import os
import random

# third party
Expand All @@ -11,6 +9,7 @@

# syft absolute
import syft as sy
from syft.orchestra import DeploymentType

# relative
from .flows.user_bigquery_api import bq_submit_query
Expand Down Expand Up @@ -124,7 +123,11 @@ async def sim_l2_scenario(ctx: SimulatorContext):
]

server_url = "http://localhost:8080"
launch_server(ctx, server_url, "syft-high")
deployment_type = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", DeploymentType.PYTHON)
ctx.logger.info(f"Deployment type: {deployment_type}")
if deployment_type == DeploymentType.PYTHON:
server = launch_server(ctx, server_url, "syft-high")

admin_auth = {
"url": server_url,
"email": "[email protected]",
Expand All @@ -136,6 +139,9 @@ async def sim_l2_scenario(ctx: SimulatorContext):
*[user_flow(ctx, server_url, user) for user in users],
)

if deployment_type == DeploymentType.PYTHON:
server.land()


@pytest.mark.asyncio
async def test_l2_scenario(request):
Expand Down

0 comments on commit 40c9387

Please sign in to comment.