From bf70ade5fecdd13c783b91e29701fbc01702dcfa Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Tue, 16 Jul 2024 22:24:33 +0530 Subject: [PATCH 01/12] Optimization 1 --- .../src/syft/store/blob_storage/on_disk.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/store/blob_storage/on_disk.py b/packages/syft/src/syft/store/blob_storage/on_disk.py index 4369b46db4f..e89a604f59d 100644 --- a/packages/syft/src/syft/store/blob_storage/on_disk.py +++ b/packages/syft/src/syft/store/blob_storage/on_disk.py @@ -32,14 +32,28 @@ def write(self, data: BytesIO) -> SyftSuccess | SyftError: # relative from ...service.service import from_api_or_context - write_to_disk_method = from_api_or_context( - func_or_path="blob_storage.write_to_disk", + get_by_uid_method = from_api_or_context( + func_or_path="blob_storage.get_by_uid", syft_node_location=self.syft_node_location, syft_client_verify_key=self.syft_client_verify_key, ) - if write_to_disk_method is None: - return SyftError(message="write_to_disk_method is None") - return write_to_disk_method(data=data.read(), uid=self.blob_storage_entry_id) + if get_by_uid_method is None: + return SyftError(message="get_by_uid_method is None") + + obj = get_by_uid_method(uid=self.blob_storage_entry_id) + if isinstance(obj, SyftError): + return obj + if obj is None: + return SyftError( + message=f"No blob storage entry exists for uid: {self.blob_storage_entry_id}, " + "or you have no permissions to read it" + ) + + try: + Path(obj.location.path).write_bytes(data.read()) + return SyftSuccess(message="File successfully saved.") + except Exception as e: + return SyftError(message=f"Failed to write object to disk: {e}") class OnDiskBlobStorageConnection(BlobStorageConnection): From a636533e299a0ea9d14a792208df4f0ca74b5b5e Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Wed, 17 Jul 2024 00:02:36 +0530 Subject: [PATCH 02/12] Remove asset_list from model before saving it to the model stash --- packages/syft/src/syft/client/domain_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 5d0dbe6652d..7be3614125b 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -182,6 +182,7 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError: return valid # Step 4. Upload Model to Model Stash + model.asset_list = [] return self.api.services.model.add(model=model) def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: From c244eeeef9ed5a95b734a06bf56d83308c08e7a2 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Fri, 19 Jul 2024 12:08:16 +0530 Subject: [PATCH 03/12] upload digital signature benchmarks --- .../Digital Signatures/ds-benchmarks.ipynb | 326 ++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb diff --git a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb new file mode 100644 index 00000000000..83c12377581 --- /dev/null +++ b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb @@ -0,0 +1,326 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "721d9a24-aec7-4fbd-a9c2-a6146a4da291", + "metadata": {}, + "outputs": [], + "source": [ + "data = b\"A\" * (10**9) # 1GB message" + ] + }, + { + "cell_type": "markdown", + "id": "d9a2e0c0-ef3b-41be-a4e8-0d9f190a1106", + "metadata": {}, + "source": [ + "# Using PyNacl" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a4145072-a959-479b-8c80-da15f82946f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time to hash with hashlib: 0.45 seconds\n", + "Time to sign hashed message with PyNaCl: 0.00 seconds\n", + "Total time (hash + sign): 0.45 seconds\n", + "Time to directly sign large message with PyNaCl: 22.56 seconds\n" + ] + } + ], + "source": [ + "# stdlib\n", + "import hashlib\n", + "import time\n", + "\n", + "# third party\n", + "from nacl.signing import SigningKey\n", + "\n", + "# Generate a new random signing key\n", + "signing_key = SigningKey.generate()\n", + "\n", + "# Example large message\n", + "large_message = data\n", + "\n", + "# Hash the message with SHA-256 using hashlib\n", + "start = time.time()\n", + "hash_object = hashlib.sha256()\n", + "hash_object.update(large_message)\n", + "hashed_message = hash_object.digest()\n", + "hash_time = time.time() - start\n", + "\n", + "# Sign the hashed message with PyNaCl\n", + "start = time.time()\n", + "signed_hash = signing_key.sign(hashed_message)\n", + "sign_time = time.time() - start\n", + "\n", + "# Directly sign the large message with PyNaCl\n", + "start = time.time()\n", + "signed_message = signing_key.sign(large_message)\n", + "direct_sign_time = time.time() - start\n", + "\n", + "print(f\"Time to hash with hashlib: {hash_time:.2f} seconds\")\n", + "print(f\"Time to sign hashed message with PyNaCl: {sign_time:.2f} seconds\")\n", + "print(f\"Total time (hash + sign): {hash_time + sign_time:.2f} seconds\")\n", + "print(\n", + " f\"Time to directly sign large message with PyNaCl: {direct_sign_time:.2f} seconds\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d8581767-bee2-42e1-a571-148cf0fb12a4", + "metadata": {}, + "source": [ + "# Using Cryptography library" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "8ea32e21-8987-4459-aa0f-6bc832376ab7", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install cryptography" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1618c35b-cb6e-4f28-a13c-a2e23497841c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.38 s, sys: 11 ms, total: 1.39 s\n", + "Wall time: 1.38 s\n" + ] + } + ], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey\n", + "\n", + "private_key = Ed25519PrivateKey.generate()\n", + "signature = private_key.sign(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9abb35b3-1891-4074-8f0e-729de0c2e4a2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.22 s, sys: 23.2 ms, total: 2.24 s\n", + "Wall time: 2.25 s\n" + ] + } + ], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey\n", + "\n", + "private_key = Ed448PrivateKey.generate()\n", + "signature = private_key.sign(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "66341fe5-94c3-4c8e-af34-a2e837a6957f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 485 ms, sys: 4.75 ms, total: 490 ms\n", + "Wall time: 489 ms\n" + ] + } + ], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import dsa\n", + "\n", + "private_key = dsa.generate_private_key(\n", + " key_size=1024,\n", + ")\n", + "signature = private_key.sign(data, hashes.SHA256())" + ] + }, + { + "cell_type": "markdown", + "id": "6fa46875-4405-47c6-855c-0b3f407aa26c", + "metadata": {}, + "source": [ + "# Hashing by PyNacl" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ea204831-482d-4d3a-988b-32920b7af285", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken for sha256 16.29426908493042\n", + "Time taken for sha512 11.238587856292725\n", + "Time taken for blake2b 7.366748094558716\n" + ] + } + ], + "source": [ + "# third party\n", + "import nacl.encoding\n", + "import nacl.hash\n", + "\n", + "methods = [\"sha256\", \"sha512\", \"blake2b\"]\n", + "\n", + "for hash_method in methods:\n", + " HASHER = getattr(nacl.hash, hash_method)\n", + "\n", + " start = time.time()\n", + " digest = HASHER(data, encoder=nacl.encoding.HexEncoder)\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "markdown", + "id": "df81c37d-024e-4de8-a136-717f2e67e724", + "metadata": {}, + "source": [ + "# Hashing by cryptography library" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2a775385-6b57-46ab-9aed-51598a8c7592", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken for SHA256 0.43844008445739746\n", + "Time taken for SHA512 0.6953341960906982\n", + "Time taken for BLAKE2b 7.246281862258911\n" + ] + } + ], + "source": [ + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "\n", + "methods = [\"SHA256\", \"SHA512\", \"BLAKE2b\"]\n", + "\n", + "for hash_method in methods:\n", + " if hash_method == \"BLAKE2b\":\n", + " digest = hashes.Hash(getattr(hashes, hash_method)(64))\n", + " else:\n", + " digest = hashes.Hash(getattr(hashes, hash_method)())\n", + "\n", + " start = time.time()\n", + " digest.update(data)\n", + " digest.finalize()\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "markdown", + "id": "086ab235-d9a0-4184-8270-bffb088bf1c3", + "metadata": {}, + "source": [ + "# Hashing by python hashlib" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b08d7f82-ea8f-4b24-ac09-669526894293", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time taken for sha256 0.4372677803039551\n", + "Time taken for sha512 0.6927249431610107\n", + "Time taken for blake2b 1.4543838500976562\n" + ] + } + ], + "source": [ + "methods = [\"sha256\", \"sha512\", \"blake2b\"]\n", + "\n", + "for hash_method in methods:\n", + " if hash_method == \"blake2b\":\n", + " m = getattr(hashlib, hash_method)(digest_size=64)\n", + " else:\n", + " m = getattr(hashlib, hash_method)()\n", + "\n", + " start = time.time()\n", + " m.update(data)\n", + " m.digest()\n", + " end = time.time()\n", + " print(f\"Time taken for {hash_method}\", end - start)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c6e1c84-e782-4965-b6fd-a53bbbc445ac", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 7fe10d24eb1e4836276a9c88a568ca6bf7be9367 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:16:18 +0530 Subject: [PATCH 04/12] upload digital signature benchmarks for ecc, rsa --- .../Digital Signatures/ds-benchmarks.ipynb | 185 ++++++++++-------- 1 file changed, 100 insertions(+), 85 deletions(-) diff --git a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb index 83c12377581..bf6fe811cb1 100644 --- a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb +++ b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb @@ -2,8 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, - "id": "721d9a24-aec7-4fbd-a9c2-a6146a4da291", + "execution_count": null, + "id": "f272a63f-03a9-417d-88c3-11a98ad25c80", "metadata": {}, "outputs": [], "source": [ @@ -20,21 +20,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "a4145072-a959-479b-8c80-da15f82946f3", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time to hash with hashlib: 0.45 seconds\n", - "Time to sign hashed message with PyNaCl: 0.00 seconds\n", - "Total time (hash + sign): 0.45 seconds\n", - "Time to directly sign large message with PyNaCl: 22.56 seconds\n" - ] - } - ], + "outputs": [], "source": [ "# stdlib\n", "import hashlib\n", @@ -84,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "8ea32e21-8987-4459-aa0f-6bc832376ab7", "metadata": {}, "outputs": [], @@ -94,19 +83,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "1618c35b-cb6e-4f28-a13c-a2e23497841c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 1.38 s, sys: 11 ms, total: 1.39 s\n", - "Wall time: 1.38 s\n" - ] - } - ], + "outputs": [], "source": [ "# third party\n", "%%time\n", @@ -119,19 +99,10 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "9abb35b3-1891-4074-8f0e-729de0c2e4a2", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2.22 s, sys: 23.2 ms, total: 2.24 s\n", - "Wall time: 2.25 s\n" - ] - } - ], + "outputs": [], "source": [ "# third party\n", "%%time\n", @@ -144,19 +115,10 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "66341fe5-94c3-4c8e-af34-a2e837a6957f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 485 ms, sys: 4.75 ms, total: 490 ms\n", - "Wall time: 489 ms\n" - ] - } - ], + "outputs": [], "source": [ "# third party\n", "%%time\n", @@ -170,6 +132,81 @@ "signature = private_key.sign(data, hashes.SHA256())" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "83362239-6376-46ee-8e70-d9a23ff5421b", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import ec\n", + "\n", + "private_key = ec.generate_private_key(ec.SECP384R1())\n", + "\n", + "signature = private_key.sign(data, ec.ECDSA(hashes.SHA256()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd5d1781-666e-4d19-aee9-c0ad4b8f0756", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "public_key = private_key.public_key()\n", + "public_key.verify(signature, data, ec.ECDSA(hashes.SHA256()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "206369da-d2c7-424c-b5c6-b1d9b5202786", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "%%time\n", + "# third party\n", + "from cryptography.hazmat.primitives import hashes\n", + "from cryptography.hazmat.primitives.asymmetric import padding\n", + "from cryptography.hazmat.primitives.asymmetric import rsa\n", + "\n", + "private_key = rsa.generate_private_key(\n", + " public_exponent=65537,\n", + " key_size=2048,\n", + ")\n", + "\n", + "message = data\n", + "signature = private_key.sign(\n", + " message,\n", + " padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),\n", + " hashes.SHA256(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b222c11a-e2a2-4610-a9d8-95ee3343d466", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "public_key = private_key.public_key()\n", + "message = data\n", + "public_key.verify(\n", + " signature,\n", + " message,\n", + " padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),\n", + " hashes.SHA256(),\n", + ")" + ] + }, { "cell_type": "markdown", "id": "6fa46875-4405-47c6-855c-0b3f407aa26c", @@ -180,20 +217,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "ea204831-482d-4d3a-988b-32920b7af285", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time taken for sha256 16.29426908493042\n", - "Time taken for sha512 11.238587856292725\n", - "Time taken for blake2b 7.366748094558716\n" - ] - } - ], + "outputs": [], "source": [ "# third party\n", "import nacl.encoding\n", @@ -220,20 +247,10 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "2a775385-6b57-46ab-9aed-51598a8c7592", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time taken for SHA256 0.43844008445739746\n", - "Time taken for SHA512 0.6953341960906982\n", - "Time taken for BLAKE2b 7.246281862258911\n" - ] - } - ], + "outputs": [], "source": [ "# third party\n", "from cryptography.hazmat.primitives import hashes\n", @@ -263,20 +280,10 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "b08d7f82-ea8f-4b24-ac09-669526894293", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time taken for sha256 0.4372677803039551\n", - "Time taken for sha512 0.6927249431610107\n", - "Time taken for blake2b 1.4543838500976562\n" - ] - } - ], + "outputs": [], "source": [ "methods = [\"sha256\", \"sha512\", \"blake2b\"]\n", "\n", @@ -300,6 +307,14 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bf2843e-add6-4f65-a75b-5ef93093d347", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From e41d0eeffd11997b081e581b2b6b934a164dbfc3 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:19:58 +0530 Subject: [PATCH 05/12] Add benchmarks also for pycryptodome --- .../Digital Signatures/ds-benchmarks.ipynb | 100 ++++++++++++++++-- 1 file changed, 94 insertions(+), 6 deletions(-) diff --git a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb index bf6fe811cb1..5ff803cdce6 100644 --- a/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb +++ b/notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "f272a63f-03a9-417d-88c3-11a98ad25c80", "metadata": {}, "outputs": [], @@ -302,16 +302,104 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "0c6e1c84-e782-4965-b6fd-a53bbbc445ac", + "execution_count": 1, + "id": "9bf2843e-add6-4f65-a75b-5ef93093d347", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pycryptodome\n", + " Downloading pycryptodome-3.20.0-cp35-abi3-macosx_10_9_universal2.whl.metadata (3.4 kB)\n", + "Downloading pycryptodome-3.20.0-cp35-abi3-macosx_10_9_universal2.whl (2.4 MB)\n", + "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: pycryptodome\n", + "Successfully installed pycryptodome-3.20.0\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install pycryptodome" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4343bedd-308a-4caf-a4ff-56cdd3ca2433", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Public Key:\n", + "-----BEGIN PUBLIC KEY-----\n", + "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEz1vchLT61W1+TWg86POU/jsYS4IJ\n", + "IzeBv+mYc9Ehpn0MqCpri5l0+HbnIpLAdvO7KeYRGBRqFPJMjqt5rB30Aw==\n", + "-----END PUBLIC KEY-----\n", + "\n", + "Private Key:\n", + "-----BEGIN PRIVATE KEY-----\n", + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgSIn/SVjK1hLXs5XK\n", + "S7C+dB1YcSz9VqStzP1ytSL9y7ihRANCAATPW9yEtPrVbX5NaDzo85T+OxhLggkj\n", + "N4G/6Zhz0SGmfQyoKmuLmXT4duciksB287sp5hEYFGoU8kyOq3msHfQD\n", + "-----END PRIVATE KEY-----\n", + "\n", + "Signature:\n", + "108b92beb9b85840c39e217373c998fb6df71baabb6a39cae6088f4a1f920d66694b1a71df082d930f58d91e83b72eee6aaa77f865796a78671d5bb74d384866\n", + "CPU times: user 4.9 s, sys: 41.8 ms, total: 4.94 s\n", + "Wall time: 4.94 s\n" + ] + } + ], + "source": [ + "# third party\n", + "from Crypto.Hash import SHA256\n", + "\n", + "%%time\n", + "# third party\n", + "from Crypto.PublicKey import ECC\n", + "from Crypto.Signature import DSS\n", + "\n", + "# Generate a new ECC key pair\n", + "key = ECC.generate(curve=\"P-256\")\n", + "\n", + "# Export the public key in PEM format\n", + "public_key_pem = key.public_key().export_key(format=\"PEM\")\n", + "print(\"Public Key:\")\n", + "print(public_key_pem)\n", + "\n", + "# Export the private key in PEM format\n", + "private_key_pem = key.export_key(format=\"PEM\")\n", + "print(\"\\nPrivate Key:\")\n", + "print(private_key_pem)\n", + "\n", + "# Sign a message\n", + "message = data\n", + "hash_obj = SHA256.new(message)\n", + "signer = DSS.new(key, \"fips-186-3\")\n", + "signature = signer.sign(hash_obj)\n", + "print(\"\\nSignature:\")\n", + "print(signature.hex())\n", + "\n", + "# # Verify the signature\n", + "# public_key = ECC.import_key(public_key_pem)\n", + "# verifier = DSS.new(public_key, 'fips-186-3')\n", + "# try:\n", + "# verifier.verify(hash_obj, signature)\n", + "# print(\"\\nThe message is authentic.\")\n", + "# except ValueError:\n", + "# print(\"\\nThe message is not authentic.\")" + ] }, { "cell_type": "code", "execution_count": null, - "id": "9bf2843e-add6-4f65-a75b-5ef93093d347", + "id": "2034a8fd-c89e-461f-805b-5b37c4c7d395", "metadata": {}, "outputs": [], "source": [] From 2358ede3116ac276d17b1ff62bcacd72b20160bc Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Fri, 19 Jul 2024 16:22:46 +0530 Subject: [PATCH 06/12] clear up and data and mock references in asset before uploading to server , as they are referenced by id's --- packages/syft/src/syft/client/domain_client.py | 5 ++++- packages/syft/src/syft/service/blob_storage/service.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 7be3614125b..4ab321b108a 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -157,6 +157,10 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError: model_size += get_mb_size(asset.data) model_ref_action_ids.append(twin.id) + # Clear the Data and Mock , as they are uploaded as twin object + asset.data = None + asset.mock = None + # Update the progress bar and set the dynamic description pbar.set_description(f"Uploading: {asset.name}") pbar.update(1) @@ -182,7 +186,6 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError: return valid # Step 4. Upload Model to Model Stash - model.asset_list = [] return self.api.services.model.add(model=model) def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 7a6e1cf732e..702f448dd12 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -166,7 +166,9 @@ def get_files_from_bucket( return blob_files - @service_method(path="blob_storage.get_by_uid", name="get_by_uid") + @service_method( + path="blob_storage.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL + ) def get_blob_storage_entry_by_uid( self, context: AuthedServiceContext, uid: UID ) -> BlobStorageEntry | SyftError: From 05474aca5ee52cb666ffcdd89a0643e763bd9a2b Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:21:04 +0530 Subject: [PATCH 07/12] clear up action data cache when uploading to blob store --- .../service/enclave/domain_enclave_service.py | 13 +++++++++++- packages/syft/src/syft/service/model/model.py | 21 ++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/enclave/domain_enclave_service.py b/packages/syft/src/syft/service/enclave/domain_enclave_service.py index a2e483939a5..c18bef2d229 100644 --- a/packages/syft/src/syft/service/enclave/domain_enclave_service.py +++ b/packages/syft/src/syft/service/enclave/domain_enclave_service.py @@ -1,6 +1,7 @@ # stdlib import itertools from typing import Any +from typing import cast # relative from ...serde.serializable import serializable @@ -13,6 +14,7 @@ from ...service.user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ...store.document_store import DocumentStore from ...types.uid import UID +from ..action.action_object import ActionObject from ..code.user_code import UserCode from ..context import AuthedServiceContext from ..model.model import ModelRef @@ -145,7 +147,7 @@ def request_assets_upload( if node_identity.node_id == context.node.id ] asset_action_ids = tuple(itertools.chain.from_iterable(asset_action_ids_nested)) - action_objects = [ + action_objects: list[ActionObject] = [ context.node.get_service("actionservice") .get(context=root_context, uid=action_id) .ok() @@ -184,6 +186,15 @@ def request_assets_upload( _ = action_object.syft_action_data action_object.syft_blob_storage_entry_id = None blob_res = action_object._save_to_blob_storage(client=enclave_client) + + action_object.syft_blob_storage_entry_id = cast( + UID | None, action_object.syft_blob_storage_entry_id + ) + # For smaller data, we do not store in blob storage + # so for the cases, where we store in blob storage + # we need to clear the cache , to avoid sending the data again + if action_object.syft_blob_storage_entry_id: + action_object._clear_cache() if isinstance(blob_res, SyftError): return blob_res diff --git a/packages/syft/src/syft/service/model/model.py b/packages/syft/src/syft/service/model/model.py index e92e1ea10f8..5416c981a76 100644 --- a/packages/syft/src/syft/service/model/model.py +++ b/packages/syft/src/syft/service/model/model.py @@ -5,6 +5,7 @@ from textwrap import dedent from typing import Any from typing import ClassVar +from typing import cast # third party from IPython.display import display @@ -634,16 +635,26 @@ def load_data( asset_list = [] for asset_action_id in asset_action_ids: - res = admin_client.services.action.get(asset_action_id) - action_data = res.syft_action_data + action_object = admin_client.services.action.get(asset_action_id) + action_data = action_object.syft_action_data # Save to blob storage of remote client if provided if remote_client is not None: - res.syft_blob_storage_entry_id = None - blob_res = res._save_to_blob_storage(client=remote_client) + action_object.syft_blob_storage_entry_id = None + blob_res = action_object._save_to_blob_storage(client=remote_client) + # For smaller data, we do not store in blob storage + # so for the cases, where we store in blob storage + # we need to clear the cache , to avoid sending the data again + # stdlib + + action_object.syft_blob_storage_entry_id = cast( + UID | None, action_object.syft_blob_storage_entry_id + ) + if action_object.syft_blob_storage_entry_id: + action_object._clear_cache() if isinstance(blob_res, SyftError): return blob_res - asset_list.append(action_data if unwrap_action_data else res) + asset_list.append(action_data if unwrap_action_data else action_object) loaded_data = [model] + asset_list if wrap_ref_to_obj: From 483e845062d7d83eed4d3c73e6bb181492a678c9 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 22 Jul 2024 09:18:14 +0530 Subject: [PATCH 08/12] add pyinstrument profiling --- packages/syft/src/syft/node/routes.py | 14 ++-- packages/syft/src/syft/node/server.py | 93 +++++++++++++++++++++++++-- packages/syft/src/syft/orchestra.py | 13 ++++ 3 files changed, 109 insertions(+), 11 deletions(-) diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 37baaff90e8..814c52238e2 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -120,7 +120,7 @@ async def stream_upload(peer_uid: str, url_path: str, request: Request) -> Respo status_code=200, response_class=JSONResponse, ) - def root() -> dict[str, str]: + async def root() -> dict[str, str]: """ Currently, all service backends must satisfy either of the following requirements to pass the HTTP health checks sent to it from the GCE loadbalancer: 1. Respond with a @@ -131,11 +131,11 @@ def root() -> dict[str, str]: # provide information about the node in JSON @router.get("/metadata", response_class=JSONResponse) - def syft_metadata() -> JSONResponse: + async def syft_metadata() -> JSONResponse: return worker.metadata.to(NodeMetadataJSON) @router.get("/metadata_capnp") - def syft_metadata_capnp() -> Response: + async def syft_metadata_capnp() -> Response: result = worker.metadata return Response( serialize(result, to_bytes=True), @@ -154,7 +154,7 @@ def handle_syft_new_api( # get the SyftAPI object @router.get("/api") - def syft_new_api( + async def syft_new_api( request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE ) -> Response: user_verify_key: SyftVerifyKey = SyftVerifyKey.from_string(verify_key) @@ -178,7 +178,7 @@ def handle_new_api_call(data: bytes) -> Response: # make a request to the SyftAPI @router.post("/api_call") - def syft_new_api_call( + async def syft_new_api_call( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: if TRACE_MODE: @@ -241,7 +241,7 @@ def handle_register(data: bytes, node: AbstractNode) -> Response: # exchange email and password for a SyftSigningKey @router.post("/login", name="login", status_code=200) - def login( + async def login( request: Request, email: Annotated[str, Body(example="info@openmined.org")], password: Annotated[str, Body(example="changethis")], @@ -257,7 +257,7 @@ def login( return handle_login(email, password, worker) @router.post("/register", name="register", status_code=200) - def register( + async def register( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: if TRACE_MODE: diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index a3451f304a7..6b4177c146f 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from datetime import datetime import multiprocessing import multiprocessing.synchronize import os @@ -14,6 +15,8 @@ # third party from fastapi import APIRouter from fastapi import FastAPI +from fastapi import Request +from fastapi import Response from pydantic_settings import BaseSettings from pydantic_settings import SettingsConfigDict import requests @@ -40,6 +43,8 @@ multiprocessing.set_start_method("spawn", True) WAIT_TIME_SECONDS = 20 +# PYINSTRUMENT_ENABLED = os.getenv("PYINSTRUMENT_ENABLED", "False").lower() == "true" +PYINSTRUMENT_ENABLED = True class AppSettings(BaseSettings): @@ -57,6 +62,11 @@ class AppSettings(BaseSettings): association_request_auto_approval: bool = False background_tasks: bool = False + # Profiling inputs + profile: bool = False + profile_interval: float = 0.0001 + profile_dir: str | None = None + model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -72,21 +82,39 @@ def app_factory() -> FastAPI: raise NotImplementedError(f"node_type: {settings.node_type} is not supported") worker_class = worker_classes[settings.node_type] - kwargs = settings.model_dump() + worker_kwargs = settings.model_dump() + # Remove Profiling inputs + worker_kwargs.pop("profile") + worker_kwargs.pop("profile_interval") + worker_kwargs.pop("profile_dir") if settings.dev_mode: print( f"WARN: private key is based on node name: {settings.name} in dev_mode. " "Don't run this in production." ) - worker = worker_class.named(**kwargs) + worker = worker_class.named(**worker_kwargs) else: - worker = worker_class(**kwargs) + worker = worker_class(**worker_kwargs) app = FastAPI(title=settings.name) router = make_routes(worker=worker) api_router = APIRouter() api_router.include_router(router) app.include_router(api_router, prefix="/api/v2") + + # Register middlewares + _register_middlewares(app, settings) + + return app + + +def _register_middlewares(app: FastAPI, settings: AppSettings) -> None: + _register_cors_middleware(app) + if settings.profile: + _register_profiler(app, settings) + + +def _register_cors_middleware(app: FastAPI) -> None: app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -94,7 +122,53 @@ def app_factory() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) - return app + + +def _register_profiler(app: FastAPI, settings: AppSettings) -> None: + # third party + from pyinstrument import Profiler + + profiles_dir = ( + Path.cwd() / "profiles" + if settings.profile_dir is None + else Path(settings.profile_dir) / "profiles" + ) + + @app.middleware("http") + async def profile_request( + request: Request, call_next: Callable[[Request], Response] + ) -> Response: + with Profiler(interval=0.001, async_mode="enabled") as profiler: + response = await call_next(request) + + timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + profiler_output_html = profiler.output_html() + profiles_dir.mkdir(parents=True, exist_ok=True) + url_path = request.url.path.replace("/api/v2", "").replace("/", "-") + profile_output_path = ( + profiles_dir / f"{settings.name}-{timestamp}{url_path}.html" + ) + + with open(profile_output_path, "w") as f: + f.write(profiler_output_html) + + print( + f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" + ) + + return response + + +def _load_pyinstrument_jupyter_extension() -> None: + try: + # third party + from IPython import get_ipython + + ipython = get_ipython() # noqa: F821 + ipython.run_line_magic("load_ext", "pyinstrument") + print("Pyinstrument Jupyter extension loaded") + except Exception as e: + print(f"Error loading pyinstrument jupyter extension: {e}") def attach_debugger() -> None: @@ -187,6 +261,10 @@ def serve_node( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, + # Profiling inputs + profile: bool = False, + profile_interval: float = 0.0001, + profile_dir: str | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -194,6 +272,10 @@ def serve_node( if dev_mode: enable_autoreload() + # Load the Pyinstrument Jupyter extension if profile is enabled. + if profile: + _load_pyinstrument_jupyter_extension() + server_process = multiprocessing.Process( target=run_uvicorn, kwargs={ @@ -214,6 +296,9 @@ def serve_node( "background_tasks": background_tasks, "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, + "profile": profile, + "profile_interval": profile_interval, + "profile_dir": profile_dir, }, ) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index c07dce6a5d6..5bce4ba7a72 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -175,6 +175,9 @@ def deploy_to_python( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + profile: bool = False, + profile_interval: float = 0.0001, + profile_dir: str | None = None, ) -> NodeHandle: worker_classes = { NodeType.DOMAIN: Domain, @@ -204,6 +207,9 @@ def deploy_to_python( "background_tasks": background_tasks, "debug": debug, "migrate": migrate, + "profile": profile, + "profile_interval": profile_interval, + "profile_dir": profile_dir, } if port: @@ -305,6 +311,10 @@ def launch( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + # Profiling Related Input for in-memory fastapi server + profile: bool = False, + profile_interval: float = 0.0001, + profile_dir: str | None = None, ) -> NodeHandle: if dev_mode is True: thread_workers = True @@ -343,6 +353,9 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, + profile=profile, + profile_interval=profile_interval, + profile_dir=profile_dir, ) display( SyftInfo( From 3f012578f41d59e320dae0a4db50f6df15003b64 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 22 Jul 2024 12:00:36 +0530 Subject: [PATCH 09/12] cleanup profile settings --- packages/syft/src/syft/node/server.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 6b4177c146f..7e6eff632e0 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -43,8 +43,6 @@ multiprocessing.set_start_method("spawn", True) WAIT_TIME_SECONDS = 20 -# PYINSTRUMENT_ENABLED = os.getenv("PYINSTRUMENT_ENABLED", "False").lower() == "true" -PYINSTRUMENT_ENABLED = True class AppSettings(BaseSettings): @@ -138,23 +136,25 @@ def _register_profiler(app: FastAPI, settings: AppSettings) -> None: async def profile_request( request: Request, call_next: Callable[[Request], Response] ) -> Response: - with Profiler(interval=0.001, async_mode="enabled") as profiler: + with Profiler( + interval=settings.profile_interval, async_mode="enabled" + ) as profiler: response = await call_next(request) + # Profile File Name - Domain Name - Timestamp - URL Path timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") - profiler_output_html = profiler.output_html() profiles_dir.mkdir(parents=True, exist_ok=True) url_path = request.url.path.replace("/api/v2", "").replace("/", "-") profile_output_path = ( profiles_dir / f"{settings.name}-{timestamp}{url_path}.html" ) - with open(profile_output_path, "w") as f: - f.write(profiler_output_html) + # Write the profile to a HTML file + profiler.write_html(profile_output_path) - print( - f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" - ) + print( + f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" + ) return response From c3c50fd57f1f56345f017e85b30a39e7d7120e63 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 22 Jul 2024 12:36:44 +0530 Subject: [PATCH 10/12] change default profile interval to be 0.001 --- packages/syft/src/syft/node/server.py | 4 ++-- packages/syft/src/syft/orchestra.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 7e6eff632e0..19e1beb3f33 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -62,7 +62,7 @@ class AppSettings(BaseSettings): # Profiling inputs profile: bool = False - profile_interval: float = 0.0001 + profile_interval: float = 0.001 profile_dir: str | None = None model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -263,7 +263,7 @@ def serve_node( debug: bool = False, # Profiling inputs profile: bool = False, - profile_interval: float = 0.0001, + profile_interval: float = 0.001, profile_dir: str | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 5bce4ba7a72..bff399ba99e 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -176,7 +176,7 @@ def deploy_to_python( debug: bool = False, migrate: bool = False, profile: bool = False, - profile_interval: float = 0.0001, + profile_interval: float = 0.001, profile_dir: str | None = None, ) -> NodeHandle: worker_classes = { @@ -313,7 +313,7 @@ def launch( migrate: bool = False, # Profiling Related Input for in-memory fastapi server profile: bool = False, - profile_interval: float = 0.0001, + profile_interval: float = 0.001, profile_dir: str | None = None, ) -> NodeHandle: if dev_mode is True: From feb8a8658d080930894daff8787252b3f8f051df Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:30:11 +0530 Subject: [PATCH 11/12] revert to sync routes --- packages/syft/src/syft/node/routes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 814c52238e2..37baaff90e8 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -120,7 +120,7 @@ async def stream_upload(peer_uid: str, url_path: str, request: Request) -> Respo status_code=200, response_class=JSONResponse, ) - async def root() -> dict[str, str]: + def root() -> dict[str, str]: """ Currently, all service backends must satisfy either of the following requirements to pass the HTTP health checks sent to it from the GCE loadbalancer: 1. Respond with a @@ -131,11 +131,11 @@ async def root() -> dict[str, str]: # provide information about the node in JSON @router.get("/metadata", response_class=JSONResponse) - async def syft_metadata() -> JSONResponse: + def syft_metadata() -> JSONResponse: return worker.metadata.to(NodeMetadataJSON) @router.get("/metadata_capnp") - async def syft_metadata_capnp() -> Response: + def syft_metadata_capnp() -> Response: result = worker.metadata return Response( serialize(result, to_bytes=True), @@ -154,7 +154,7 @@ def handle_syft_new_api( # get the SyftAPI object @router.get("/api") - async def syft_new_api( + def syft_new_api( request: Request, verify_key: str, communication_protocol: PROTOCOL_TYPE ) -> Response: user_verify_key: SyftVerifyKey = SyftVerifyKey.from_string(verify_key) @@ -178,7 +178,7 @@ def handle_new_api_call(data: bytes) -> Response: # make a request to the SyftAPI @router.post("/api_call") - async def syft_new_api_call( + def syft_new_api_call( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: if TRACE_MODE: @@ -241,7 +241,7 @@ def handle_register(data: bytes, node: AbstractNode) -> Response: # exchange email and password for a SyftSigningKey @router.post("/login", name="login", status_code=200) - async def login( + def login( request: Request, email: Annotated[str, Body(example="info@openmined.org")], password: Annotated[str, Body(example="changethis")], @@ -257,7 +257,7 @@ async def login( return handle_login(email, password, worker) @router.post("/register", name="register", status_code=200) - async def register( + def register( request: Request, data: Annotated[bytes, Depends(get_body)] ) -> Response: if TRACE_MODE: From 80166a56d52eaaab77dcb4ed326e2a4049420a41 Mon Sep 17 00:00:00 2001 From: rasswanth-s <43314053+rasswanth-s@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:17:30 +0530 Subject: [PATCH 12/12] shift pyinstrument to make_routes, as async routes are not fully supported, disable pyinstrument middlware temporarily re-name appsettings to serversettings --- packages/syft/src/syft/node/routes.py | 47 ++++++++++++++++- packages/syft/src/syft/node/server.py | 52 +++++++------------ .../syft/src/syft/node/server_settings.py | 30 +++++++++++ 3 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 packages/syft/src/syft/node/server_settings.py diff --git a/packages/syft/src/syft/node/routes.py b/packages/syft/src/syft/node/routes.py index 37baaff90e8..632f84b12d8 100644 --- a/packages/syft/src/syft/node/routes.py +++ b/packages/syft/src/syft/node/routes.py @@ -2,7 +2,10 @@ import base64 import binascii from collections.abc import AsyncGenerator +from collections.abc import Callable +from datetime import datetime import logging +from pathlib import Path from typing import Annotated # third party @@ -34,12 +37,13 @@ from ..util.telemetry import TRACE_MODE from .credentials import SyftVerifyKey from .credentials import UserLoginCredentials +from .server_settings import ServerSettings from .worker import Worker logger = logging.getLogger(__name__) -def make_routes(worker: Worker) -> APIRouter: +def make_routes(worker: Worker, settings: ServerSettings | None = None) -> APIRouter: if TRACE_MODE: # third party try: @@ -49,6 +53,34 @@ def make_routes(worker: Worker) -> APIRouter: except Exception as e: logger.error("Failed to import opentelemetry", exc_info=e) + def _handle_profile( + request: Request, handler_func: Callable, *args: list, **kwargs: dict + ) -> Response: + if not settings: + raise Exception("Server Settings are required to enable profiling") + # third party + from pyinstrument import Profiler # Lazy Load + + profiles_dir = Path(settings.profile_dir or Path.cwd()) / "profiles" + profiles_dir.mkdir(parents=True, exist_ok=True) + + with Profiler( + interval=settings.profile_interval, async_mode="enabled" + ) as profiler: + response = handler_func(*args, **kwargs) + + timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") + url_path = request.url.path.replace("/api/v2", "").replace("/", "-") + profile_output_path = ( + profiles_dir / f"{settings.name}-{timestamp}{url_path}.html" + ) + profiler.write_html(profile_output_path) + + logger.info( + f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds" + ) + return response + router = APIRouter() async def get_body(request: Request) -> bytes: @@ -165,6 +197,13 @@ def syft_new_api( kind=trace.SpanKind.SERVER, ): return handle_syft_new_api(user_verify_key, communication_protocol) + elif settings and settings.profile: + return _handle_profile( + request, + handle_syft_new_api, + user_verify_key, + communication_protocol, + ) else: return handle_syft_new_api(user_verify_key, communication_protocol) @@ -188,6 +227,8 @@ def syft_new_api_call( kind=trace.SpanKind.SERVER, ): return handle_new_api_call(data) + elif settings and settings.profile: + return _handle_profile(request, handle_new_api_call, data) else: return handle_new_api_call(data) @@ -253,6 +294,8 @@ def login( kind=trace.SpanKind.SERVER, ): return handle_login(email, password, worker) + elif settings and settings.profile: + return _handle_profile(request, handle_login, email, password, worker) else: return handle_login(email, password, worker) @@ -267,6 +310,8 @@ def register( kind=trace.SpanKind.SERVER, ): return handle_register(data, worker) + elif settings and settings.profile: + return _handle_profile(request, handle_register, data, worker) else: return handle_register(data, worker) diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 19e1beb3f33..207860c8b83 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -17,8 +17,6 @@ from fastapi import FastAPI from fastapi import Request from fastapi import Response -from pydantic_settings import BaseSettings -from pydantic_settings import SettingsConfigDict import requests from starlette.middleware.cors import CORSMiddleware import uvicorn @@ -34,6 +32,7 @@ from .gateway import Gateway from .node import NodeType from .routes import make_routes +from .server_settings import ServerSettings from .utils import get_named_node_uid from .utils import remove_temp_dir_for_node @@ -45,31 +44,8 @@ WAIT_TIME_SECONDS = 20 -class AppSettings(BaseSettings): - name: str - node_type: NodeType = NodeType.DOMAIN - node_side_type: NodeSideType = NodeSideType.HIGH_SIDE - processes: int = 1 - reset: bool = False - dev_mode: bool = False - enable_warnings: bool = False - in_memory_workers: bool = True - queue_port: int | None = None - create_producer: bool = False - n_consumers: int = 0 - association_request_auto_approval: bool = False - background_tasks: bool = False - - # Profiling inputs - profile: bool = False - profile_interval: float = 0.001 - profile_dir: str | None = None - - model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") - - def app_factory() -> FastAPI: - settings = AppSettings() + settings = ServerSettings() worker_classes = { NodeType.DOMAIN: Domain, @@ -95,7 +71,7 @@ def app_factory() -> FastAPI: worker = worker_class(**worker_kwargs) app = FastAPI(title=settings.name) - router = make_routes(worker=worker) + router = make_routes(worker=worker, settings=settings) api_router = APIRouter() api_router.include_router(router) app.include_router(api_router, prefix="/api/v2") @@ -106,10 +82,20 @@ def app_factory() -> FastAPI: return app -def _register_middlewares(app: FastAPI, settings: AppSettings) -> None: +def _register_middlewares(app: FastAPI, settings: ServerSettings) -> None: _register_cors_middleware(app) - if settings.profile: - _register_profiler(app, settings) + + # As currently sync routes are not supported in pyinstrument + # we are not registering the profiler middleware for sync routes + # as currently most of our routes are sync routes in syft (routes.py) + # ex: syft_new_api, syft_new_api_call, login, register + # we should either convert these routes to async or + # wait until pyinstrument supports sync routes + # The reason we cannot our sync routes to async is because + # we have blocking IO operations, like the requests library, like if one route calls to + # itself, it will block the event loop and the server will hang + # if settings.profile: + # _register_profiler(app, settings) def _register_cors_middleware(app: FastAPI) -> None: @@ -122,7 +108,7 @@ def _register_cors_middleware(app: FastAPI) -> None: ) -def _register_profiler(app: FastAPI, settings: AppSettings) -> None: +def _register_profiler(app: FastAPI, settings: ServerSettings) -> None: # third party from pyinstrument import Profiler @@ -216,7 +202,7 @@ def run_uvicorn( attach_debugger() # Set up all kwargs as environment variables so that they can be accessed in the app_factory function. - env_prefix = AppSettings.model_config.get("env_prefix", "") + env_prefix = ServerSettings.model_config.get("env_prefix", "") for key, value in kwargs.items(): key_with_prefix = f"{env_prefix}{key.upper()}" os.environ[key_with_prefix] = str(value) @@ -275,6 +261,8 @@ def serve_node( # Load the Pyinstrument Jupyter extension if profile is enabled. if profile: _load_pyinstrument_jupyter_extension() + if profile_dir is None: + profile_dir = str(Path.cwd()) server_process = multiprocessing.Process( target=run_uvicorn, diff --git a/packages/syft/src/syft/node/server_settings.py b/packages/syft/src/syft/node/server_settings.py new file mode 100644 index 00000000000..3c57606ec02 --- /dev/null +++ b/packages/syft/src/syft/node/server_settings.py @@ -0,0 +1,30 @@ +# third party +from pydantic_settings import BaseSettings +from pydantic_settings import SettingsConfigDict + +# relative +from ..abstract_node import NodeSideType +from ..abstract_node import NodeType + + +class ServerSettings(BaseSettings): + name: str + node_type: NodeType = NodeType.DOMAIN + node_side_type: NodeSideType = NodeSideType.HIGH_SIDE + processes: int = 1 + reset: bool = False + dev_mode: bool = False + enable_warnings: bool = False + in_memory_workers: bool = True + queue_port: int | None = None + create_producer: bool = False + n_consumers: int = 0 + association_request_auto_approval: bool = False + background_tasks: bool = False + + # Profiling inputs + profile: bool = False + profile_interval: float = 0.001 + profile_dir: str | None = None + + model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")