From f389274815b1e03aa55824a99bf2596c8515d720 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 16:22:29 +0100 Subject: [PATCH 1/5] init session from model id and version --- .../verifiable_mnist/verifiable_mnist.ipynb | 71 +++++++++++++------ giza_actions/model.py | 58 ++++++++++----- 2 files changed, 88 insertions(+), 41 deletions(-) diff --git a/examples/verifiable_mnist/verifiable_mnist.ipynb b/examples/verifiable_mnist/verifiable_mnist.ipynb index 8d1a939..47effee 100644 --- a/examples/verifiable_mnist/verifiable_mnist.ipynb +++ b/examples/verifiable_mnist/verifiable_mnist.ipynb @@ -41,29 +41,6 @@ "# Login to Giza and create a Workspace" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "https://api-dev.gizatech.xyz\n" - ] - } - ], - "source": [ - "# TODO: remove this cell when it deploys on main branch\n", - "\n", - "import os\n", - "import uuid\n", - "\n", - "os.environ['GIZA_API_HOST'] = 'https://api-dev.gizatech.xyz'\n", - "print(os.environ['GIZA_API_HOST'])\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1982,6 +1959,54 @@ "Now, let's make a prediction with the Cairo model (`veriable=True`)." ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torchvision\n", + "import numpy as np\n", + "import logging\n", + "from scipy.ndimage import zoom\n", + "from giza_actions.action import action, Action\n", + "from giza_actions.task import task\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "from giza_actions.model import GizaModel\n", + "\n", + "MODEL_ID = 296 # Update with your model ID\n", + "VERSION_ID = 1 # Update with your version ID\n", + "\n", + "def prediction(model_id, version_id):\n", + " model = GizaModel(id=model_id, version=version_id)\n", + "\n", + " print(model.session.get_outputs())\n", + "\n", + " # (result, request_id) = model.predict(\n", + " # input_feed={\"image\": image}, verifiable=True, output_dtype=\"Tensor\"\n", + " # )\n", + "\n", + " # # Convert result to a PyTorch tensor\n", + " # probabilities = torch.tensor(result)\n", + " # # Use argmax to get the predicted class\n", + " # predicted_class = torch.argmax(probabilities, dim=1)\n", + "\n", + " # return predicted_class, request_id\n", + "\n", + "prediction(MODEL_ID, VERSION_ID)" + ] + }, { "cell_type": "code", "execution_count": 4, diff --git a/giza_actions/model.py b/giza_actions/model.py index acd9a96..c49f976 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -50,13 +50,15 @@ def __init__( output_path: Optional[str] = None, ): if model_path is None and id is None and version is None: - raise ValueError("Either model_path or id and version must be provided.") + raise ValueError( + "Either model_path or id and version must be provided.") if model_path is None and (id is None or version is None): raise ValueError("Both id and version must be provided.") if model_path and (id or version): - raise ValueError("Either model_path or id and version must be provided.") + raise ValueError( + "Either model_path or id and version must be provided.") if model_path and id and version: raise ValueError( @@ -70,13 +72,13 @@ def __init__( self.version_client = VersionsClient(API_HOST) self.api_client = ApiClient(API_HOST) self._get_credentials() + self.model = self._get_model(id) self.version = self._get_version(id, version) - print(self.version) - self.session = None + self.session = ort.InferenceSession( + self._download_model(id, output_path) + ) self.framework = self.version.framework self.uri = self._retrieve_uri(id, version) - if output_path: - self._download_model(id, version, output_path) def _retrieve_uri(self, model_id: int, version_id: int): """ @@ -96,6 +98,18 @@ def _retrieve_uri(self, model_id: int, version_id: int): else: return f"{uri}/predict" + def _get_model(self, model_id: int): + """ + Retrieves the model specified by model_id. + + Args: + model_id (int): The unique identifier of the model. + + Returns: + The model. + """ + return self.model_client.get(model_id) + def _get_version(self, model_id: int, version_id: int): """ Retrieves the version of the model specified by model_id and version_id. @@ -109,13 +123,12 @@ def _get_version(self, model_id: int, version_id: int): """ return self.version_client.get(model_id, version_id) - def _download_model(self, model_id: int, version_id: int, output_path: str): + def _download_model(self, model_id: int, output_path: Optional[str]): """ Downloads the model specified by model_id and version_id to the given output_path. Args: model_id (int): The unique identifier of the model. - version_id (int): The version number of the model. output_path (str): The file path where the downloaded model should be saved. Raises: @@ -127,18 +140,26 @@ def _download_model(self, model_id: int, version_id: int, output_path: str): f"Model version status is not completed {self.version.status}" ) - print("ONNX model is ready, downloading! ✅") - onnx_model = self.api_client.download_original(model_id, self.version.version) + onnx_model = self.version_client.download_original( + model_id, self.version.version) + + if output_path is not None: + + print("ONNX model is ready, downloading! ✅") + + if ".onnx" in output_path: + save_path = Path(output_path) + else: + save_path = Path(f"{output_path}/{self.model.name}.onnx") - model_name = self.version.original_model_path.split("/")[-1] - save_path = Path(output_path) / model_name + with open(save_path, "wb") as f: + f.write(onnx_model) - with open(save_path, "wb") as f: - f.write(onnx_model) + print(f"ONNX model saved at: {save_path}") + self.session = ort.InferenceSession(save_path) + print("Model ready for inference with ONNX Runtime! ✅") - print(f"ONNX model saved at: {save_path}") - self.session = ort.InferenceSession(save_path) - print("Model ready for inference with ONNX Runtime! ✅") + return onnx_model def _get_credentials(self): """ @@ -209,7 +230,8 @@ def predict( if self.framework == Framework.CAIRO: logging.info("Serialized: ", serialized_output) - preds = self._parse_cairo_response(serialized_output, output_dtype) + preds = self._parse_cairo_response( + serialized_output, output_dtype) elif self.framework == Framework.EZKL: preds = np.array(serialized_output[0]) return (preds, request_id) From d1a4e323245f8d5b20a7aa3e9b1fa43b3f75890d Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 17:52:28 +0100 Subject: [PATCH 2/5] create inner function to set session --- .../verifiable_mnist/verifiable_mnist.ipynb | 32 ++++++----- giza_actions/model.py | 55 +++++++++++++------ 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/examples/verifiable_mnist/verifiable_mnist.ipynb b/examples/verifiable_mnist/verifiable_mnist.ipynb index 47effee..5423559 100644 --- a/examples/verifiable_mnist/verifiable_mnist.ipynb +++ b/examples/verifiable_mnist/verifiable_mnist.ipynb @@ -1965,10 +1965,17 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" + "ename": "TypeError", + "evalue": "Unable to load from type ''", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 27\u001b[0m\n\u001b[1;32m 23\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(onnx_model)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28mprint\u001b[39m(onnx_model)\n\u001b[0;32m---> 27\u001b[0m \u001b[43mprediction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL_ID\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mVERSION_ID\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[1], line 19\u001b[0m, in \u001b[0;36mprediction\u001b[0;34m(model_id, version_id)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprediction\u001b[39m(model_id, version_id):\n\u001b[0;32m---> 19\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGizaModel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mversion_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39monnx_model\n\u001b[1;32m 23\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(onnx_model)\n", + "File \u001b[0;32m~/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/model.py:78\u001b[0m, in \u001b[0;36mGizaModel.__init__\u001b[0;34m(self, model_path, id, version, output_path)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mversion \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_version(\u001b[38;5;28mid\u001b[39m, version)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39monnx_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mframework \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mframework\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muri \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_retrieve_uri(\u001b[38;5;28mid\u001b[39m, version)\n", + "File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:405\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes \u001b[38;5;241m=\u001b[39m path_or_bytes \u001b[38;5;66;03m# TODO: This is bad as we're holding the memory indefinitely\u001b[39;00m\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 405\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnable to load from type \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(path_or_bytes)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sess_options \u001b[38;5;241m=\u001b[39m sess_options\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sess_options_initial \u001b[38;5;241m=\u001b[39m sess_options\n", + "\u001b[0;31mTypeError\u001b[0m: Unable to load from type ''" ] } ], @@ -1984,6 +1991,8 @@ "from giza_actions.task import task\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from giza_actions.model import GizaModel\n", + "import onnx\n", + "\n", "\n", "MODEL_ID = 296 # Update with your model ID\n", "VERSION_ID = 1 # Update with your version ID\n", @@ -1991,18 +2000,11 @@ "def prediction(model_id, version_id):\n", " model = GizaModel(id=model_id, version=version_id)\n", "\n", - " print(model.session.get_outputs())\n", - "\n", - " # (result, request_id) = model.predict(\n", - " # input_feed={\"image\": image}, verifiable=True, output_dtype=\"Tensor\"\n", - " # )\n", - "\n", - " # # Convert result to a PyTorch tensor\n", - " # probabilities = torch.tensor(result)\n", - " # # Use argmax to get the predicted class\n", - " # predicted_class = torch.argmax(probabilities, dim=1)\n", + " onnx_model = model.onnx_model\n", "\n", - " # return predicted_class, request_id\n", + " onnx_model = onnx.load(onnx_model)\n", + " \n", + " print(onnx_model)\n", "\n", "prediction(MODEL_ID, VERSION_ID)" ] diff --git a/giza_actions/model.py b/giza_actions/model.py index c49f976..dce9712 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -74,11 +74,11 @@ def __init__( self._get_credentials() self.model = self._get_model(id) self.version = self._get_version(id, version) - self.session = ort.InferenceSession( - self._download_model(id, output_path) - ) + self.session = self._set_session(id) self.framework = self.version.framework self.uri = self._retrieve_uri(id, version) + if output_path: + self._download_model(id, output_path) def _retrieve_uri(self, model_id: int, version_id: int): """ @@ -123,7 +123,32 @@ def _get_version(self, model_id: int, version_id: int): """ return self.version_client.get(model_id, version_id) - def _download_model(self, model_id: int, output_path: Optional[str]): + def _set_session(self, model_id: int): + """ + Set onnxruntime session for the model specified by model_id. + + Args: + model_id (int): The unique identifier of the model. + + Raises: + ValueError: If the model version status is not completed. + """ + + if self.version.status != VersionStatus.COMPLETED: + raise ValueError( + f"Model version status is not completed {self.version.status}" + ) + + try: + onnx_model = self.version_client.download_original( + model_id, self.version.version) + + return ort.InferenceSession(onnx_model) + + except: + return None + + def _download_model(self, model_id: int, output_path: str): """ Downloads the model specified by model_id and version_id to the given output_path. @@ -143,23 +168,17 @@ def _download_model(self, model_id: int, output_path: Optional[str]): onnx_model = self.version_client.download_original( model_id, self.version.version) - if output_path is not None: - - print("ONNX model is ready, downloading! ✅") + print("ONNX model is ready, downloading! ✅") - if ".onnx" in output_path: - save_path = Path(output_path) - else: - save_path = Path(f"{output_path}/{self.model.name}.onnx") - - with open(save_path, "wb") as f: - f.write(onnx_model) + if ".onnx" in output_path: + save_path = Path(output_path) + else: + save_path = Path(f"{output_path}/{self.model.name}.onnx") - print(f"ONNX model saved at: {save_path}") - self.session = ort.InferenceSession(save_path) - print("Model ready for inference with ONNX Runtime! ✅") + with open(save_path, "wb") as f: + f.write(onnx_model) - return onnx_model + print(f"ONNX model saved at: {save_path} ✅") def _get_credentials(self): """ From 058cc12e9fefde40dd0fbd9cf7110e7c6a32611e Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 18:32:24 +0100 Subject: [PATCH 3/5] get output dtype --- .../verifiable_mnist/verifiable_mnist.ipynb | 110 +++++------------- giza_actions/model.py | 47 +++++++- 2 files changed, 76 insertions(+), 81 deletions(-) diff --git a/examples/verifiable_mnist/verifiable_mnist.ipynb b/examples/verifiable_mnist/verifiable_mnist.ipynb index 5423559..dbb8c4e 100644 --- a/examples/verifiable_mnist/verifiable_mnist.ipynb +++ b/examples/verifiable_mnist/verifiable_mnist.ipynb @@ -1961,72 +1961,22 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "Unable to load from type ''", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 27\u001b[0m\n\u001b[1;32m 23\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(onnx_model)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28mprint\u001b[39m(onnx_model)\n\u001b[0;32m---> 27\u001b[0m \u001b[43mprediction\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL_ID\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mVERSION_ID\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[1], line 19\u001b[0m, in \u001b[0;36mprediction\u001b[0;34m(model_id, version_id)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprediction\u001b[39m(model_id, version_id):\n\u001b[0;32m---> 19\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGizaModel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mversion\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mversion_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39monnx_model\n\u001b[1;32m 23\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(onnx_model)\n", - "File \u001b[0;32m~/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/model.py:78\u001b[0m, in \u001b[0;36mGizaModel.__init__\u001b[0;34m(self, model_path, id, version, output_path)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mversion \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_version(\u001b[38;5;28mid\u001b[39m, version)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39monnx_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mframework \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mversion\u001b[38;5;241m.\u001b[39mframework\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muri \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_retrieve_uri(\u001b[38;5;28mid\u001b[39m, version)\n", - "File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:405\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m 403\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes \u001b[38;5;241m=\u001b[39m path_or_bytes \u001b[38;5;66;03m# TODO: This is bad as we're holding the memory indefinitely\u001b[39;00m\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 405\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnable to load from type \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(path_or_bytes)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sess_options \u001b[38;5;241m=\u001b[39m sess_options\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sess_options_initial \u001b[38;5;241m=\u001b[39m sess_options\n", - "\u001b[0;31mTypeError\u001b[0m: Unable to load from type ''" - ] - } - ], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import torchvision\n", - "import numpy as np\n", - "import logging\n", - "from scipy.ndimage import zoom\n", - "from giza_actions.action import action, Action\n", - "from giza_actions.task import task\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "from giza_actions.model import GizaModel\n", - "import onnx\n", - "\n", - "\n", - "MODEL_ID = 296 # Update with your model ID\n", - "VERSION_ID = 1 # Update with your version ID\n", - "\n", - "def prediction(model_id, version_id):\n", - " model = GizaModel(id=model_id, version=version_id)\n", - "\n", - " onnx_model = model.onnx_model\n", - "\n", - " onnx_model = onnx.load(onnx_model)\n", - " \n", - " print(onnx_model)\n", - "\n", - "prediction(MODEL_ID, VERSION_ID)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:337: UserWarning: A task named 'Preprocess Image' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:7' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:332: UserWarning: A task named 'Preprocess Image' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:7' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", "\n", " `@task(name='my_unique_name', ...)`\n", " warnings.warn(\n", - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:337: UserWarning: A task named 'Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:20' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:332: UserWarning: A task named 'Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:20' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", "\n", " `@task(name='my_unique_name', ...)`\n", " warnings.warn(\n", - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/flows.py:337: UserWarning: A flow named 'Execution: Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/action.py:36' conflicts with another flow. Consider specifying a unique `name` parameter in the flow definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/flows.py:336: UserWarning: A flow named 'Execution: Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/action.py:36' conflicts with another flow. Consider specifying a unique `name` parameter in the flow definition:\n", "\n", " `@flow(name='my_unique_name', ...)`\n", " warnings.warn(\n" @@ -2035,11 +1985,11 @@ { "data": { "text/html": [ - "
11:08:42.535 | INFO    | Created flow run 'industrious-wrasse' for flow 'Execution: Prediction with Cairo'\n",
+       "
18:29:27.522 | INFO    | Created flow run 'shapeless-snail' for flow 'Execution: Prediction with Cairo'\n",
        "
\n" ], "text/plain": [ - "11:08:42.535 | \u001b[36mINFO\u001b[0m | Created flow run\u001b[35m 'industrious-wrasse'\u001b[0m for flow\u001b[1;35m 'Execution: Prediction with Cairo'\u001b[0m\n" + "18:29:27.522 | \u001b[36mINFO\u001b[0m | Created flow run\u001b[35m 'shapeless-snail'\u001b[0m for flow\u001b[1;35m 'Execution: Prediction with Cairo'\u001b[0m\n" ] }, "metadata": {}, @@ -2048,11 +1998,11 @@ { "data": { "text/html": [ - "
11:08:42.537 | INFO    | Action run 'industrious-wrasse' - View at https://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/40572d39-7a1f-46be-873f-4737742bff89\n",
+       "
18:29:27.523 | INFO    | Action run 'shapeless-snail' - View at https://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/5068c561-0a8c-4fcf-9bff-8cb7b2617d19\n",
        "
\n" ], "text/plain": [ - "11:08:42.537 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - View at \u001b[94mhttps://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/40572d39-7a1f-46be-873f-4737742bff89\u001b[0m\n" + "18:29:27.523 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - View at \u001b[94mhttps://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/5068c561-0a8c-4fcf-9bff-8cb7b2617d19\u001b[0m\n" ] }, "metadata": {}, @@ -2061,11 +2011,11 @@ { "data": { "text/html": [ - "
11:08:43.118 | INFO    | Action run 'industrious-wrasse' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n",
+       "
18:29:27.740 | INFO    | Action run 'shapeless-snail' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n",
        "
\n" ], "text/plain": [ - "11:08:43.118 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n" + "18:29:27.740 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n" ] }, "metadata": {}, @@ -2074,11 +2024,11 @@ { "data": { "text/html": [ - "
11:08:43.121 | INFO    | Action run 'industrious-wrasse' - Executing 'Preprocess Image-0' immediately...\n",
+       "
18:29:27.742 | INFO    | Action run 'shapeless-snail' - Executing 'Preprocess Image-0' immediately...\n",
        "
\n" ], "text/plain": [ - "11:08:43.121 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Executing 'Preprocess Image-0' immediately...\n" + "18:29:27.742 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Executing 'Preprocess Image-0' immediately...\n" ] }, "metadata": {}, @@ -2087,11 +2037,11 @@ { "data": { "text/html": [ - "
11:08:43.632 | INFO    | Task run 'Preprocess Image-0' - Finished in state Completed()\n",
+       "
18:29:28.006 | INFO    | Task run 'Preprocess Image-0' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:43.632 | \u001b[36mINFO\u001b[0m | Task run 'Preprocess Image-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:28.006 | \u001b[36mINFO\u001b[0m | Task run 'Preprocess Image-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2100,11 +2050,11 @@ { "data": { "text/html": [ - "
11:08:43.749 | INFO    | Action run 'industrious-wrasse' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n",
+       "
18:29:28.112 | INFO    | Action run 'shapeless-snail' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n",
        "
\n" ], "text/plain": [ - "11:08:43.749 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n" + "18:29:28.112 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n" ] }, "metadata": {}, @@ -2113,11 +2063,11 @@ { "data": { "text/html": [ - "
11:08:43.751 | INFO    | Action run 'industrious-wrasse' - Executing 'Prediction with Cairo-0' immediately...\n",
+       "
18:29:28.115 | INFO    | Action run 'shapeless-snail' - Executing 'Prediction with Cairo-0' immediately...\n",
        "
\n" ], "text/plain": [ - "11:08:43.751 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Executing 'Prediction with Cairo-0' immediately...\n" + "18:29:28.115 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Executing 'Prediction with Cairo-0' immediately...\n" ] }, "metadata": {}, @@ -2134,11 +2084,11 @@ { "data": { "text/html": [ - "
11:08:47.958 | INFO    | Task run 'Prediction with Cairo-0' - Finished in state Completed()\n",
+       "
18:29:32.420 | INFO    | Task run 'Prediction with Cairo-0' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:47.958 | \u001b[36mINFO\u001b[0m | Task run 'Prediction with Cairo-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:32.420 | \u001b[36mINFO\u001b[0m | Task run 'Prediction with Cairo-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2147,11 +2097,11 @@ { "data": { "text/html": [ - "
11:08:47.961 | INFO    | Action run 'industrious-wrasse' - Result:  tensor([0])\n",
+       "
18:29:32.421 | INFO    | Action run 'shapeless-snail' - Result:  tensor([0])\n",
        "
\n" ], "text/plain": [ - "11:08:47.961 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Result: tensor([0])\n" + "18:29:32.421 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Result: tensor([0])\n" ] }, "metadata": {}, @@ -2160,11 +2110,11 @@ { "data": { "text/html": [ - "
11:08:47.962 | INFO    | Action run 'industrious-wrasse' - Request id:  \"cd38a8593d2c429cb8c45f5e37939409\"\n",
+       "
18:29:32.422 | INFO    | Action run 'shapeless-snail' - Request id:  \"b2484bba4b5644df80ab7eed20f1c87b\"\n",
        "
\n" ], "text/plain": [ - "11:08:47.962 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Request id: \"cd38a8593d2c429cb8c45f5e37939409\"\n" + "18:29:32.422 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Request id: \"b2484bba4b5644df80ab7eed20f1c87b\"\n" ] }, "metadata": {}, @@ -2173,11 +2123,11 @@ { "data": { "text/html": [ - "
11:08:48.087 | INFO    | Action run 'industrious-wrasse' - Finished in state Completed()\n",
+       "
18:29:32.508 | INFO    | Action run 'shapeless-snail' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:48.087 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:32.508 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2186,10 +2136,10 @@ { "data": { "text/plain": [ - "(tensor([0]), '\"cd38a8593d2c429cb8c45f5e37939409\"')" + "(tensor([0]), '\"b2484bba4b5644df80ab7eed20f1c87b\"')" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -2201,6 +2151,7 @@ "MODEL_ID = 296 # Update with your model ID\n", "VERSION_ID = 1 # Update with your version ID\n", "\n", + "\n", "@task(name=f'Preprocess Image')\n", "def preprocess_image(image_path):\n", " from PIL import Image\n", @@ -2214,12 +2165,13 @@ " image = image.reshape(1, 196) # Reshape to (1, 196) for model input\n", " return image\n", "\n", + "\n", "@task(name=f'Prediction with Cairo')\n", "def prediction(image, model_id, version_id):\n", " model = GizaModel(id=model_id, version=version_id)\n", "\n", " (result, request_id) = model.predict(\n", - " input_feed={\"image\": image}, verifiable=True, output_dtype=\"Tensor\"\n", + " input_feed={\"image\": image}, verifiable=True\n", " )\n", "\n", " # Convert result to a PyTorch tensor\n", diff --git a/giza_actions/model.py b/giza_actions/model.py index dce9712..10c868b 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -5,6 +5,7 @@ import numpy as np import onnxruntime as ort +import onnx import requests from giza import API_HOST from giza.client import ApiClient, ModelsClient, VersionsClient @@ -193,7 +194,7 @@ def predict( input_feed: Optional[Dict] = None, verifiable: bool = False, fp_impl="FP16x16", - output_dtype: str = "tensor_fixed_point", + custom_output_dtype: Optional[str] = None, job_size: str = "M", ): """ @@ -205,7 +206,7 @@ def predict( input_feed (Optional[Dict]): A dictionary containing the input data for prediction. Defaults to None. verifiable (bool): A flag indicating whether to use the verifiable computation endpoint. Defaults to False. fp_impl (str): The fixed point implementation to use, when computed in verifiable mode. Defaults to "FP16x16". - output_dtype (str): The data type of the result when computed in verifiable mode. Defaults to "tensor_fixed_point". + custom_output_dtype (Optional[str]): Specify the data type of the result when computed in verifiable mode. Defaults to None. Returns: A tuple (predictions, request_id) where predictions is the result of the prediction and request_id @@ -249,6 +250,11 @@ def predict( if self.framework == Framework.CAIRO: logging.info("Serialized: ", serialized_output) + if custom_output_dtype is None: + output_dtype = self._get_output_dtype() + else: + output_dtype = custom_output_dtype + preds = self._parse_cairo_response( serialized_output, output_dtype) elif self.framework == Framework.EZKL: @@ -359,3 +365,40 @@ def _parse_cairo_response(self, response, data_type: str): The deserialized prediction result. """ return deserialize(response, data_type) + + def _get_output_dtype(self): + """ + Retrieve the Cairo output data type base on the operator type of the final node. + + Returns: + The output dtype as a string. + """ + + file = self.version_client.download_original( + self.model.id, self.version.version + ) + + model = onnx.load_model_from_string(file) + graph = model.graph + output_tensor_name = graph.output[0].name + + def find_producing_node(graph, tensor_name): + for node in graph.node: + if tensor_name in node.output: + return node + return None + + final_node = find_producing_node(graph, output_tensor_name) + optype = final_node.op_type + + match optype: + case "TreeEnsembleClassifier": + return "(Span, MutMatrix)" + + case "TreeEnsembleRegressor": + return "MutMatrix::" + + case "LinearClassifier": + return "(Span, Tensor)" + case _: + return "Tensor" From 4a4b156ac70610f88dff03f18345fc8ff30a62e8 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 18:44:41 +0100 Subject: [PATCH 4/5] fix bare except error --- giza_actions/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/giza_actions/model.py b/giza_actions/model.py index 10c868b..c7da8a8 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -146,7 +146,8 @@ def _set_session(self, model_id: int): return ort.InferenceSession(onnx_model) - except: + except Exception as e: + print(f"Could not download model: {e}") return None def _download_model(self, model_id: int, output_path: str): From 47a30aaff7ece207497e071fd0237e7f4a8ff359 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 19 Mar 2024 09:58:51 +0100 Subject: [PATCH 5/5] remove unnecessary model_id param --- giza_actions/model.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/giza_actions/model.py b/giza_actions/model.py index c7da8a8..143d501 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -74,26 +74,25 @@ def __init__( self.api_client = ApiClient(API_HOST) self._get_credentials() self.model = self._get_model(id) - self.version = self._get_version(id, version) - self.session = self._set_session(id) + self.version = self._get_version(version) + self.session = self._set_session() self.framework = self.version.framework - self.uri = self._retrieve_uri(id, version) + self.uri = self._retrieve_uri(version) if output_path: - self._download_model(id, output_path) + self._download_model(output_path) - def _retrieve_uri(self, model_id: int, version_id: int): + def _retrieve_uri(self, version_id: int): """ Retrieves the URI for making prediction requests to a deployed model. Args: - model_id (int): The unique identifier of the model. version_id (int): The version number of the model. Returns: The URI for making prediction requests to the deployed model. """ # Different URI per framework - uri = get_endpoint_uri(model_id, version_id) + uri = get_endpoint_uri(self.model.id, version_id) if self.framework == Framework.CAIRO: return f"{uri}/cairo_run" else: @@ -111,25 +110,21 @@ def _get_model(self, model_id: int): """ return self.model_client.get(model_id) - def _get_version(self, model_id: int, version_id: int): + def _get_version(self, version_id: int): """ - Retrieves the version of the model specified by model_id and version_id. + Retrieves the version of the model specified by model id and version id. Args: - model_id (int): The unique identifier of the model. version_id (int): The version number of the model. Returns: The version of the model. """ - return self.version_client.get(model_id, version_id) + return self.version_client.get(self.model.id, version_id) - def _set_session(self, model_id: int): + def _set_session(self): """ - Set onnxruntime session for the model specified by model_id. - - Args: - model_id (int): The unique identifier of the model. + Set onnxruntime session for the model specified by model id. Raises: ValueError: If the model version status is not completed. @@ -142,7 +137,7 @@ def _set_session(self, model_id: int): try: onnx_model = self.version_client.download_original( - model_id, self.version.version) + self.model.id, self.version.version) return ort.InferenceSession(onnx_model) @@ -150,12 +145,11 @@ def _set_session(self, model_id: int): print(f"Could not download model: {e}") return None - def _download_model(self, model_id: int, output_path: str): + def _download_model(self, output_path: str): """ - Downloads the model specified by model_id and version_id to the given output_path. + Downloads the model specified by model id and version id to the given output_path. Args: - model_id (int): The unique identifier of the model. output_path (str): The file path where the downloaded model should be saved. Raises: @@ -168,7 +162,7 @@ def _download_model(self, model_id: int, output_path: str): ) onnx_model = self.version_client.download_original( - model_id, self.version.version) + self.model.id, self.version.version) print("ONNX model is ready, downloading! ✅")