diff --git a/examples/verifiable_mnist/verifiable_mnist.ipynb b/examples/verifiable_mnist/verifiable_mnist.ipynb index 8d1a939..dbb8c4e 100644 --- a/examples/verifiable_mnist/verifiable_mnist.ipynb +++ b/examples/verifiable_mnist/verifiable_mnist.ipynb @@ -41,29 +41,6 @@ "# Login to Giza and create a Workspace" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "https://api-dev.gizatech.xyz\n" - ] - } - ], - "source": [ - "# TODO: remove this cell when it deploys on main branch\n", - "\n", - "import os\n", - "import uuid\n", - "\n", - "os.environ['GIZA_API_HOST'] = 'https://api-dev.gizatech.xyz'\n", - "print(os.environ['GIZA_API_HOST'])\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1984,22 +1961,22 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:337: UserWarning: A task named 'Preprocess Image' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:7' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:332: UserWarning: A task named 'Preprocess Image' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:7' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", "\n", " `@task(name='my_unique_name', ...)`\n", " warnings.warn(\n", - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:337: UserWarning: A task named 'Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:20' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/tasks.py:332: UserWarning: A task named 'Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/task.py:20' conflicts with another task. Consider specifying a unique `name` parameter in the task definition:\n", "\n", " `@task(name='my_unique_name', ...)`\n", " warnings.warn(\n", - "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/flows.py:337: UserWarning: A flow named 'Execution: Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/action.py:36' conflicts with another flow. Consider specifying a unique `name` parameter in the flow definition:\n", + "/Users/raphaeldoukhan/Library/Caches/pypoetry/virtualenvs/giza-actions-mYf3m_Lk-py3.11/lib/python3.11/site-packages/prefect/flows.py:336: UserWarning: A flow named 'Execution: Prediction with Cairo' and defined at '/Users/raphaeldoukhan/Desktop/Orion-Giza/Tools/actions-sdk/giza_actions/action.py:36' conflicts with another flow. Consider specifying a unique `name` parameter in the flow definition:\n", "\n", " `@flow(name='my_unique_name', ...)`\n", " warnings.warn(\n" @@ -2008,11 +1985,11 @@ { "data": { "text/html": [ - "
11:08:42.535 | INFO    | Created flow run 'industrious-wrasse' for flow 'Execution: Prediction with Cairo'\n",
+       "
18:29:27.522 | INFO    | Created flow run 'shapeless-snail' for flow 'Execution: Prediction with Cairo'\n",
        "
\n" ], "text/plain": [ - "11:08:42.535 | \u001b[36mINFO\u001b[0m | Created flow run\u001b[35m 'industrious-wrasse'\u001b[0m for flow\u001b[1;35m 'Execution: Prediction with Cairo'\u001b[0m\n" + "18:29:27.522 | \u001b[36mINFO\u001b[0m | Created flow run\u001b[35m 'shapeless-snail'\u001b[0m for flow\u001b[1;35m 'Execution: Prediction with Cairo'\u001b[0m\n" ] }, "metadata": {}, @@ -2021,11 +1998,11 @@ { "data": { "text/html": [ - "
11:08:42.537 | INFO    | Action run 'industrious-wrasse' - View at https://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/40572d39-7a1f-46be-873f-4737742bff89\n",
+       "
18:29:27.523 | INFO    | Action run 'shapeless-snail' - View at https://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/5068c561-0a8c-4fcf-9bff-8cb7b2617d19\n",
        "
\n" ], "text/plain": [ - "11:08:42.537 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - View at \u001b[94mhttps://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/40572d39-7a1f-46be-873f-4737742bff89\u001b[0m\n" + "18:29:27.523 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - View at \u001b[94mhttps://actions-server-raphael-doukhan-dblzzhtf5q-ew.a.run.app/flow-runs/flow-run/5068c561-0a8c-4fcf-9bff-8cb7b2617d19\u001b[0m\n" ] }, "metadata": {}, @@ -2034,11 +2011,11 @@ { "data": { "text/html": [ - "
11:08:43.118 | INFO    | Action run 'industrious-wrasse' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n",
+       "
18:29:27.740 | INFO    | Action run 'shapeless-snail' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n",
        "
\n" ], "text/plain": [ - "11:08:43.118 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n" + "18:29:27.740 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Created task run 'Preprocess Image-0' for task 'Preprocess Image'\n" ] }, "metadata": {}, @@ -2047,11 +2024,11 @@ { "data": { "text/html": [ - "
11:08:43.121 | INFO    | Action run 'industrious-wrasse' - Executing 'Preprocess Image-0' immediately...\n",
+       "
18:29:27.742 | INFO    | Action run 'shapeless-snail' - Executing 'Preprocess Image-0' immediately...\n",
        "
\n" ], "text/plain": [ - "11:08:43.121 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Executing 'Preprocess Image-0' immediately...\n" + "18:29:27.742 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Executing 'Preprocess Image-0' immediately...\n" ] }, "metadata": {}, @@ -2060,11 +2037,11 @@ { "data": { "text/html": [ - "
11:08:43.632 | INFO    | Task run 'Preprocess Image-0' - Finished in state Completed()\n",
+       "
18:29:28.006 | INFO    | Task run 'Preprocess Image-0' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:43.632 | \u001b[36mINFO\u001b[0m | Task run 'Preprocess Image-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:28.006 | \u001b[36mINFO\u001b[0m | Task run 'Preprocess Image-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2073,11 +2050,11 @@ { "data": { "text/html": [ - "
11:08:43.749 | INFO    | Action run 'industrious-wrasse' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n",
+       "
18:29:28.112 | INFO    | Action run 'shapeless-snail' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n",
        "
\n" ], "text/plain": [ - "11:08:43.749 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n" + "18:29:28.112 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Created task run 'Prediction with Cairo-0' for task 'Prediction with Cairo'\n" ] }, "metadata": {}, @@ -2086,11 +2063,11 @@ { "data": { "text/html": [ - "
11:08:43.751 | INFO    | Action run 'industrious-wrasse' - Executing 'Prediction with Cairo-0' immediately...\n",
+       "
18:29:28.115 | INFO    | Action run 'shapeless-snail' - Executing 'Prediction with Cairo-0' immediately...\n",
        "
\n" ], "text/plain": [ - "11:08:43.751 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Executing 'Prediction with Cairo-0' immediately...\n" + "18:29:28.115 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Executing 'Prediction with Cairo-0' immediately...\n" ] }, "metadata": {}, @@ -2107,11 +2084,11 @@ { "data": { "text/html": [ - "
11:08:47.958 | INFO    | Task run 'Prediction with Cairo-0' - Finished in state Completed()\n",
+       "
18:29:32.420 | INFO    | Task run 'Prediction with Cairo-0' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:47.958 | \u001b[36mINFO\u001b[0m | Task run 'Prediction with Cairo-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:32.420 | \u001b[36mINFO\u001b[0m | Task run 'Prediction with Cairo-0' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2120,11 +2097,11 @@ { "data": { "text/html": [ - "
11:08:47.961 | INFO    | Action run 'industrious-wrasse' - Result:  tensor([0])\n",
+       "
18:29:32.421 | INFO    | Action run 'shapeless-snail' - Result:  tensor([0])\n",
        "
\n" ], "text/plain": [ - "11:08:47.961 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Result: tensor([0])\n" + "18:29:32.421 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Result: tensor([0])\n" ] }, "metadata": {}, @@ -2133,11 +2110,11 @@ { "data": { "text/html": [ - "
11:08:47.962 | INFO    | Action run 'industrious-wrasse' - Request id:  \"cd38a8593d2c429cb8c45f5e37939409\"\n",
+       "
18:29:32.422 | INFO    | Action run 'shapeless-snail' - Request id:  \"b2484bba4b5644df80ab7eed20f1c87b\"\n",
        "
\n" ], "text/plain": [ - "11:08:47.962 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Request id: \"cd38a8593d2c429cb8c45f5e37939409\"\n" + "18:29:32.422 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Request id: \"b2484bba4b5644df80ab7eed20f1c87b\"\n" ] }, "metadata": {}, @@ -2146,11 +2123,11 @@ { "data": { "text/html": [ - "
11:08:48.087 | INFO    | Action run 'industrious-wrasse' - Finished in state Completed()\n",
+       "
18:29:32.508 | INFO    | Action run 'shapeless-snail' - Finished in state Completed()\n",
        "
\n" ], "text/plain": [ - "11:08:48.087 | \u001b[36mINFO\u001b[0m | Action run 'industrious-wrasse' - Finished in state \u001b[32mCompleted\u001b[0m()\n" + "18:29:32.508 | \u001b[36mINFO\u001b[0m | Action run 'shapeless-snail' - Finished in state \u001b[32mCompleted\u001b[0m()\n" ] }, "metadata": {}, @@ -2159,10 +2136,10 @@ { "data": { "text/plain": [ - "(tensor([0]), '\"cd38a8593d2c429cb8c45f5e37939409\"')" + "(tensor([0]), '\"b2484bba4b5644df80ab7eed20f1c87b\"')" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -2174,6 +2151,7 @@ "MODEL_ID = 296 # Update with your model ID\n", "VERSION_ID = 1 # Update with your version ID\n", "\n", + "\n", "@task(name=f'Preprocess Image')\n", "def preprocess_image(image_path):\n", " from PIL import Image\n", @@ -2187,12 +2165,13 @@ " image = image.reshape(1, 196) # Reshape to (1, 196) for model input\n", " return image\n", "\n", + "\n", "@task(name=f'Prediction with Cairo')\n", "def prediction(image, model_id, version_id):\n", " model = GizaModel(id=model_id, version=version_id)\n", "\n", " (result, request_id) = model.predict(\n", - " input_feed={\"image\": image}, verifiable=True, output_dtype=\"Tensor\"\n", + " input_feed={\"image\": image}, verifiable=True\n", " )\n", "\n", " # Convert result to a PyTorch tensor\n", diff --git a/giza_actions/model.py b/giza_actions/model.py index acd9a96..143d501 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -5,6 +5,7 @@ import numpy as np import onnxruntime as ort +import onnx import requests from giza import API_HOST from giza.client import ApiClient, ModelsClient, VersionsClient @@ -50,13 +51,15 @@ def __init__( output_path: Optional[str] = None, ): if model_path is None and id is None and version is None: - raise ValueError("Either model_path or id and version must be provided.") + raise ValueError( + "Either model_path or id and version must be provided.") if model_path is None and (id is None or version is None): raise ValueError("Both id and version must be provided.") if model_path and (id or version): - raise ValueError("Either model_path or id and version must be provided.") + raise ValueError( + "Either model_path or id and version must be provided.") if model_path and id and version: raise ValueError( @@ -70,52 +73,83 @@ def __init__( self.version_client = VersionsClient(API_HOST) self.api_client = ApiClient(API_HOST) self._get_credentials() - self.version = self._get_version(id, version) - print(self.version) - self.session = None + self.model = self._get_model(id) + self.version = self._get_version(version) + self.session = self._set_session() self.framework = self.version.framework - self.uri = self._retrieve_uri(id, version) + self.uri = self._retrieve_uri(version) if output_path: - self._download_model(id, version, output_path) + self._download_model(output_path) - def _retrieve_uri(self, model_id: int, version_id: int): + def _retrieve_uri(self, version_id: int): """ Retrieves the URI for making prediction requests to a deployed model. Args: - model_id (int): The unique identifier of the model. version_id (int): The version number of the model. Returns: The URI for making prediction requests to the deployed model. """ # Different URI per framework - uri = get_endpoint_uri(model_id, version_id) + uri = get_endpoint_uri(self.model.id, version_id) if self.framework == Framework.CAIRO: return f"{uri}/cairo_run" else: return f"{uri}/predict" - def _get_version(self, model_id: int, version_id: int): + def _get_model(self, model_id: int): """ - Retrieves the version of the model specified by model_id and version_id. + Retrieves the model specified by model_id. Args: model_id (int): The unique identifier of the model. + + Returns: + The model. + """ + return self.model_client.get(model_id) + + def _get_version(self, version_id: int): + """ + Retrieves the version of the model specified by model id and version id. + + Args: version_id (int): The version number of the model. Returns: The version of the model. """ - return self.version_client.get(model_id, version_id) + return self.version_client.get(self.model.id, version_id) + + def _set_session(self): + """ + Set onnxruntime session for the model specified by model id. + + Raises: + ValueError: If the model version status is not completed. + """ + + if self.version.status != VersionStatus.COMPLETED: + raise ValueError( + f"Model version status is not completed {self.version.status}" + ) + + try: + onnx_model = self.version_client.download_original( + self.model.id, self.version.version) + + return ort.InferenceSession(onnx_model) + + except Exception as e: + print(f"Could not download model: {e}") + return None - def _download_model(self, model_id: int, version_id: int, output_path: str): + def _download_model(self, output_path: str): """ - Downloads the model specified by model_id and version_id to the given output_path. + Downloads the model specified by model id and version id to the given output_path. Args: - model_id (int): The unique identifier of the model. - version_id (int): The version number of the model. output_path (str): The file path where the downloaded model should be saved. Raises: @@ -127,18 +161,20 @@ def _download_model(self, model_id: int, version_id: int, output_path: str): f"Model version status is not completed {self.version.status}" ) + onnx_model = self.version_client.download_original( + self.model.id, self.version.version) + print("ONNX model is ready, downloading! ✅") - onnx_model = self.api_client.download_original(model_id, self.version.version) - model_name = self.version.original_model_path.split("/")[-1] - save_path = Path(output_path) / model_name + if ".onnx" in output_path: + save_path = Path(output_path) + else: + save_path = Path(f"{output_path}/{self.model.name}.onnx") with open(save_path, "wb") as f: f.write(onnx_model) - print(f"ONNX model saved at: {save_path}") - self.session = ort.InferenceSession(save_path) - print("Model ready for inference with ONNX Runtime! ✅") + print(f"ONNX model saved at: {save_path} ✅") def _get_credentials(self): """ @@ -153,7 +189,7 @@ def predict( input_feed: Optional[Dict] = None, verifiable: bool = False, fp_impl="FP16x16", - output_dtype: str = "tensor_fixed_point", + custom_output_dtype: Optional[str] = None, job_size: str = "M", ): """ @@ -165,7 +201,7 @@ def predict( input_feed (Optional[Dict]): A dictionary containing the input data for prediction. Defaults to None. verifiable (bool): A flag indicating whether to use the verifiable computation endpoint. Defaults to False. fp_impl (str): The fixed point implementation to use, when computed in verifiable mode. Defaults to "FP16x16". - output_dtype (str): The data type of the result when computed in verifiable mode. Defaults to "tensor_fixed_point". + custom_output_dtype (Optional[str]): Specify the data type of the result when computed in verifiable mode. Defaults to None. Returns: A tuple (predictions, request_id) where predictions is the result of the prediction and request_id @@ -209,7 +245,13 @@ def predict( if self.framework == Framework.CAIRO: logging.info("Serialized: ", serialized_output) - preds = self._parse_cairo_response(serialized_output, output_dtype) + if custom_output_dtype is None: + output_dtype = self._get_output_dtype() + else: + output_dtype = custom_output_dtype + + preds = self._parse_cairo_response( + serialized_output, output_dtype) elif self.framework == Framework.EZKL: preds = np.array(serialized_output[0]) return (preds, request_id) @@ -318,3 +360,40 @@ def _parse_cairo_response(self, response, data_type: str): The deserialized prediction result. """ return deserialize(response, data_type) + + def _get_output_dtype(self): + """ + Retrieve the Cairo output data type base on the operator type of the final node. + + Returns: + The output dtype as a string. + """ + + file = self.version_client.download_original( + self.model.id, self.version.version + ) + + model = onnx.load_model_from_string(file) + graph = model.graph + output_tensor_name = graph.output[0].name + + def find_producing_node(graph, tensor_name): + for node in graph.node: + if tensor_name in node.output: + return node + return None + + final_node = find_producing_node(graph, output_tensor_name) + optype = final_node.op_type + + match optype: + case "TreeEnsembleClassifier": + return "(Span, MutMatrix)" + + case "TreeEnsembleRegressor": + return "MutMatrix::" + + case "LinearClassifier": + return "(Span, Tensor)" + case _: + return "Tensor"