Skip to content

Commit

Permalink
Merge pull request #9215 from OpenMined/ionesio/fix_api_logs
Browse files Browse the repository at this point in the history
Add Twin API endpoint logs
  • Loading branch information
koenvanderveen authored Aug 28, 2024
2 parents 395dcad + 16e422c commit b802a15
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 25 deletions.
143 changes: 141 additions & 2 deletions notebooks/api/0.8/12-custom-api-endpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,154 @@
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Syft Function/API Logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@sy.api_endpoint_method()\n",
"def public_log_function(\n",
" context,\n",
") -> str:\n",
" print(\"Logging Public Function Call\")\n",
" return \"Public Function Execution\"\n",
"\n",
"\n",
"@sy.api_endpoint_method()\n",
"def private_log_function(\n",
" context,\n",
") -> str:\n",
" print(\"Logging Private Function Call\")\n",
" return \"Private Function Execution\"\n",
"\n",
"\n",
"new_endpoint = sy.TwinAPIEndpoint(\n",
" path=\"test.log\",\n",
" mock_function=public_log_function,\n",
" private_function=private_log_function,\n",
" description=\"Lore ipsulum ...\",\n",
")\n",
"\n",
"# # Add it to the server.\n",
"response = datasite_client.api.services.api.add(endpoint=new_endpoint)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call_mock(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint.mock()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"@sy.syft_function_single_use(endpoint=datasite_client.api.services.test.log)\n",
"def test_log_call_private(endpoint): # noqa: F811\n",
" print(\"In Syft Function Context\")\n",
" endpoint.private()\n",
" print(\"After API endpoint call\")\n",
" return True\n",
"\n",
"\n",
"# Create a project\n",
"project = sy.Project(\n",
" name=\"My Cool Project\",\n",
" description=\"\"\"Hi, I want to calculate the mean of your private data,\\\n",
" pretty please!\"\"\",\n",
" members=[datasite_client],\n",
")\n",
"project.create_code_request(test_log_call, datasite_client)\n",
"project.create_code_request(test_log_call_mock, datasite_client)\n",
"project.create_code_request(test_log_call_private, datasite_client)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_call_job = datasite_client.code.test_log_call(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_call_mock_job = datasite_client.code.test_log_call_mock(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_call_private_job = datasite_client.code.test_log_call_private(\n",
" endpoint=datasite_client.api.services.test.log, blocking=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# stdlib\n",
"import time\n",
"\n",
"# Iterate over the Jobs waiting them to finish their pipelines.\n",
"job_pool = [\n",
" (log_call_job, \"Logging Private Function Call\"),\n",
" (log_call_mock_job, \"Logging Public Function Call\"),\n",
" (log_call_private_job, \"Logging Private Function Call\"),\n",
"]\n",
"for job, expected_log in job_pool:\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
" while updated_job.status.value != \"completed\":\n",
" updated_job = datasite_client.api.services.job.get(job.id)\n",
" time.sleep(1)\n",
" # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n",
" assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n",
" _print=False\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
Expand All @@ -552,7 +691,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
5 changes: 5 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@
"version": 1,
"hash": "c7addbaf2777707f3e91e5c1e092343476cd22efc4ec8617f39ccf76e61a5a14",
"action": "add"
},
"2": {
"version": 2,
"hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089",
"action": "add"
}
},
"DataSubject": {
Expand Down
7 changes: 5 additions & 2 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ def add_api_endpoint_execution_to_queue(
credentials: SyftVerifyKey,
method: str,
path: str,
log_id: UID,
*args: Any,
worker_pool: str | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -1294,7 +1295,7 @@ def add_api_endpoint_execution_to_queue(
job_id=job_id,
worker_settings=worker_settings,
args=args,
kwargs={"path": path, **kwargs},
kwargs={"path": path, "log_id": log_id, **kwargs},
has_execute_permissions=True,
worker_pool=worker_pool_ref, # set worker pool reference as part of queue item
)
Expand Down Expand Up @@ -1385,9 +1386,11 @@ def add_queueitem_to_queue(
action: Action | None = None,
parent_job_id: UID | None = None,
user_id: UID | None = None,
log_id: UID | None = None,
job_type: JobType = JobType.JOB,
) -> Job:
log_id = UID()
if log_id is None:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(server=self, credentials=credentials, role=role)

Expand Down
38 changes: 35 additions & 3 deletions packages/syft/src/syft/service/action/action_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
from __future__ import annotations

# stdlib
from collections.abc import Callable
from enum import Enum
from enum import auto
from typing import Any

# relative
from ...serde.serializable import serializable
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default
from ...types.uid import UID
from ..context import AuthedServiceContext

Expand All @@ -21,15 +26,28 @@ class EXECUTION_MODE(Enum):


@serializable()
class CustomEndpointActionObject(SyftObject):
class CustomEndpointActionObjectV1(SyftObject):
__canonical_name__ = "CustomEndpointActionObject"
__version__ = SYFT_OBJECT_VERSION_1

endpoint_id: UID
context: AuthedServiceContext | None = None

def add_context(self, context: AuthedServiceContext) -> CustomEndpointActionObject:

@serializable()
class CustomEndpointActionObject(SyftObject):
__canonical_name__ = "CustomEndpointActionObject"
__version__ = SYFT_OBJECT_VERSION_2

endpoint_id: UID
context: AuthedServiceContext | None = None
log_id: UID | None = None

def add_context(
self, context: AuthedServiceContext, log_id: UID | None = None
) -> CustomEndpointActionObject:
self.context = context
self.log_id = log_id
return self

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -69,10 +87,24 @@ def __call_function(
__endpoint_mode = endpoint_service.execute_server_side_endpoint_by_id

return __endpoint_mode(
*args, context=self.context, endpoint_uid=self.endpoint_id, **kwargs
*args,
context=self.context,
endpoint_uid=self.endpoint_id,
log_id=self.log_id,
**kwargs,
).unwrap()

def __check_context(self) -> AuthedServiceContext:
if self.context is None:
raise Exception("No context provided to CustomEndpointActionObject")
return self.context


@migrate(CustomEndpointActionObjectV1, CustomEndpointActionObject)
def migrate_custom_endpoint_v1_to_v2() -> list[Callable]:
return [make_set_default("log_id", None)]


@migrate(CustomEndpointActionObject, CustomEndpointActionObjectV1)
def migrate_custom_endpoint_v2_to_v1() -> list[Callable]:
return [drop(["log_id"])]
Loading

0 comments on commit b802a15

Please sign in to comment.