-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9193 from OpenMined/shubham/custom-image-api-test
Test for Twin Endpoint with Custom Workers
- Loading branch information
Showing
1 changed file
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# stdlib | ||
from collections.abc import Callable | ||
import os | ||
import sys | ||
import time | ||
|
||
# third party | ||
from faker import Faker | ||
import pytest | ||
|
||
# syft absolute | ||
import syft as sy | ||
from syft.client.datasite_client import DatasiteClient | ||
from syft.service.api.api import TwinAPIEndpoint | ||
from syft.service.response import SyftError | ||
from syft.service.response import SyftSuccess | ||
|
||
JOB_TIMEOUT = 20 | ||
|
||
|
||
def get_external_registry() -> str: | ||
"""Get the external registry to use for the worker image.""" | ||
return os.environ.get("EXTERNAL_REGISTRY", "docker.io") | ||
|
||
|
||
def get_worker_tag() -> str: | ||
"""Get the worker tag to use for the worker image.""" | ||
return os.environ.get("PRE_BUILT_WORKER_TAG", f"openmined/backend:{sy.__version__}") | ||
|
||
|
||
def public_function( | ||
context, | ||
) -> str: | ||
return "Public Function Execution" | ||
|
||
|
||
def private_function( | ||
context, | ||
) -> str: | ||
return "Private Function Execution" | ||
|
||
|
||
def get_twin_api_endpoint(worker_pool_name: str) -> TwinAPIEndpoint: | ||
"""Get a twin API endpoint with a custom worker pool name.""" | ||
|
||
public_func = sy.api_endpoint_method(settings={"Hello": "Public"})(public_function) | ||
pvt_func = sy.api_endpoint_method(settings={"Hello": "Private"})(private_function) | ||
|
||
new_endpoint = sy.TwinAPIEndpoint( | ||
path="second.query", | ||
mock_function=public_func, | ||
private_function=pvt_func, | ||
description="Lore ipsulum ...", | ||
worker_pool=worker_pool_name, | ||
) | ||
|
||
return new_endpoint | ||
|
||
|
||
faker = Faker() | ||
|
||
|
||
def get_ds_client(client: DatasiteClient) -> DatasiteClient: | ||
"""Get a datasite client with a registered user.""" | ||
pwd = faker.password() | ||
email = faker.email() | ||
client.register( | ||
name=faker.name(), | ||
email=email, | ||
password=pwd, | ||
password_verify=pwd, | ||
) | ||
return client.login(email=email, password=pwd) | ||
|
||
|
||
def get_syft_function(worker_pool_name: str, endpoint: Callable) -> Callable: | ||
@sy.syft_function_single_use(endpoint=endpoint, worker_pool_name=worker_pool_name) | ||
def job_function(endpoint): | ||
return endpoint() | ||
|
||
return job_function | ||
|
||
|
||
def submit_project(ds_client: DatasiteClient, syft_function: Callable): | ||
# Create a new project | ||
new_project = sy.Project( | ||
name=f"Project - {faker.text(max_nb_chars=20)}", | ||
description="Hi, I want to calculate the trade volume in million's with my cool code.", | ||
members=[ds_client], | ||
) | ||
|
||
result = new_project.create_code_request(syft_function, ds_client) | ||
assert isinstance(result, SyftSuccess) | ||
|
||
|
||
@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") | ||
# @pytest.mark.local_server | ||
def test_twin_api_with_custom_worker(full_high_worker): | ||
high_client = full_high_worker.login( | ||
email="[email protected]", password="changethis" | ||
) | ||
|
||
worker_pool_name = "custom-worker-pool" | ||
|
||
external_registry = get_external_registry() | ||
worker_docker_tag = get_worker_tag() | ||
|
||
# Create pre-built worker image | ||
docker_config = sy.PrebuiltWorkerConfig( | ||
tag=f"{external_registry}/{worker_docker_tag}" | ||
) | ||
|
||
# Submit the worker image | ||
submit_result = high_client.api.services.worker_image.submit( | ||
worker_config=docker_config | ||
) | ||
|
||
# Check if the submission was successful | ||
assert not isinstance(submit_result, SyftError), submit_result | ||
|
||
# Get the worker image | ||
worker_image = high_client.images.get_all()[-1] | ||
|
||
launch_result = high_client.api.services.worker_pool.launch( | ||
pool_name=worker_pool_name, | ||
image_uid=worker_image.id, | ||
num_workers=2, | ||
) | ||
|
||
# Check if the worker pool was launched successfully | ||
assert not isinstance(launch_result, SyftError), launch_result | ||
|
||
# Add the twin API endpoint | ||
twin_api_endpoint = get_twin_api_endpoint(worker_pool_name) | ||
twin_endpoint_result = high_client.api.services.api.add(endpoint=twin_api_endpoint) | ||
|
||
# Check if the twin API endpoint was added successfully | ||
assert isinstance(twin_endpoint_result, SyftSuccess) | ||
|
||
# validate the number of endpoints | ||
assert len(high_client.api.services.api.api_endpoints()) == 1 | ||
|
||
# refresh the client | ||
high_client.refresh() | ||
|
||
# Get datasite client | ||
high_client_ds = get_ds_client(high_client) | ||
|
||
# Execute the public endpoint | ||
mock_endpoint_result = high_client_ds.api.services.second.query() | ||
assert mock_endpoint_result == "Public Function Execution" | ||
|
||
# Get the syft function | ||
custom_function = get_syft_function( | ||
worker_pool_name, high_client_ds.api.services.second.query | ||
) | ||
|
||
# Submit the project | ||
submit_project(high_client_ds, custom_function) | ||
|
||
ds_email = high_client_ds.logged_in_user | ||
|
||
# Approve the request | ||
for r in high_client.requests.get_all(): | ||
if r.requesting_user_email == ds_email: | ||
r.approve() | ||
|
||
private_func_result_job = high_client_ds.code.job_function( | ||
endpoint=high_client_ds.api.services.second.query, blocking=False | ||
) | ||
|
||
# Wait for the job to complete | ||
job_start_time = time.time() | ||
while True: | ||
# Check if the job is resolved | ||
_ = private_func_result_job.resolved | ||
|
||
if private_func_result_job.resolve: | ||
break | ||
|
||
# Check if the job is timed out | ||
if time.time() - job_start_time > JOB_TIMEOUT: | ||
raise TimeoutError(f"Job did not complete in given time: {JOB_TIMEOUT}") | ||
time.sleep(1) | ||
|
||
# Check if the job worker is the same as the worker pool name | ||
private_func_job = high_client_ds.jobs.get(private_func_result_job.id) | ||
|
||
assert private_func_job is not None | ||
|
||
# Check if job is assigned to a worker | ||
assert private_func_job.job_worker_id is not None | ||
|
||
# Check if the job worker is the same as the worker pool name | ||
assert private_func_job.worker.worker_pool_name == worker_pool_name | ||
|
||
# Check if the job was successful | ||
assert private_func_result_job.resolved | ||
private_func_result = private_func_result_job.result | ||
|
||
assert not isinstance(private_func_result, SyftError), private_func_result | ||
|
||
assert private_func_result.get() == "Private Function Execution" |