diff --git a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb new file mode 100644 index 00000000000..5ff803cdce6 --- /dev/null +++ b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "id": "f272a63f-03a9-417d-88c3-11a98ad25c80", + "metadata": {}, + "outputs": [], + "source": [ + "data = b\"A\" * (10**9) # 1GB message" + ] + }, + { + "cell_type": "markdown", + "id": "d9a2e0c0-ef3b-41be-a4e8-0d9f190a1106", + "metadata": {}, + "source": [ + "# Using PyNacl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4145072-a959-479b-8c80-da15f82946f3", + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "import hashlib\n", + "import time\n", + "\n", + "# third party\n", + "from nacl.signing import SigningKey\n", + "\n", + "# Generate a new random signing key\n", + "signing_key = SigningKey.generate()\n", + "\n", + "# Example large message\n", + "large_message = data\n", + "\n", + "# Hash the message with SHA-256 using hashlib\n", + "start = time.time()\n", + "hash_object = hashlib.sha256()\n", + "hash_object.update(large_message)\n", + "hashed_message = hash_object.digest()\n", + "hash_time = time.time() - start\n", + "\n", + "# Sign the hashed message with PyNaCl\n", + "start = time.time()\n", + "signed_hash = signing_key.sign(hashed_message)\n", + "sign_time = time.time() - start\n", + "\n", + "# Directly sign the large message with PyNaCl\n", + "start = time.time()\n", + "signed_message = signing_key.sign(large_message)\n", + "direct_sign_time = time.time() - start\n", + "\n", + "print(f\"Time to hash with hashlib: {hash_time:.2f} seconds\")\n", + "print(f\"Time to sign hashed message with PyNaCl: {sign_time:.2f} seconds\")\n", + "print(f\"Total time (hash + sign): {hash_time + sign_time:.2f} seconds\")\n", + "print(\n", + " f\"Time to directly sign large message with PyNaCl: {direct_sign_time:.2f} seconds\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d8581767-bee2-42e1-a571-148cf0fb12a4", + "metadata": {}, + "source": [ + "# Using Cryptography library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ea32e21-8987-4459-aa0f-6bc832376ab7", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install cryptography" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1618c35b-cb6e-4f28-a13c-a2e23497841c", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n", + "\n", + "private_key = Ed25519PrivateKey.generate()\n", + "signature = private_key.sign(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9abb35b3-1891-4074-8f0e-729de0c2e4a2", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey\n", + "\n", + "private_key = Ed448PrivateKey.generate()\n", + "signature = private_key.sign(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66341fe5-94c3-4c8e-af34-a2e837a6957f", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import dsa\n", + "\n", + "private_key = dsa.generate_private_key(\n", + " key_size=1024,\n", + ")\n", + "signature = private_key.sign(data, hashes.SHA256())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83362239-6376-46ee-8e70-d9a23ff5421b", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import ec\n", + "\n", + "private_key = ec.generate_private_key(ec.SECP384R1())\n", + "\n", + "signature = private_key.sign(data, ec.ECDSA(hashes.SHA256()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd5d1781-666e-4d19-aee9-c0ad4b8f0756", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "public_key = private_key.public_key()\n", + "public_key.verify(signature, data, ec.ECDSA(hashes.SHA256()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "206369da-d2c7-424c-b5c6-b1d9b5202786", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import padding\n", + "from cryptography.hazmat.primitives.asymmetric import rsa\n", + "\n", + "private_key = rsa.generate_private_key(\n", + " public_exponent=65537,\n", + " key_size=2048,\n", + ")\n", + "\n", + "message = data\n", + "signature = private_key.sign(\n", + " message,\n", + " padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),\n", + " hashes.SHA256(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b222c11a-e2a2-4610-a9d8-95ee3343d466", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "public_key = private_key.public_key()\n", + "message = data\n", + "public_key.verify(\n", + " signature,\n", + " message,\n", + " padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),\n", + " hashes.SHA256(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6fa46875-4405-47c6-855c-0b3f407aa26c", + "metadata": {}, + "source": [ + "# Hashing by PyNacl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea204831-482d-4d3a-988b-32920b7af285", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "import nacl.encoding\n", + "import nacl.hash\n", + "\n", + "methods = [\"sha256\", \"sha512\", \"blake2b\"]\n", + "\n", + "for hash_method in methods:\n", + " HASHER = getattr(nacl.hash, hash_method)\n", + "\n", + " start = time.time()\n", + " digest = HASHER(data, encoder=nacl.encoding.HexEncoder)\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "markdown", + "id": "df81c37d-024e-4de8-a136-717f2e67e724", + "metadata": {}, + "source": [ + "# Hashing by cryptography library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a775385-6b57-46ab-9aed-51598a8c7592", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "\n", + "methods = [\"SHA256\", \"SHA512\", \"BLAKE2b\"]\n", + "\n", + "for hash_method in methods:\n", + " if hash_method == \"BLAKE2b\":\n", + " digest = hashes.Hash(getattr(hashes, hash_method)(64))\n", + " else:\n", + " digest = hashes.Hash(getattr(hashes, hash_method)())\n", + "\n", + " start = time.time()\n", + " digest.update(data)\n", + " digest.finalize()\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "markdown", + "id": "086ab235-d9a0-4184-8270-bffb088bf1c3", + "metadata": {}, + "source": [ + "# Hashing by python hashlib" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b08d7f82-ea8f-4b24-ac09-669526894293", + "metadata": {}, + "outputs": [], + "source": [ + "methods = [\"sha256\", \"sha512\", \"blake2b\"]\n", + "\n", + "for hash_method in methods:\n", + " if hash_method == \"blake2b\":\n", + " m = getattr(hashlib, hash_method)(digest_size=64)\n", + " else:\n", + " m = getattr(hashlib, hash_method)()\n", + "\n", + " start = time.time()\n", + " m.update(data)\n", + " m.digest()\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9bf2843e-add6-4f65-a75b-5ef93093d347", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pycryptodome\n", + " Downloading pycryptodome-3.20.0-cp35-abi3-macosx_10_9_universal2.whl.metadata (3.4 kB)\n", + "Downloading pycryptodome-3.20.0-cp35-abi3-macosx_10_9_universal2.whl (2.4 MB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: pycryptodome\n", + "Successfully installed pycryptodome-3.20.0\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install pycryptodome" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4343bedd-308a-4caf-a4ff-56cdd3ca2433", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Public Key:\n", + "-----BEGIN PUBLIC KEY-----\n", + "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEz1vchLT61W1+TWg86POU/jsYS4IJ\n", + "IzeBv+mYc9Ehpn0MqCpri5l0+HbnIpLAdvO7KeYRGBRqFPJMjqt5rB30Aw==\n", + "-----END PUBLIC KEY-----\n", + "\n", + "Private Key:\n", + "-----BEGIN PRIVATE KEY-----\n", + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgSIn/SVjK1hLXs5XK\n", + "S7C+dB1YcSz9VqStzP1ytSL9y7ihRANCAATPW9yEtPrVbX5NaDzo85T+OxhLggkj\n", + "N4G/6Zhz0SGmfQyoKmuLmXT4duciksB287sp5hEYFGoU8kyOq3msHfQD\n", + "-----END PRIVATE KEY-----\n", + "\n", + "Signature:\n", + "108b92beb9b85840c39e217373c998fb6df71baabb6a39cae6088f4a1f920d66694b1a71df082d930f58d91e83b72eee6aaa77f865796a78671d5bb74d384866\n", + "CPU times: user 4.9 s, sys: 41.8 ms, total: 4.94 s\n", + "Wall time: 4.94 s\n" + ] + } + ], + "source": [ + "# third party\n", + "from Crypto.Hash import SHA256\n", + "\n", + "%%time\n", + "# third party\n", + "from Crypto.PublicKey import ECC\n", + "from Crypto.Signature import DSS\n", + "\n", + "# Generate a new ECC key pair\n", + "key = ECC.generate(curve=\"P-256\")\n", + "\n", + "# Export the public key in PEM format\n", + "public_key_pem = key.public_key().export_key(format=\"PEM\")\n", + "print(\"Public Key:\")\n", + "print(public_key_pem)\n", + "\n", + "# Export the private key in PEM format\n", + "private_key_pem = key.export_key(format=\"PEM\")\n", + "print(\"\\nPrivate Key:\")\n", + "print(private_key_pem)\n", + "\n", + "# Sign a message\n", + "message = data\n", + "hash_obj = SHA256.new(message)\n", + "signer = DSS.new(key, \"fips-186-3\")\n", + "signature = signer.sign(hash_obj)\n", + "print(\"\\nSignature:\")\n", + "print(signature.hex())\n", + "\n", + "# # Verify the signature\n", + "# public_key = ECC.import_key(public_key_pem)\n", + "# verifier = DSS.new(public_key, 'fips-186-3')\n", + "# try:\n", + "# verifier.verify(hash_obj, signature)\n", + "# print(\"\\nThe message is authentic.\")\n", + "# except ValueError:\n", + "# print(\"\\nThe message is not authentic.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2034a8fd-c89e-461f-805b-5b37c4c7d395", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 5d0dbe6652d..4ab321b108a 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -157,6 +157,10 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError: model_size += get_mb_size(asset.data) model_ref_action_ids.append(twin.id) + # Clear the Data and Mock , as they are uploaded as twin object + asset.data = None + asset.mock = None + # Update the progress bar and set the dynamic description pbar.set_description(f"Uploading: {asset.name}") pbar.update(1) diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 37baaff90e8..632f84b12d8 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -2,7 +2,10 @@ import base64 import binascii from collections.abc import AsyncGenerator +from collections.abc import Callable +from datetime import datetime import logging +from pathlib import Path from typing import Annotated # third party @@ -34,12 +37,13 @@ from ..util.telemetry import TRACE_MODE from .credentials import SyftVerifyKey from .credentials import UserLoginCredentials +from .server_settings import ServerSettings from .worker import Worker logger = logging.getLogger(__name__) -def make_routes(worker: Worker) -> APIRouter: +def make_routes(worker: Worker, settings: ServerSettings | None = None) -> APIRouter: if TRACE_MODE: # third party try: @@ -49,6 +53,34 @@ def make_routes(worker: Worker) -> APIRouter: except Exception as e: logger.error("Failed to import opentelemetry", exc_info=e) + def _handle_profile( + request: Request, handler_func: Callable, *args: list, **kwargs: dict + ) -> Response: + if not settings: + raise Exception("Server Settings are required to enable profiling") + # third party + from pyinstrument import Profiler # Lazy Load + + profiles_dir = Path(settings.profile_dir or Path.cwd()) / "profiles" + profiles_dir.mkdir(parents=True, exist_ok=True) + + with Profiler( + interval=settings.profile_interval, async_mode="enabled" + ) as profiler: + response = handler_func(*args, **kwargs) + + timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + url_path = request.url.path.replace("/api/v2", "").replace("/", "-") + profile_output_path = ( + profiles_dir / f"{settings.name}-{timestamp}{url_path}.html" + ) + profiler.write_html(profile_output_path) + + logger.info( + f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" + ) + return response + router = APIRouter() async def get_body(request: Request) -> bytes: @@ -165,6 +197,13 @@ def syft_new_api( kind=trace.SpanKind.SERVER, ): return handle_syft_new_api(user_verify_key, communication_protocol) + elif settings and settings.profile: + return _handle_profile( + request, + handle_syft_new_api, + user_verify_key, + communication_protocol, + ) else: return handle_syft_new_api(user_verify_key, communication_protocol) @@ -188,6 +227,8 @@ def syft_new_api_call( kind=trace.SpanKind.SERVER, ): return handle_new_api_call(data) + elif settings and settings.profile: + return _handle_profile(request, handle_new_api_call, data) else: return handle_new_api_call(data) @@ -253,6 +294,8 @@ def login( kind=trace.SpanKind.SERVER, ): return handle_login(email, password, worker) + elif settings and settings.profile: + return _handle_profile(request, handle_login, email, password, worker) else: return handle_login(email, password, worker) @@ -267,6 +310,8 @@ def register( kind=trace.SpanKind.SERVER, ): return handle_register(data, worker) + elif settings and settings.profile: + return _handle_profile(request, handle_register, data, worker) else: return handle_register(data, worker) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index a3451f304a7..207860c8b83 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from datetime import datetime import multiprocessing import multiprocessing.synchronize import os @@ -14,8 +15,8 @@ # third party from fastapi import APIRouter from fastapi import FastAPI -from pydantic_settings import BaseSettings -from pydantic_settings import SettingsConfigDict +from fastapi import Request +from fastapi import Response import requests from starlette.middleware.cors import CORSMiddleware import uvicorn @@ -31,6 +32,7 @@ from .gateway import Gateway from .node import NodeType from .routes import make_routes +from .server_settings import ServerSettings from .utils import get_named_node_uid from .utils import remove_temp_dir_for_node @@ -42,26 +44,8 @@ WAIT_TIME_SECONDS = 20 -class AppSettings(BaseSettings): - name: str - node_type: NodeType = NodeType.DOMAIN - node_side_type: NodeSideType = NodeSideType.HIGH_SIDE - processes: int = 1 - reset: bool = False - dev_mode: bool = False - enable_warnings: bool = False - in_memory_workers: bool = True - queue_port: int | None = None - create_producer: bool = False - n_consumers: int = 0 - association_request_auto_approval: bool = False - background_tasks: bool = False - - model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") - - def app_factory() -> FastAPI: - settings = AppSettings() + settings = ServerSettings() worker_classes = { NodeType.DOMAIN: Domain, @@ -72,21 +56,49 @@ def app_factory() -> FastAPI: raise NotImplementedError(f"node_type: {settings.node_type} is not supported") worker_class = worker_classes[settings.node_type] - kwargs = settings.model_dump() + worker_kwargs = settings.model_dump() + # Remove Profiling inputs + worker_kwargs.pop("profile") + worker_kwargs.pop("profile_interval") + worker_kwargs.pop("profile_dir") if settings.dev_mode: print( f"WARN: private key is based on node name: {settings.name} in dev_mode. " "Don't run this in production." ) - worker = worker_class.named(**kwargs) + worker = worker_class.named(**worker_kwargs) else: - worker = worker_class(**kwargs) + worker = worker_class(**worker_kwargs) app = FastAPI(title=settings.name) - router = make_routes(worker=worker) + router = make_routes(worker=worker, settings=settings) api_router = APIRouter() api_router.include_router(router) app.include_router(api_router, prefix="/api/v2") + + # Register middlewares + _register_middlewares(app, settings) + + return app + + +def _register_middlewares(app: FastAPI, settings: ServerSettings) -> None: + _register_cors_middleware(app) + + # As currently sync routes are not supported in pyinstrument + # we are not registering the profiler middleware for sync routes + # as currently most of our routes are sync routes in syft (routes.py) + # ex: syft_new_api, syft_new_api_call, login, register + # we should either convert these routes to async or + # wait until pyinstrument supports sync routes + # The reason we cannot our sync routes to async is because + # we have blocking IO operations, like the requests library, like if one route calls to + # itself, it will block the event loop and the server will hang + # if settings.profile: + # _register_profiler(app, settings) + + +def _register_cors_middleware(app: FastAPI) -> None: app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -94,7 +106,55 @@ def app_factory() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) - return app + + +def _register_profiler(app: FastAPI, settings: ServerSettings) -> None: + # third party + from pyinstrument import Profiler + + profiles_dir = ( + Path.cwd() / "profiles" + if settings.profile_dir is None + else Path(settings.profile_dir) / "profiles" + ) + + @app.middleware("http") + async def profile_request( + request: Request, call_next: Callable[[Request], Response] + ) -> Response: + with Profiler( + interval=settings.profile_interval, async_mode="enabled" + ) as profiler: + response = await call_next(request) + + # Profile File Name - Domain Name - Timestamp - URL Path + timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + profiles_dir.mkdir(parents=True, exist_ok=True) + url_path = request.url.path.replace("/api/v2", "").replace("/", "-") + profile_output_path = ( + profiles_dir / f"{settings.name}-{timestamp}{url_path}.html" + ) + + # Write the profile to a HTML file + profiler.write_html(profile_output_path) + + print( + f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" + ) + + return response + + +def _load_pyinstrument_jupyter_extension() -> None: + try: + # third party + from IPython import get_ipython + + ipython = get_ipython() # noqa: F821 + ipython.run_line_magic("load_ext", "pyinstrument") + print("Pyinstrument Jupyter extension loaded") + except Exception as e: + print(f"Error loading pyinstrument jupyter extension: {e}") def attach_debugger() -> None: @@ -142,7 +202,7 @@ def run_uvicorn( attach_debugger() # Set up all kwargs as environment variables so that they can be accessed in the app_factory function. - env_prefix = AppSettings.model_config.get("env_prefix", "") + env_prefix = ServerSettings.model_config.get("env_prefix", "") for key, value in kwargs.items(): key_with_prefix = f"{env_prefix}{key.upper()}" os.environ[key_with_prefix] = str(value) @@ -187,6 +247,10 @@ def serve_node( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, + # Profiling inputs + profile: bool = False, + profile_interval: float = 0.001, + profile_dir: str | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -194,6 +258,12 @@ def serve_node( if dev_mode: enable_autoreload() + # Load the Pyinstrument Jupyter extension if profile is enabled. + if profile: + _load_pyinstrument_jupyter_extension() + if profile_dir is None: + profile_dir = str(Path.cwd()) + server_process = multiprocessing.Process( target=run_uvicorn, kwargs={ @@ -214,6 +284,9 @@ def serve_node( "background_tasks": background_tasks, "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, + "profile": profile, + "profile_interval": profile_interval, + "profile_dir": profile_dir, }, ) diff --git a/packages/syft/src/syft/node/server_settings.py b/packages/syft/src/syft/node/server_settings.py new file mode 100644 index 00000000000..3c57606ec02 --- /dev/null +++ b/packages/syft/src/syft/node/server_settings.py @@ -0,0 +1,30 @@ +# third party +from pydantic_settings import BaseSettings +from pydantic_settings import SettingsConfigDict + +# relative +from ..abstract_node import NodeSideType +from ..abstract_node import NodeType + + +class ServerSettings(BaseSettings): + name: str + node_type: NodeType = NodeType.DOMAIN + node_side_type: NodeSideType = NodeSideType.HIGH_SIDE + processes: int = 1 + reset: bool = False + dev_mode: bool = False + enable_warnings: bool = False + in_memory_workers: bool = True + queue_port: int | None = None + create_producer: bool = False + n_consumers: int = 0 + association_request_auto_approval: bool = False + background_tasks: bool = False + + # Profiling inputs + profile: bool = False + profile_interval: float = 0.001 + profile_dir: str | None = None + + model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index c07dce6a5d6..bff399ba99e 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -175,6 +175,9 @@ def deploy_to_python( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + profile: bool = False, + profile_interval: float = 0.001, + profile_dir: str | None = None, ) -> NodeHandle: worker_classes = { NodeType.DOMAIN: Domain, @@ -204,6 +207,9 @@ def deploy_to_python( "background_tasks": background_tasks, "debug": debug, "migrate": migrate, + "profile": profile, + "profile_interval": profile_interval, + "profile_dir": profile_dir, } if port: @@ -305,6 +311,10 @@ def launch( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + # Profiling Related Input for in-memory fastapi server + profile: bool = False, + profile_interval: float = 0.001, + profile_dir: str | None = None, ) -> NodeHandle: if dev_mode is True: thread_workers = True @@ -343,6 +353,9 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, + profile=profile, + profile_interval=profile_interval, + profile_dir=profile_dir, ) display( SyftInfo( diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 7a6e1cf732e..702f448dd12 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -166,7 +166,9 @@ def get_files_from_bucket( return blob_files - @service_method(path="blob_storage.get_by_uid", name="get_by_uid") + @service_method( + path="blob_storage.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL + ) def get_blob_storage_entry_by_uid( self, context: AuthedServiceContext, uid: UID ) -> BlobStorageEntry | SyftError: diff --git a/packages/syft/src/syft/service/enclave/domain_enclave_service.py b/packages/syft/src/syft/service/enclave/domain_enclave_service.py index a2e483939a5..c18bef2d229 100644 --- a/packages/syft/src/syft/service/enclave/domain_enclave_service.py +++ b/packages/syft/src/syft/service/enclave/domain_enclave_service.py @@ -1,6 +1,7 @@ # stdlib import itertools from typing import Any +from typing import cast # relative from ...serde.serializable import serializable @@ -13,6 +14,7 @@ from ...service.user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ...store.document_store import DocumentStore from ...types.uid import UID +from ..action.action_object import ActionObject from ..code.user_code import UserCode from ..context import AuthedServiceContext from ..model.model import ModelRef @@ -145,7 +147,7 @@ def request_assets_upload( if node_identity.node_id == context.node.id ] asset_action_ids = tuple(itertools.chain.from_iterable(asset_action_ids_nested)) - action_objects = [ + action_objects: list[ActionObject] = [ context.node.get_service("actionservice") .get(context=root_context, uid=action_id) .ok() @@ -184,6 +186,15 @@ def request_assets_upload( _ = action_object.syft_action_data action_object.syft_blob_storage_entry_id = None blob_res = action_object._save_to_blob_storage(client=enclave_client) + + action_object.syft_blob_storage_entry_id = cast( + UID | None, action_object.syft_blob_storage_entry_id + ) + # For smaller data, we do not store in blob storage + # so for the cases, where we store in blob storage + # we need to clear the cache , to avoid sending the data again + if action_object.syft_blob_storage_entry_id: + action_object._clear_cache() if isinstance(blob_res, SyftError): return blob_res diff --git a/packages/syft/src/syft/service/model/model.py b/packages/syft/src/syft/service/model/model.py index e92e1ea10f8..5416c981a76 100644 --- a/packages/syft/src/syft/service/model/model.py +++ b/packages/syft/src/syft/service/model/model.py @@ -5,6 +5,7 @@ from textwrap import dedent from typing import Any from typing import ClassVar +from typing import cast # third party from IPython.display import display @@ -634,16 +635,26 @@ def load_data( asset_list = [] for asset_action_id in asset_action_ids: - res = admin_client.services.action.get(asset_action_id) - action_data = res.syft_action_data + action_object = admin_client.services.action.get(asset_action_id) + action_data = action_object.syft_action_data # Save to blob storage of remote client if provided if remote_client is not None: - res.syft_blob_storage_entry_id = None - blob_res = res._save_to_blob_storage(client=remote_client) + action_object.syft_blob_storage_entry_id = None + blob_res = action_object._save_to_blob_storage(client=remote_client) + # For smaller data, we do not store in blob storage + # so for the cases, where we store in blob storage + # we need to clear the cache , to avoid sending the data again + # stdlib + + action_object.syft_blob_storage_entry_id = cast( + UID | None, action_object.syft_blob_storage_entry_id + ) + if action_object.syft_blob_storage_entry_id: + action_object._clear_cache() if isinstance(blob_res, SyftError): return blob_res - asset_list.append(action_data if unwrap_action_data else res) + asset_list.append(action_data if unwrap_action_data else action_object) loaded_data = [model] + asset_list if wrap_ref_to_obj: diff --git a/packages/syft/src/syft/store/blob_storage/on_disk.py b/packages/syft/src/syft/store/blob_storage/on_disk.py index 4369b46db4f..e89a604f59d 100644 --- a/packages/syft/src/syft/store/blob_storage/on_disk.py +++ b/packages/syft/src/syft/store/blob_storage/on_disk.py @@ -32,14 +32,28 @@ def write(self, data: BytesIO) -> SyftSuccess | SyftError: # relative from ...service.service import from_api_or_context - write_to_disk_method = from_api_or_context( - func_or_path="blob_storage.write_to_disk", + get_by_uid_method = from_api_or_context( + func_or_path="blob_storage.get_by_uid", syft_node_location=self.syft_node_location, syft_client_verify_key=self.syft_client_verify_key, ) - if write_to_disk_method is None: - return SyftError(message="write_to_disk_method is None") - return write_to_disk_method(data=data.read(), uid=self.blob_storage_entry_id) + if get_by_uid_method is None: + return SyftError(message="get_by_uid_method is None") + + obj = get_by_uid_method(uid=self.blob_storage_entry_id) + if isinstance(obj, SyftError): + return obj + if obj is None: + return SyftError( + message=f"No blob storage entry exists for uid: {self.blob_storage_entry_id}, " + "or you have no permissions to read it" + ) + + try: + Path(obj.location.path).write_bytes(data.read()) + return SyftSuccess(message="File successfully saved.") + except Exception as e: + return SyftError(message=f"Failed to write object to disk: {e}") class OnDiskBlobStorageConnection(BlobStorageConnection):