Skip to content

Commit

Permalink
Merge pull request #9199 from OpenMined/fix_api_logs
Browse files Browse the repository at this point in the history
Fix log service for twin api
  • Loading branch information
IonesioJunior authored Aug 26, 2024
2 parents 862a6a9 + a189b40 commit 03d3a0c
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 28 deletions.
132 changes: 128 additions & 4 deletions notebooks/api/0.8/12-custom-api-endpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"outputs": [],
"source": [
"# stdlib\n",
"from typing import Any\n",
"\n",
"# syft absolute\n",
"import syft as sy\n",
Expand Down Expand Up @@ -69,7 +68,7 @@
"def public_endpoint_method(\n",
" context,\n",
" query: str,\n",
") -> Any:\n",
") -> bool:\n",
" return context.settings[\"key\"] == \"value\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -397,7 +396,7 @@
"def new_public_function(\n",
" context,\n",
" query: str,\n",
") -> Any:\n",
") -> bool:\n",
" return context.settings[\"key\"] == \"value\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -562,6 +561,131 @@
")\n",
"assert isinstance(response, SyftError), response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Syft Function/API Logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@sy.api_endpoint_method()\n",
"def public_log_function(\n",
" context,\n",
") -> str:\n",
" print(\"Logging Public Function Call\")\n",
" return \"Public Function Execution\"\n",
"\n",
"\n",
"@sy.api_endpoint_method()\n",
"def private_log_function(\n",
" context,\n",
") -> str:\n",
" print(\"Logging Private Function Call\")\n",
" return \"Private Function Execution\"\n",
"\n",
"\n",
"new_endpoint = sy.TwinAPIEndpoint(\n",
" path=\"test.log\",\n",
" mock_function=public_log_function,\n",
" private_function=private_log_function,\n",
" description=\"Lore ipsulum ...\",\n",
")\n",
"\n",
"# # Add it to the server.\n",
"response = datasite_client.api.services.api.add(endpoint=new_endpoint)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call_mock(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint.mock()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call_private(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint.private()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"# Create a project\n",
"project = sy.Project(\n",
" name=\"My Cool Project\",\n",
" description=\"\"\"Hi, I want to calculate the mean of your private data,\\\n",
" pretty please!\"\"\",\n",
" members=[datasite_client],\n",
")\n",
"project.create_code_request(test_log_call, datasite_client)\n",
"project.create_code_request(test_log_call_mock, datasite_client)\n",
"project.create_code_request(test_log_call_private, datasite_client)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_call_job = datasite_client.code.test_log_call(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")\n",
"log_call_mock_job = datasite_client.code.test_log_call_mock(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")\n",
"log_call_private_job = datasite_client.code.test_log_call_private(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# stdlib\n",
"import time\n",
"\n",
"# Iterate over the Jobs waiting them to finish their pipelines.\n",
"job_pool = [\n",
" (log_call_job, \"Logging Private Function Call\"),\n",
" (log_call_mock_job, \"Logging Public Function Call\"),\n",
" (log_call_private_job, \"Logging Private Function Call\"),\n",
"]\n",
"for job, expected_log in job_pool:\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
" while updated_job.status.value != \"completed\":\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
" time.sleep(1)\n",
" # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n",
" assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n",
" _print=False\n",
" )"
]
}
],
"metadata": {
Expand All @@ -580,7 +704,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
"hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105",
"action": "add"
}
},
"CustomEndpointActionObject": {
"2": {
"version": 2,
"hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089",
"action": "add"
}
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,7 @@ def add_api_endpoint_execution_to_queue(
credentials: SyftVerifyKey,
method: str,
path: str,
log_id: UID,
*args: Any,
worker_pool: str | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -1266,7 +1267,7 @@ def add_api_endpoint_execution_to_queue(
job_id=job_id,
worker_settings=worker_settings,
args=args,
kwargs={"path": path, **kwargs},
kwargs={"path": path, "log_id": log_id, **kwargs},
has_execute_permissions=True,
worker_pool=worker_pool_ref, # set worker pool reference as part of queue item
)
Expand All @@ -1277,6 +1278,7 @@ def add_api_endpoint_execution_to_queue(
credentials=credentials,
action=action,
job_type=JobType.TWINAPIJOB,
log_id=log_id,
)

def get_worker_pool_ref_by_name(
Expand Down Expand Up @@ -1360,9 +1362,11 @@ def add_queueitem_to_queue(
action: Action | None = None,
parent_job_id: UID | None = None,
user_id: UID | None = None,
log_id: UID | None = None,
job_type: JobType = JobType.JOB,
) -> Job | SyftError:
log_id = UID()
if log_id is None:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(server=self, credentials=credentials, role=role)

Expand Down
39 changes: 36 additions & 3 deletions packages/syft/src/syft/service/action/action_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from __future__ import annotations

# stdlib
from collections.abc import Callable
from enum import Enum
from enum import auto
from typing import Any

# relative
from ...serde.serializable import serializable
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default
from ...types.uid import UID
from ..context import AuthedServiceContext

Expand All @@ -21,15 +26,28 @@ class EXECUTION_MODE(Enum):


@serializable()
class CustomEndpointActionObject(SyftObject):
class CustomEndpointActionObjectV1(SyftObject):
__canonical_name__ = "CustomEndpointActionObject"
__version__ = SYFT_OBJECT_VERSION_1

endpoint_id: UID
context: AuthedServiceContext | None = None

def add_context(self, context: AuthedServiceContext) -> CustomEndpointActionObject:

@serializable()
class CustomEndpointActionObject(SyftObject):
__canonical_name__ = "CustomEndpointActionObject"
__version__ = SYFT_OBJECT_VERSION_2

endpoint_id: UID
context: AuthedServiceContext | None = None
log_id: UID | None = None

def add_context(
self, context: AuthedServiceContext, log_id: UID | None = None
) -> CustomEndpointActionObject:
self.context = context
self.log_id = log_id
return self

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -69,10 +87,25 @@ def __call_function(
__endpoint_mode = endpoint_service.execute_server_side_endpoint_by_id

return __endpoint_mode(
*args, context=self.context, endpoint_uid=self.endpoint_id, **kwargs
*args,
context=self.context,
endpoint_uid=self.endpoint_id,
log_id=self.log_id,
**kwargs,
)

def __check_context(self) -> AuthedServiceContext:
if self.context is None:
raise Exception("No context provided to CustomEndpointActionObject")
return self.context


@migrate(CustomEndpointActionObjectV1, CustomEndpointActionObject)
def migrate_custom_endpoint_v1_to_v2() -> list[Callable]:
return [make_set_default("log_id", None)]


@migrate(CustomEndpointActionObject, CustomEndpointActionObjectV1)
def migrate_custom_endpoint_v2_to_v1() -> list[Callable]:
# Use drop function on "notifications_enabled" attrubute
return [drop(["log_id"])]
Loading

0 comments on commit 03d3a0c

Please sign in to comment.