diff --git a/notebooks/Experimental/k8s.ipynb b/notebooks/Experimental/k8s.ipynb new file mode 100644 index 00000000000..3812f229bd0 --- /dev/null +++ b/notebooks/Experimental/k8s.ipynb @@ -0,0 +1,482 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e7da7ef7-09f2-431e-82f7-4e28d50488fa", + "metadata": {}, + "outputs": [], + "source": [ + "SYFT_VERSION = \">=0.8.2.b0,<0.9\"\n", + "package_string = f'\"syft{SYFT_VERSION}\"'\n", + "# %pip install {package_string} -q" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74f04dcf-5fd9-4526-990e-91ffbc0df8c6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "kj/filesystem-disk-unix.c++:1734: warning: PWD environment variable doesn't match current directory; pwd = /home/ionesio/workspace/PySyft\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ The installed version of syft==0.8.5b1 matches the requirement >=0.8.2b0 and the requirement <0.9\n" + ] + } + ], + "source": [ + "# syft absolute\n", + "import syft as sy\n", + "\n", + "sy.requires(SYFT_VERSION)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fac2594c-b5e2-40ae-89c0-83c687225103", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Staging Protocol Changes...\n", + "Starting test-domain-1 server on 0.0.0.0:19706\n", + "\n", + "WARNING: private key is based on node name: test-domain-1 in dev_mode. Don't run this in production.\n", + "SQLite Store Path:\n", + "!open file:///tmp/7bca415d13ed4ec881f0d0aede098dbb.sqlite\n", + "\n", + "Waiting for server to start.Creating default worker image with tag='local-dev'\n", + "Building default worker image with tag=local-dev\n", + "Setting up worker poolname=default-pool workers=0 image_uid=71ff9f4a9a034f1ab7295e37973236ca in_memory=True\n", + "Created default worker pool.\n", + "...Data Migrated to latest version !!!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Started server process [502803]\n", + "INFO: Waiting for application startup.\n", + "INFO: Application startup complete.\n", + "INFO: Uvicorn running on http://0.0.0.0:19706 (Press CTRL+C to quit)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:47166 - \"GET /api/v2/metadata HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47180 - \"GET /api/v2/metadata HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47180 - \"GET /api/v2/metadata HTTP/1.1\" 200 OK\n", + ". Done.\n", + "Logged into as GUEST\n", + "INFO: 127.0.0.1:47180 - \"POST /api/v2/login HTTP/1.1\" 200 OK\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ionesio/workspace/bq_collab/PySyft/packages/syft/src/syft/types/syft_object.py:599: TypeHintWarning: Skipping type check against 'ServiceRole'; this looks like a string-form forward reference imported from another module\n", + " check_type(value, var_annotation)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:47180 - \"GET /api/v2/api?verify_key=aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f&communication_protocol=dev HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47186 - \"POST /api/v2/api_call HTTP/1.1\" 200 OK\n", + "Logged into as \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ionesio/workspace/bq_collab/PySyft/packages/syft/src/syft/types/syft_object.py:599: TypeHintWarning: Skipping type check against 'ServiceRole'; this looks like a string-form forward reference imported from another module\n", + " check_type(value, var_annotation)\n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:47180 - \"POST /api/v2/register HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47202 - \"POST /api/v2/api_call HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47208 - \"GET /api/v2/metadata HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47208 - \"GET /api/v2/metadata HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47208 - \"POST /api/v2/login HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47208 - \"GET /api/v2/api?verify_key=9a25c8482066894f8770635ed2b17ba4d8cea6374b7ad4b6f1ac5132922146bb&communication_protocol=dev HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47218 - \"POST /api/v2/api_call HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:47230 - \"POST /api/v2/api_call HTTP/1.1\" 200 OK\n" + ] + } + ], + "source": [ + "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True, reset=True)\n", + "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4115e0d4-3aab-4509-acc7-40c86c9f8504", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: User 'New User' successfully registered! To see users, run `[your_client].users`

" + ], + "text/plain": [ + "SyftSuccess: User 'New User' successfully registered! To see users, run `[your_client].users`" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain_client.register(\n", + " email=\"user@openmined.org\",\n", + " password=\"verysecurepassword\",\n", + " password_verify=\"verysecurepassword\",\n", + " name=\"New User\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "42455389-19e5-4815-a62b-867d2aaa952b", + "metadata": {}, + "outputs": [], + "source": [ + "@sy.api_endpoint(path=\"project.query\")\n", + "def run_query_high_side(\n", + " context,\n", + " sql_query: str,\n", + ") -> str:\n", + " # third party\n", + " from google.cloud import bigquery\n", + " from google.oauth2 import service_account\n", + "\n", + " SERVICE_ACCOUNT_KEY = {}\n", + " credentials = service_account.Credentials.from_service_account_info(\n", + " SERVICE_ACCOUNT_KEY\n", + " )\n", + " scoped_credentials = credentials.with_scopes(\n", + " [\"https://www.googleapis.com/auth/cloud-platform\"]\n", + " )\n", + "\n", + " client = bigquery.Client(\n", + " credentials=scoped_credentials,\n", + " location=\"us-west1\",\n", + " )\n", + "\n", + " rows = client.query_and_wait(\n", + " sql_query,\n", + " project=SERVICE_ACCOUNT_KEY[\"project_id\"],\n", + " )\n", + "\n", + " return rows.to_dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "aa044bde-265b-4c28-b626-8f8ddf8c216a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: CustomAPIEndpoint added: syft.service.api.api.CustomAPIEndpoint

" + ], + "text/plain": [ + "SyftSuccess: CustomAPIEndpoint added: syft.service.api.api.CustomAPIEndpoint" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response = domain_client.api.services.api.set(endpoint=run_query_high_side)\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "10679327-d006-4140-b941-c93879a8ac87", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged into as GUEST\n", + "Logged into as \n" + ] + } + ], + "source": [ + "domain_guest = node.login(email=\"user@openmined.org\", password=\"verysecurepassword\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "172b496b-9dd7-42fa-943a-691a89c5ca86", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_name = \"\"\n", + "table_name = \"\"\n", + "sql_query = f\"SELECT * FROM {dataset_name}.{table_name} LIMIT 100\"\n", + "query_results = domain_guest.api.services.project.query(sql_query=sql_query)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b5b03512-8bf9-4c76-b557-b0b2d8095bd6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
int64_field_0idnamesubscribers_countpermalinknsfwspam
04t5_via1x/r/mylittlepony4323081/r//r/mylittleponyNaNFalse
15t5_cv9gn/r/polyamory2425929/r//r/polyamoryNaNFalse
210t5_8p2tq/r/Catholicism4062607/r//r/CatholicismNaNFalse
316t5_8fcro/r/cordcutters7543226/r//r/cordcuttersNaNFalse
417t5_td5of/r/stevenuniverse2692168/r//r/stevenuniverseNaNFalse
........................
95305t5_jgydw/r/cannabis7703201/r//r/cannabisNaNFalse
96311t5_3mfau/r/marvelmemes4288492/r//r/marvelmemesNaNFalse
97317t5_ub3c8/r/ghibli6029127/r//r/ghibliNaNFalse
98319t5_fbgo3/r/birdsarentreal3416317/r//r/birdsarentrealNaNFalse
99320t5_mue7v/r/polandball8894111/r//r/polandballNaNFalse
\n", + "

100 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " int64_field_0 id name subscribers_count \\\n", + "0 4 t5_via1x /r/mylittlepony 4323081 \n", + "1 5 t5_cv9gn /r/polyamory 2425929 \n", + "2 10 t5_8p2tq /r/Catholicism 4062607 \n", + "3 16 t5_8fcro /r/cordcutters 7543226 \n", + "4 17 t5_td5of /r/stevenuniverse 2692168 \n", + ".. ... ... ... ... \n", + "95 305 t5_jgydw /r/cannabis 7703201 \n", + "96 311 t5_3mfau /r/marvelmemes 4288492 \n", + "97 317 t5_ub3c8 /r/ghibli 6029127 \n", + "98 319 t5_fbgo3 /r/birdsarentreal 3416317 \n", + "99 320 t5_mue7v /r/polandball 8894111 \n", + "\n", + " permalink nsfw spam \n", + "0 /r//r/mylittlepony NaN False \n", + "1 /r//r/polyamory NaN False \n", + "2 /r//r/Catholicism NaN False \n", + "3 /r//r/cordcutters NaN False \n", + "4 /r//r/stevenuniverse NaN False \n", + ".. ... ... ... \n", + "95 /r//r/cannabis NaN False \n", + "96 /r//r/marvelmemes NaN False \n", + "97 /r//r/ghibli NaN False \n", + "98 /r//r/birdsarentreal NaN False \n", + "99 /r//r/polandball NaN False \n", + "\n", + "[100 rows x 7 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0rc1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 2cc80e6aa90..3177ae623b7 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -26,13 +26,15 @@ vars: CONTAINER_REGISTRY: "docker.io" NODE_NAME: "mynode" VERSION: "0.8.5-beta.6" + PLATFORM: $(uname -m | grep -q 'arm64' && echo "arm64" || echo "amd64") + # This is a list of `images` that DevSpace can build for this project # We recommend to skip image building during development (devspace dev) as much as possible images: backend: image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_BACKEND}" - buildKit: {} + buildKit: { args: ["--platform", "linux/${PLATFORM}"] } dockerfile: ./backend/backend.dockerfile context: ../ tags: @@ -40,7 +42,8 @@ images: frontend: image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_FRONTEND}" buildKit: - args: ["--target", "grid-ui-production"] + args: + ["--target", "grid-ui-production", "--platform", "linux/${PLATFORM}"] dockerfile: ./frontend/frontend.dockerfile target: "grid-ui-production" context: ./frontend @@ -48,7 +51,7 @@ images: - dev-${DEVSPACE_TIMESTAMP} seaweedfs: image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_SEAWEEDFS}" - buildKit: {} + buildKit: { args: ["--platform", "linux/${PLATFORM}"] } buildArgs: SEAWEEDFS_VERSION: ${SEAWEEDFS_VERSION} dockerfile: ./seaweedfs/seaweedfs.dockerfile diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index e6e865f53fb..efe43b298c8 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -47,6 +47,8 @@ from .service.action.action_object import ActionObject # noqa: F401 from .service.action.plan import Plan # noqa: F401 from .service.action.plan import planify # noqa: F401 +from .service.api.api import api_endpoint # noqa: F401 +from .service.api.api import create_new_api_endpoint as TwinAPIEndpoint # noqa: F401 from .service.code.user_code import UserCodeStatus # noqa: F401; noqa: F401 from .service.code.user_code import syft_function # noqa: F401; noqa: F401 from .service.code.user_code import syft_function_single_use # noqa: F401; noqa: F401 diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index d9a19dbb1a5..14a232ba211 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -621,6 +621,8 @@ def for_user( user_verify_key: SyftVerifyKey | None = None, ) -> SyftAPI: # relative + from ..service.api.api_service import APIService + # TODO: Maybe there is a possibility of merging ServiceConfig and APIEndpoint from ..service.code.user_code_service import UserCodeService @@ -717,6 +719,26 @@ def for_user( ) endpoints[unique_path] = endpoint + # get admin defined custom api endpoints + method = node.get_method_with_context(APIService.get_endpoints, context) + custom_endpoints = method() + for custom_endpoint in custom_endpoints: + pre_kwargs = {"path": custom_endpoint.path} + service_path = "api.call" + path = custom_endpoint.path + api_end = custom_endpoint.path.split(".")[-1] + endpoint = APIEndpoint( + service_path=service_path, + module_path=path, + name=api_end, + description="", + doc_string="", + signature=custom_endpoint.signature, + has_self=False, + pre_kwargs=pre_kwargs, + ) + endpoints[path] = endpoint + return SyftAPI( node_name=node.name, node_uid=node.id, diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index ba2de258904..06e08bd870a 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -45,6 +45,7 @@ from ..service.action.action_store import DictActionStore from ..service.action.action_store import MongoActionStore from ..service.action.action_store import SQLiteActionStore +from ..service.api.api_service import APIService from ..service.blob_storage.service import BlobStorageService from ..service.code.status_service import UserCodeStatusService from ..service.code.user_code_service import UserCodeService @@ -371,6 +372,7 @@ def __init__( SyftWorkerImageService, SyftWorkerPoolService, SyftImageRegistryService, + APIService, SyncService, OutputService, UserCodeStatusService, @@ -990,6 +992,7 @@ def _construct_services(self) -> None: SyftWorkerImageService, SyftWorkerPoolService, SyftImageRegistryService, + APIService, SyncService, OutputService, UserCodeStatusService, diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py new file mode 100644 index 00000000000..e1cf5e64792 --- /dev/null +++ b/packages/syft/src/syft/service/api/api.py @@ -0,0 +1,312 @@ +# stdlib +import ast +from collections.abc import Callable +import inspect +from inspect import Signature +import keyword +import re +from typing import Any + +# third party +from pydantic import ValidationError +from pydantic import field_validator +from pydantic import model_validator +from result import Err +from result import Ok +from result import Result + +# relative +from ...serde.serializable import serializable +from ...serde.signature import signature_remove_context +from ...types.syft_object import PartialSyftObject +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ...types.transforms import TransformContext +from ...types.transforms import drop +from ...types.transforms import generate_id +from ...types.transforms import transform +from ..context import AuthedServiceContext +from ..response import SyftError + + +def get_signature(func: Callable) -> Signature: + sig = inspect.signature(func) + sig = signature_remove_context(sig) + return sig + + +@serializable() +class TwinAPIEndpointView(SyftObject): + # version + __canonical_name__ = "CustomAPIView" + __version__ = SYFT_OBJECT_VERSION_1 + + path: str + signature: Signature + access: str = "Public" + + __repr_attrs__ = [ + "path", + "signature", + ] + + def _coll_repr_(self) -> dict[str, Any]: + return { + "API path": self.path, + "Signature": self.path + str(self.signature), + "Access": self.access, + } + + +class Endpoint(SyftObject): + """Base class to perform basic Endpoint validation for both public/private endpoints.""" + + # version + __canonical_name__ = "CustomApiEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + @field_validator("api_code", check_fields=False) + @classmethod + def validate_api_code(cls, api_code: str) -> str: + valid_code = True + try: + ast.parse(api_code) + except SyntaxError: + # If the code isn't valid Python syntax + valid_code = False + + if not valid_code: + raise ValueError("Code must be a valid Python function.") + + return api_code + + @field_validator("func_name", check_fields=False) + @classmethod + def validate_func_name(cls, func_name: str) -> str: + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", func_name) or keyword.iskeyword( + func_name + ): + raise ValueError("Invalid function name.") + return func_name + + @field_validator("context_vars", check_fields=False) + @classmethod + def validate_context_vars( + cls, context_vars: dict[str, Any] | None + ) -> dict[str, Any] | None: + return context_vars + + +@serializable() +class PrivateAPIEndpoint(Endpoint): + # version + __canonical_name__ = "PrivateAPIEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + api_code: str + func_name: str + context_vars: dict[str, Any] | None = None + + +@serializable() +class PublicAPIEndpoint(Endpoint): + # version + __canonical_name__ = "PublicAPIEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + api_code: str + func_name: str + context_vars: dict[str, Any] | None = None + + +@serializable() +class UpdateTwinAPIEndpoint(PartialSyftObject): + # version + __canonical_name__ = "UpdateTwinAPIEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + path: str + private_code: PrivateAPIEndpoint + public_code: PublicAPIEndpoint + + +@serializable() +class CreateTwinAPIEndpoint(SyftObject): + # version + __canonical_name__ = "CreateTwinAPIEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + path: str + private_code: PrivateAPIEndpoint + public_code: PublicAPIEndpoint | None = None + signature: Signature + + @model_validator(mode="before") + @classmethod + def validate_signature(cls, data: dict[str, Any]) -> dict[str, Any]: + # TODO: Implement a signature check. + mismatch_signatures = False + if data.get("public_code") is not None and mismatch_signatures: + raise ValueError( + "Public and Private API Endpoints must have the same signature." + ) + + return data + + @field_validator("path") + @classmethod + def validate_path(cls, path: str) -> str: + if not re.match(r"^[a-z]+(\.[a-z]+)*$", path): + raise ValueError('String must be a path-like string (e.g., "new.endpoint")') + return path + + @field_validator("private_code") + @classmethod + def validate_private_code( + cls, private_code: PrivateAPIEndpoint + ) -> PrivateAPIEndpoint: + return private_code + + @field_validator("public_code") + @classmethod + def validate_public_code( + cls, public_code: PublicAPIEndpoint | None + ) -> PublicAPIEndpoint | None: + return public_code + + +@serializable() +class TwinAPIEndpoint(SyftObject): + # version + __canonical_name__ = "TwinAPIEndpoint" + __version__ = SYFT_OBJECT_VERSION_1 + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + path: str + private_code: PrivateAPIEndpoint + public_code: PublicAPIEndpoint | None = None + signature: Signature + + __attr_searchable__ = ["path"] + __attr_unique__ = ["path"] + + def has_mock(self) -> bool: + return self.api_mock_code is not None + + def select_code(self, context: AuthedServiceContext) -> Result[Ok, Err]: + if context.role.value == 128: + return Ok(self.private_code) + + if self.public_code: + return Ok(self.public_code) + + return Err("No public code available") + + def exec(self, context: AuthedServiceContext, *args: Any, **kwargs: Any) -> Any: + try: + executable_code = self.select_code(context) + if executable_code.is_err(): + return context, SyftError(message=executable_code.err()) + + executable_code = executable_code.ok() + + inner_function = ast.parse(executable_code.api_code).body[0] + inner_function.decorator_list = [] + # compile the function + raw_byte_code = compile(ast.unparse(inner_function), "", "exec") + # load it + exec(raw_byte_code) # nosec + # execute it + evil_string = f"{executable_code.func_name}(context, *args, **kwargs)" + result = eval(evil_string, None, locals()) # nosec + # return the results + return context, result + except Exception as e: + print(f"Failed to run CustomAPIEndpoint Code. {e}") + return SyftError(message=e) + + +def set_access_type(context: TransformContext) -> TransformContext: + if context.output is not None and context.obj is not None: + if context.obj.public_code is not None: + context.output["access"] = "Public" + else: + context.output["access"] = "Private" + return context + + +@transform(CreateTwinAPIEndpoint, TwinAPIEndpoint) +def endpoint_create_to_twin_endpoint() -> list[Callable]: + return [generate_id] + + +@transform(TwinAPIEndpoint, TwinAPIEndpointView) +def twin_endpoint_to_view() -> list[Callable]: + return [ + set_access_type, + drop("private_code"), + drop("public_code"), + ] + + +def api_endpoint(path: str) -> Callable[..., TwinAPIEndpoint | SyftError]: + def decorator(f: Callable) -> TwinAPIEndpoint | SyftError: + try: + res = CreateTwinAPIEndpoint( + path=path, + private_code=PrivateAPIEndpoint( + api_code=inspect.getsource(f), + func_name=f.__name__, + ), + signature=get_signature(f), + ) + except ValidationError as e: + for error in e.errors(): + error_msg = error["msg"] + res = SyftError(message=error_msg) + return res + + return decorator + + +def create_new_api_endpoint( + path: str, + private: Callable[..., Any], + description: str | None = None, + public: Callable[..., Any] | None = None, + private_configs: dict[str, Any] | None = None, + public_configs: dict[str, Any] | None = None, +) -> CreateTwinAPIEndpoint | SyftError: + try: + if public is not None: + return CreateTwinAPIEndpoint( + path=path, + private_code=PrivateAPIEndpoint( + api_code=inspect.getsource(private), + func_name=private.__name__, + context_vars=private_configs, + ), + public_code=PublicAPIEndpoint( + api_code=inspect.getsource(public), + func_name=public.__name__, + context_vars=public_configs, + ), + signature=get_signature(private), + ) + + return CreateTwinAPIEndpoint( + path=path, + private_code=PrivateAPIEndpoint( + api_code=inspect.getsource(private), + func_name=private.__name__, + context_vars=private_configs, + ), + signature=get_signature(private), + ) + except ValidationError as e: + for error in e.errors(): + error_msg = error["msg"] + + return SyftError(message=error_msg) diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py new file mode 100644 index 00000000000..84a923845f8 --- /dev/null +++ b/packages/syft/src/syft/service/api/api_service.py @@ -0,0 +1,147 @@ +# stdlib +from typing import Any +from typing import cast + +# relative +from ...abstract_node import AbstractNode +from ...serde.serializable import serializable +from ...store.document_store import DocumentStore +from ...types.uid import UID +from ...util.telemetry import instrument +from ..context import AuthedServiceContext +from ..response import SyftError +from ..response import SyftSuccess +from ..service import AbstractService +from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL +from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL +from ..user.user_roles import GUEST_ROLE_LEVEL +from .api import CreateTwinAPIEndpoint +from .api import TwinAPIEndpoint +from .api import TwinAPIEndpointView +from .api import UpdateTwinAPIEndpoint +from .api_stash import TwinAPIEndpointStash + + +@instrument +@serializable() +class APIService(AbstractService): + store: DocumentStore + stash: TwinAPIEndpointStash + + def __init__(self, store: DocumentStore) -> None: + self.store = store + self.stash = TwinAPIEndpointStash(store=store) + + @service_method( + path="api.add", + name="add", + roles=ADMIN_ROLE_LEVEL, + ) + def set( + self, context: AuthedServiceContext, endpoint: CreateTwinAPIEndpoint + ) -> SyftSuccess | SyftError: + """Register an CustomAPIEndpoint.""" + new_endpoint = endpoint.to(TwinAPIEndpoint) + + existent_endpoint = self.stash.get_by_path( + context.credentials, new_endpoint.path + ) + if existent_endpoint.is_ok() and existent_endpoint.ok(): + return SyftError( + message="An API endpoint already exists at the given path." + ) + + result = self.stash.update(context.credentials, endpoint=new_endpoint) + if result.is_err(): + return SyftError(message=result.err()) + + return SyftSuccess(message="Endpoint successfully created.") + + @service_method( + path="api.api_endpoints", + name="api_endpoints", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def api_endpoints( + self, + context: AuthedServiceContext, + ) -> list[TwinAPIEndpointView] | SyftError: + """Retrieves a list of available API endpoints view available to the user.""" + context.node = cast(AbstractNode, context.node) + admin_key = context.node.get_service("userservice").admin_verify_key() + result = self.stash.get_all(admin_key) + if result.is_err(): + return SyftError(message=result.err()) + + all_api_endpoints = result.ok() + api_endpoint_view = [] + for api_endpoint in all_api_endpoints: + api_endpoint_view.append(api_endpoint.to(TwinAPIEndpointView)) + + return api_endpoint_view + + @service_method( + path="api.update", + name="update", + roles=ADMIN_ROLE_LEVEL, + ) + def update( + self, + context: AuthedServiceContext, + uid: UID, + updated_api: UpdateTwinAPIEndpoint, + ) -> SyftSuccess | SyftError: + """Updates an specific API endpoint.""" + return SyftError(message="This is not implemented yet.") + + @service_method( + path="api.delete", + name="delete", + roles=ADMIN_ROLE_LEVEL, + ) + def delete( + self, context: AuthedServiceContext, path: str + ) -> SyftSuccess | SyftError: + """Deletes an specific API endpoint.""" + return SyftError(message="This is not implemented yet.") + + @service_method( + path="api.schema", + name="schema", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def api_schema(self, context: AuthedServiceContext, uid: UID) -> TwinAPIEndpoint: + """Show a view of an API endpoint. This must be smart enough to check if + the user has access to the endpoint.""" + return SyftError(message="This is not implemented yet.") + + @service_method(path="api.call", name="call", roles=GUEST_ROLE_LEVEL) + def call( + self, + context: AuthedServiceContext, + path: str, + *args: Any, + **kwargs: Any, + ) -> SyftSuccess | SyftError: + """Call a Custom API Method""" + context.node = cast(AbstractNode, context.node) + result = self.stash.get_by_path(context.node.verify_key, path=path) + if not result.is_ok(): + return SyftError(message=f"CustomAPIEndpoint: {path} does not exist") + custom_endpoint = result.ok() + custom_endpoint = custom_endpoint[-1] + if result: + context, result = custom_endpoint.exec(context, *args, **kwargs) + return result + + def get_endpoints( + self, context: AuthedServiceContext + ) -> list[TwinAPIEndpoint] | SyftError: + # TODO: Add ability to specify which roles see which endpoints + # for now skip auth + context.node = cast(AbstractNode, context.node) + results = self.stash.get_all(context.node.verify_key) + if results.is_ok(): + return results.ok() + return SyftError(messages="Unable to get CustomAPIEndpoint") diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py new file mode 100644 index 00000000000..991a9ea5964 --- /dev/null +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -0,0 +1,65 @@ +# stdlib + +# third party +from result import Ok +from result import Result + +# relative +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseUIDStoreStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionSettings +from .api import TwinAPIEndpoint + + +@serializable() +class TwinAPIEndpointStash(BaseUIDStoreStash): + object_type = TwinAPIEndpoint + settings: PartitionSettings = PartitionSettings( + name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint + ) + + def __init__(self, store: DocumentStore) -> None: + super().__init__(store=store) + + def get_by_path( + self, credentials: SyftVerifyKey, path: str + ) -> Result[TwinAPIEndpoint, str]: + endpoint_results = self.get_all(credentials=credentials) + + if endpoint_results.is_err(): + return endpoint_results + + endpoint_by_path = None + + for endpoint in endpoint_results.ok(): + if endpoint.path == path: + endpoint_by_path = endpoint + break + + return Ok(endpoint_by_path) + + def update( + self, + credentials: SyftVerifyKey, + endpoint: TwinAPIEndpoint, + has_permission: bool = False, + ) -> Result[TwinAPIEndpoint, str]: + res = self.check_type(endpoint, TwinAPIEndpoint) + if res.is_err(): + return res + old_endpoint = self.get_by_path(credentials=credentials, path=endpoint.path) + if old_endpoint and old_endpoint.ok(): + old_endpoint = old_endpoint.ok() + old_endpoint = old_endpoint[0] + + if old_endpoint == endpoint: + return Ok(endpoint) + else: + super().delete_by_uid(credentials=credentials, uid=old_endpoint.id) + + result = super().set( + credentials=credentials, obj=res.ok(), ignore_duplicates=True + ) + return result