Skip to content

Commit

Permalink
fixed errors
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam6862 committed Apr 22, 2024
1 parent 2843ce3 commit 3feb5b7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
54 changes: 23 additions & 31 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import os
from pathlib import Path
from typing import Dict, Optional
from diskcache import Cache
import os

import numpy as np
import onnx
import onnxruntime as ort
import requests
from diskcache import Cache
from giza import API_HOST
from giza.client import ApiClient, EndpointsClient, ModelsClient, VersionsClient
from giza.utils.enums import Framework, VersionStatus
Expand Down Expand Up @@ -56,15 +56,13 @@ def __init__(
output_path: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path and id and version:
raise ValueError(
Expand All @@ -83,13 +81,14 @@ def __init__(
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
self.session = self._set_session(output_path)
if output_path:
self._download_model(output_path)
self.cache = Cache(os.getcwd() + '/tmp/cachedir')
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))

def _get_endpoint_id(self):
"""
Expand Down Expand Up @@ -154,7 +153,7 @@ def _get_version(self, version_id: int):
"""
return self.version_client.get(self.model.id, version_id)

def _set_session(self):
def _set_session(self, output_path: str):
"""
Set onnxruntime session for the model specified by model id.
Expand All @@ -169,15 +168,12 @@ def _set_session(self):

try:
cache_str = f"{self.model.id}_{self.version.version}_model"
if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
self._download_model(output_path)

if cache_str in self._cache:
file_path = Path(self._cache.get(cache_str))
with open(file_path, "rb") as f:
onnx_model = f.read()
else:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)

return ort.InferenceSession(onnx_model)

Expand All @@ -203,7 +199,7 @@ def _download_model(self, output_path: str):

cache_str = f"{self.model.id}_{self.version.version}_model"

if cache_str not in self.cache:
if cache_str not in self._cache:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
Expand All @@ -218,7 +214,7 @@ def _download_model(self, output_path: str):
with open(save_path, "wb") as f:
f.write(onnx_model)

self.cache[cache_str] = save_path
self._cache[cache_str] = save_path

logger.info(f"ONNX model saved at: {save_path} ✅")
else:
Expand All @@ -240,6 +236,7 @@ def predict(
custom_output_dtype: Optional[str] = None,
job_size: str = "M",
dry_run: bool = False,
output_path: Optional[str] = None,
):
"""
Makes a prediction using either a local ONNX session or a remote deployed model, depending on the
Expand Down Expand Up @@ -279,8 +276,7 @@ def predict(
response.raise_for_status()
except requests.exceptions.HTTPError as e:
logger.error(f"An error occurred in predict: {e}")
error_message = f"Deployment predict error: {
response.text}"
error_message = f"Deployment predict error: {response.text}"
logger.error(error_message)
raise e

Expand All @@ -292,13 +288,12 @@ def predict(
logger.info("Serialized: %s", serialized_output)

if custom_output_dtype is None:
output_dtype = self._get_output_dtype()
output_dtype = self._get_output_dtype(output_path)
else:
output_dtype = custom_output_dtype

logger.debug("Output dtype: %s", output_dtype)
preds = self._parse_cairo_response(
serialized_output, output_dtype)
preds = self._parse_cairo_response(serialized_output, output_dtype)
elif self.framework == Framework.EZKL:
preds = np.array(serialized_output[0])
return (preds, request_id)
Expand Down Expand Up @@ -409,7 +404,7 @@ def _parse_cairo_response(self, response, data_type: str):
"""
return deserialize(response, data_type)

def _get_output_dtype(self):
def _get_output_dtype(self, output_path: str):
"""
Retrieve the Cairo output data type base on the operator type of the final node.
Expand All @@ -418,15 +413,12 @@ def _get_output_dtype(self):
"""

cache_str = f"{self.model.id}_{self.version.version}_model"
if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
self._download_model(output_path)

if cache_str in self._cache:
file_path = Path(self._cache.get(cache_str))
with open(file_path, "rb") as f:
file = f.read()
else:
file = self.version_client.download_original(
self.model.id, self.version.version
)

model = onnx.load_model_from_string(file)
graph = model.graph
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.11,<4.0"
diskcache == "5.6.3"
diskcache = "^5.6.3"
numpy = "^1.26.2"
prefect = "2.14.6"
onnx = "^1.15.0"
Expand Down

0 comments on commit 3feb5b7

Please sign in to comment.