From 3feb5b77e26dcedb69438bc58f6b670fc7b21283 Mon Sep 17 00:00:00 2001 From: shivam6862 Date: Mon, 22 Apr 2024 23:45:19 +0530 Subject: [PATCH] fixed errors --- giza_actions/model.py | 54 ++++++++++++++++++------------------------- pyproject.toml | 2 +- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/giza_actions/model.py b/giza_actions/model.py index d0cb9d1..a8f0ceb 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -1,13 +1,13 @@ import logging +import os from pathlib import Path from typing import Dict, Optional -from diskcache import Cache -import os import numpy as np import onnx import onnxruntime as ort import requests +from diskcache import Cache from giza import API_HOST from giza.client import ApiClient, EndpointsClient, ModelsClient, VersionsClient from giza.utils.enums import Framework, VersionStatus @@ -56,15 +56,13 @@ def __init__( output_path: Optional[str] = None, ): if model_path is None and id is None and version is None: - raise ValueError( - "Either model_path or id and version must be provided.") + raise ValueError("Either model_path or id and version must be provided.") if model_path is None and (id is None or version is None): raise ValueError("Both id and version must be provided.") if model_path and (id or version): - raise ValueError( - "Either model_path or id and version must be provided.") + raise ValueError("Either model_path or id and version must be provided.") if model_path and id and version: raise ValueError( @@ -83,13 +81,14 @@ def __init__( self._get_credentials() self.model = self._get_model(id) self.version = self._get_version(version) - self.session = self._set_session() self.framework = self.version.framework self.uri = self._retrieve_uri() self.endpoint_id = self._get_endpoint_id() + if output_path: + self.session = self._set_session(output_path) if output_path: self._download_model(output_path) - self.cache = Cache(os.getcwd() + '/tmp/cachedir') + self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) def _get_endpoint_id(self): """ @@ -154,7 +153,7 @@ def _get_version(self, version_id: int): """ return self.version_client.get(self.model.id, version_id) - def _set_session(self): + def _set_session(self, output_path: str): """ Set onnxruntime session for the model specified by model id. @@ -169,15 +168,12 @@ def _set_session(self): try: cache_str = f"{self.model.id}_{self.version.version}_model" - if cache_str in self.cache: - file_path = self.cache.get(cache_str) - file_path = Path(file_path) + self._download_model(output_path) + + if cache_str in self._cache: + file_path = Path(self._cache.get(cache_str)) with open(file_path, "rb") as f: onnx_model = f.read() - else: - onnx_model = self.version_client.download_original( - self.model.id, self.version.version - ) return ort.InferenceSession(onnx_model) @@ -203,7 +199,7 @@ def _download_model(self, output_path: str): cache_str = f"{self.model.id}_{self.version.version}_model" - if cache_str not in self.cache: + if cache_str not in self._cache: onnx_model = self.version_client.download_original( self.model.id, self.version.version ) @@ -218,7 +214,7 @@ def _download_model(self, output_path: str): with open(save_path, "wb") as f: f.write(onnx_model) - self.cache[cache_str] = save_path + self._cache[cache_str] = save_path logger.info(f"ONNX model saved at: {save_path} ✅") else: @@ -240,6 +236,7 @@ def predict( custom_output_dtype: Optional[str] = None, job_size: str = "M", dry_run: bool = False, + output_path: Optional[str] = None, ): """ Makes a prediction using either a local ONNX session or a remote deployed model, depending on the @@ -279,8 +276,7 @@ def predict( response.raise_for_status() except requests.exceptions.HTTPError as e: logger.error(f"An error occurred in predict: {e}") - error_message = f"Deployment predict error: { - response.text}" + error_message = f"Deployment predict error: {response.text}" logger.error(error_message) raise e @@ -292,13 +288,12 @@ def predict( logger.info("Serialized: %s", serialized_output) if custom_output_dtype is None: - output_dtype = self._get_output_dtype() + output_dtype = self._get_output_dtype(output_path) else: output_dtype = custom_output_dtype logger.debug("Output dtype: %s", output_dtype) - preds = self._parse_cairo_response( - serialized_output, output_dtype) + preds = self._parse_cairo_response(serialized_output, output_dtype) elif self.framework == Framework.EZKL: preds = np.array(serialized_output[0]) return (preds, request_id) @@ -409,7 +404,7 @@ def _parse_cairo_response(self, response, data_type: str): """ return deserialize(response, data_type) - def _get_output_dtype(self): + def _get_output_dtype(self, output_path: str): """ Retrieve the Cairo output data type base on the operator type of the final node. @@ -418,15 +413,12 @@ def _get_output_dtype(self): """ cache_str = f"{self.model.id}_{self.version.version}_model" - if cache_str in self.cache: - file_path = self.cache.get(cache_str) - file_path = Path(file_path) + self._download_model(output_path) + + if cache_str in self._cache: + file_path = Path(self._cache.get(cache_str)) with open(file_path, "rb") as f: file = f.read() - else: - file = self.version_client.download_original( - self.model.id, self.version.version - ) model = onnx.load_model_from_string(file) graph = model.graph diff --git a/pyproject.toml b/pyproject.toml index 31a1435..dd710eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.11,<4.0" -diskcache == "5.6.3" +diskcache = "^5.6.3" numpy = "^1.26.2" prefect = "2.14.6" onnx = "^1.15.0"