-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implemented a cache for already downloaded models #39
Changes from 1 commit
2843ce3
3feb5b7
4f81a5a
6e19f06
3bf5584
1738e47
d5c71c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import logging | ||
from pathlib import Path | ||
from typing import Dict, Optional | ||
from diskcache import Cache | ||
import os | ||
|
||
import numpy as np | ||
import onnx | ||
|
@@ -54,13 +56,15 @@ def __init__( | |
output_path: Optional[str] = None, | ||
): | ||
if model_path is None and id is None and version is None: | ||
raise ValueError("Either model_path or id and version must be provided.") | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you changing this formatting? This is handled by |
||
"Either model_path or id and version must be provided.") | ||
|
||
if model_path is None and (id is None or version is None): | ||
raise ValueError("Both id and version must be provided.") | ||
|
||
if model_path and (id or version): | ||
raise ValueError("Either model_path or id and version must be provided.") | ||
raise ValueError( | ||
"Either model_path or id and version must be provided.") | ||
|
||
if model_path and id and version: | ||
raise ValueError( | ||
|
@@ -85,6 +89,7 @@ def __init__( | |
self.endpoint_id = self._get_endpoint_id() | ||
if output_path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
self._download_model(output_path) | ||
self.cache = Cache(os.getcwd() + '/tmp/cachedir') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would also make it "private" as this cache is not something that we want the user to use directly. self._cache = Cache(...) |
||
|
||
def _get_endpoint_id(self): | ||
""" | ||
|
@@ -163,9 +168,16 @@ def _set_session(self): | |
) | ||
|
||
try: | ||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this and use |
||
if cache_str in self.cache: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this approach for the cache very much 🚀 |
||
file_path = self.cache.get(cache_str) | ||
file_path = Path(file_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that this is not needed, as file_path = Path(self.cache.get(cache_str)) |
||
with open(file_path, "rb") as f: | ||
onnx_model = f.read() | ||
else: | ||
onnx_model = self.version_client.download_original( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's change this to use |
||
self.model.id, self.version.version | ||
) | ||
|
||
return ort.InferenceSession(onnx_model) | ||
|
||
|
@@ -189,21 +201,28 @@ def _download_model(self, output_path: str): | |
f"Model version status is not completed {self.version.status}" | ||
) | ||
|
||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As in the previous comment, let's remove this and use self._output_path so cache keys are more consistent |
||
|
||
if cache_str not in self.cache: | ||
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
|
||
logger.info("ONNX model is ready, downloading! ✅") | ||
logger.info("ONNX model is ready, downloading! ✅") | ||
|
||
if ".onnx" in output_path: | ||
save_path = Path(output_path) | ||
else: | ||
save_path = Path(f"{output_path}/{self.model.name}.onnx") | ||
if ".onnx" in output_path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the proposed changes this should be |
||
save_path = Path(output_path) | ||
else: | ||
save_path = Path(f"{output_path}/{self.model.name}.onnx") | ||
|
||
with open(save_path, "wb") as f: | ||
f.write(onnx_model) | ||
|
||
with open(save_path, "wb") as f: | ||
f.write(onnx_model) | ||
self.cache[cache_str] = save_path | ||
|
||
logger.info(f"ONNX model saved at: {save_path} ✅") | ||
logger.info(f"ONNX model saved at: {save_path} ✅") | ||
else: | ||
logger.info(f"ONNX model already downloaded at: {output_path} ✅.") | ||
|
||
def _get_credentials(self): | ||
""" | ||
|
@@ -260,7 +279,8 @@ def predict( | |
response.raise_for_status() | ||
except requests.exceptions.HTTPError as e: | ||
logger.error(f"An error occurred in predict: {e}") | ||
error_message = f"Deployment predict error: {response.text}" | ||
error_message = f"Deployment predict error: { | ||
response.text}" | ||
logger.error(error_message) | ||
raise e | ||
|
||
|
@@ -277,7 +297,8 @@ def predict( | |
output_dtype = custom_output_dtype | ||
|
||
logger.debug("Output dtype: %s", output_dtype) | ||
preds = self._parse_cairo_response(serialized_output, output_dtype) | ||
preds = self._parse_cairo_response( | ||
serialized_output, output_dtype) | ||
elif self.framework == Framework.EZKL: | ||
preds = np.array(serialized_output[0]) | ||
return (preds, request_id) | ||
|
@@ -396,9 +417,16 @@ def _get_output_dtype(self): | |
The output dtype as a string. | ||
""" | ||
|
||
file = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the proposed changes this should be |
||
if cache_str in self.cache: | ||
file_path = self.cache.get(cache_str) | ||
file_path = Path(file_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the previous comment, could be the string only or make a single line to create the path |
||
with open(file_path, "rb") as f: | ||
file = f.read() | ||
else: | ||
file = self.version_client.download_original( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should change this to use Lets use |
||
self.model.id, self.version.version | ||
) | ||
|
||
model = onnx.load_model_from_string(file) | ||
graph = model.graph | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ license = "MIT" | |
|
||
[tool.poetry.dependencies] | ||
python = ">=3.11,<4.0" | ||
diskcache == "5.6.3" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not directly pin this dependency, make sure to add it just running poetry add diskcache There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it worked by degrading the Python version to Python 3.11.9 |
||
numpy = "^1.26.2" | ||
prefect = "2.14.6" | ||
onnx = "^1.15.0" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
import os
shouldn't be placed here.This should have been handle by pre-commit, make sure that you have it installed and run
pre-commit run --files giza_actions/model.py