-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implemented a cache for already downloaded models #39
Changes from 3 commits
2843ce3
3feb5b7
4f81a5a
6e19f06
3bf5584
1738e47
d5c71c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import Dict, Optional | ||
|
||
import numpy as np | ||
import onnx | ||
import onnxruntime as ort | ||
import requests | ||
from diskcache import Cache | ||
from giza import API_HOST | ||
from giza.client import ApiClient, EndpointsClient, ModelsClient, VersionsClient | ||
from giza.utils.enums import Framework, VersionStatus | ||
|
@@ -79,12 +81,14 @@ def __init__( | |
self._get_credentials() | ||
self.model = self._get_model(id) | ||
self.version = self._get_version(version) | ||
self.session = self._set_session() | ||
self.framework = self.version.framework | ||
self.uri = self._retrieve_uri() | ||
self.endpoint_id = self._get_endpoint_id() | ||
if output_path: | ||
self.session = self._set_session(output_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if output_path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having to pass
An example of the init could be: def __init__(
self,
model_path: Optional[str] = None,
id: Optional[int] = None,
version: Optional[int] = None,
output_path: Optional[str] = None,
):
...
if model_path:
self.session = ort.InferenceSession(model_path)
elif id and version:
self.model_id = id
self.version_id = version
self.model_client = ModelsClient(API_HOST)
self.version_client = VersionsClient(API_HOST)
self.api_client = ApiClient(API_HOST)
self.endpoints_client = EndpointsClient(API_HOST)
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path is not None:
self._output_path = output_path
else:
self._output_path = os.path.join(tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name})
# Now this internally uses self._output_path
# As we are using the cache hitting this function should not be problematic
self._download_model() |
||
self._download_model(output_path) | ||
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cache should be initialized before This is making the previous two existing tests to fail. |
||
|
||
def _get_endpoint_id(self): | ||
""" | ||
|
@@ -149,7 +153,7 @@ def _get_version(self, version_id: int): | |
""" | ||
return self.version_client.get(self.model.id, version_id) | ||
|
||
def _set_session(self): | ||
def _set_session(self, output_path: str): | ||
""" | ||
Set onnxruntime session for the model specified by model id. | ||
|
||
|
@@ -163,9 +167,13 @@ def _set_session(self): | |
) | ||
|
||
try: | ||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this and use |
||
self._download_model(output_path) | ||
|
||
if cache_str in self._cache: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the proposed changes if self._output_path in self._cache:
file_path = Path(self._cache.get(self._output_path))
with open(file_path, "rb") as f:
onnx_model = f.read() |
||
file_path = Path(self._cache.get(cache_str)) | ||
with open(file_path, "rb") as f: | ||
onnx_model = f.read() | ||
|
||
return ort.InferenceSession(onnx_model) | ||
|
||
|
@@ -189,21 +197,28 @@ def _download_model(self, output_path: str): | |
f"Model version status is not completed {self.version.status}" | ||
) | ||
|
||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As in the previous comment, let's remove this and use self._output_path so cache keys are more consistent |
||
|
||
logger.info("ONNX model is ready, downloading! ✅") | ||
if cache_str not in self._cache: | ||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
|
||
if ".onnx" in output_path: | ||
save_path = Path(output_path) | ||
else: | ||
save_path = Path(f"{output_path}/{self.model.name}.onnx") | ||
logger.info("ONNX model is ready, downloading! ✅") | ||
|
||
if ".onnx" in output_path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the proposed changes this should be |
||
save_path = Path(output_path) | ||
else: | ||
save_path = Path(f"{output_path}/{self.model.name}.onnx") | ||
|
||
with open(save_path, "wb") as f: | ||
f.write(onnx_model) | ||
|
||
with open(save_path, "wb") as f: | ||
f.write(onnx_model) | ||
self._cache[cache_str] = save_path | ||
|
||
logger.info(f"ONNX model saved at: {save_path} ✅") | ||
logger.info(f"ONNX model saved at: {save_path} ✅") | ||
else: | ||
logger.info(f"ONNX model already downloaded at: {output_path} ✅.") | ||
|
||
def _get_credentials(self): | ||
""" | ||
|
@@ -221,6 +236,7 @@ def predict( | |
custom_output_dtype: Optional[str] = None, | ||
job_size: str = "M", | ||
dry_run: bool = False, | ||
output_path: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this as it is not needed |
||
): | ||
""" | ||
Makes a prediction using either a local ONNX session or a remote deployed model, depending on the | ||
|
@@ -272,7 +288,7 @@ def predict( | |
logger.info("Serialized: %s", serialized_output) | ||
|
||
if custom_output_dtype is None: | ||
output_dtype = self._get_output_dtype() | ||
output_dtype = self._get_output_dtype(output_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be as it was before |
||
else: | ||
output_dtype = custom_output_dtype | ||
|
||
|
@@ -388,17 +404,21 @@ def _parse_cairo_response(self, response, data_type: str): | |
""" | ||
return deserialize(response, data_type) | ||
|
||
def _get_output_dtype(self): | ||
def _get_output_dtype(self, output_path: str): | ||
""" | ||
Retrieve the Cairo output data type base on the operator type of the final node. | ||
|
||
Returns: | ||
The output dtype as a string. | ||
""" | ||
|
||
file = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the proposed changes this should be |
||
self._download_model(output_path) | ||
|
||
if cache_str in self._cache: | ||
file_path = Path(self._cache.get(cache_str)) | ||
with open(file_path, "rb") as f: | ||
file = f.read() | ||
|
||
model = onnx.load_model_from_string(file) | ||
graph = model.graph | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
if
will make the prediction fail whenverifiable=False
if it is not provided, making it mandatory which is not what we are aiming for.