-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implemented a cache for already downloaded models #39
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty well usage of diskcache
!
There are some things that we need to change and also review the formatting changes that should have been handled prior to the commit by black
.
Also, I think that we could move this logic into a class in giza_actions/utils.py
and just initialise this class on the GizaModel.__init__()
method.
With a specific class that handles the download and the cache we can create tests for it without the complexity of GizaModel
giza_actions/model.py
Outdated
@@ -1,6 +1,8 @@ | |||
import logging | |||
from pathlib import Path | |||
from typing import Dict, Optional | |||
from diskcache import Cache | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import os
shouldn't be placed here.
This should have been handle by pre-commit, make sure that you have it installed and run pre-commit run --files giza_actions/model.py
giza_actions/model.py
Outdated
@@ -54,13 +56,15 @@ def __init__( | |||
output_path: Optional[str] = None, | |||
): | |||
if model_path is None and id is None and version is None: | |||
raise ValueError("Either model_path or id and version must be provided.") | |||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you changing this formatting? This is handled by black
and every commit, should trigger it through pre-commit
giza_actions/model.py
Outdated
@@ -85,6 +89,7 @@ def __init__( | |||
self.endpoint_id = self._get_endpoint_id() | |||
if output_path: | |||
self._download_model(output_path) | |||
self.cache = Cache(os.getcwd() + '/tmp/cachedir') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use os.path.join
or pathlib
as the separator depends on the OS ( /
unix, \
for windows)
giza_actions/model.py
Outdated
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
if cache_str in self.cache: | ||
file_path = self.cache.get(cache_str) | ||
file_path = Path(file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that this is not needed, as open
accepts strings as well or you could just do it directly:
file_path = Path(self.cache.get(cache_str))
giza_actions/model.py
Outdated
@@ -85,6 +89,7 @@ def __init__( | |||
self.endpoint_id = self._get_endpoint_id() | |||
if output_path: | |||
self._download_model(output_path) | |||
self.cache = Cache(os.getcwd() + '/tmp/cachedir') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also make it "private" as this cache is not something that we want the user to use directly.
self._cache = Cache(...)
giza_actions/model.py
Outdated
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
if cache_str in self.cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this approach for the cache very much 🚀
giza_actions/model.py
Outdated
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
if cache_str in self.cache: | ||
file_path = self.cache.get(cache_str) | ||
file_path = Path(file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the previous comment, could be the string only or make a single line to create the path
pyproject.toml
Outdated
@@ -13,6 +13,7 @@ license = "MIT" | |||
|
|||
[tool.poetry.dependencies] | |||
python = ">=3.11,<4.0" | |||
diskcache == "5.6.3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not directly pin this dependency, make sure to add it just running
poetry add diskcache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it worked by degrading the Python version to Python 3.11.9
giza_actions/model.py
Outdated
with open(file_path, "rb") as f: | ||
onnx_model = f.read() | ||
else: | ||
onnx_model = self.version_client.download_original( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change this to use _download_model
and the read it, so we have the benefits of the cache.
giza_actions/model.py
Outdated
with open(file_path, "rb") as f: | ||
file = f.read() | ||
else: | ||
file = self.version_client.download_original( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change this to use _download_model
which handles the usage of the cache if it does not exist. Currently if this is used multiple times the model would be download each time as is not reflected on the cache.
Lets use _download_model
and the read.
@Gonmeso @raphaelDkhn requested changes have been made |
giza_actions/model.py
Outdated
self.framework = self.version.framework | ||
self.uri = self._retrieve_uri() | ||
self.endpoint_id = self._get_endpoint_id() | ||
if output_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if
will make the prediction fail when verifiable=False
if it is not provided, making it mandatory which is not what we are aiming for.
giza_actions/model.py
Outdated
self.framework = self.version.framework | ||
self.uri = self._retrieve_uri() | ||
self.endpoint_id = self._get_endpoint_id() | ||
if output_path: | ||
self.session = self._set_session(output_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_path
shouldn't be a mandatory argument for set_session
giza_actions/model.py
Outdated
self.framework = self.version.framework | ||
self.uri = self._retrieve_uri() | ||
self.endpoint_id = self._get_endpoint_id() | ||
if output_path: | ||
self.session = self._set_session(output_path) | ||
if output_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having to pass output_path
to _download_model
and _get_output_dtype
will mostly make output_path
to be mandatory, but this is not the intention so another path where if output_path
is not provided we handle it for the user:
- Lets make the class to have an attribute
self._output_path
- If
output_path
is present when the user creates the instance, we will use that path, if not we will handle that - Remove
output_path
argument from_download_model
and useself._output_path
- Remove
output_path
from_set_session
and_get_output_dtype
as it is not needed any more
An example of the init could be:
def __init__(
self,
model_path: Optional[str] = None,
id: Optional[int] = None,
version: Optional[int] = None,
output_path: Optional[str] = None,
):
...
if model_path:
self.session = ort.InferenceSession(model_path)
elif id and version:
self.model_id = id
self.version_id = version
self.model_client = ModelsClient(API_HOST)
self.version_client = VersionsClient(API_HOST)
self.api_client = ApiClient(API_HOST)
self.endpoints_client = EndpointsClient(API_HOST)
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path is not None:
self._output_path = output_path
else:
self._output_path = os.path.join(tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name})
# Now this internally uses self._output_path
# As we are using the cache hitting this function should not be problematic
self._download_model()
giza_actions/model.py
Outdated
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this and use self._output_path
so cache keys are more consistent
giza_actions/model.py
Outdated
cache_str = f"{self.model.id}_{self.version.version}_model" | ||
self._download_model(output_path) | ||
|
||
if cache_str in self._cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the proposed changes
if self._output_path in self._cache:
file_path = Path(self._cache.get(self._output_path))
with open(file_path, "rb") as f:
onnx_model = f.read()
giza_actions/model.py
Outdated
onnx_model = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in the previous comment, let's remove this and use self._output_path so cache keys are more consistent
giza_actions/model.py
Outdated
save_path = Path(f"{output_path}/{self.model.name}.onnx") | ||
logger.info("ONNX model is ready, downloading! ✅") | ||
|
||
if ".onnx" in output_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the proposed changes this should be self._output_path
giza_actions/model.py
Outdated
@@ -221,6 +236,7 @@ def predict( | |||
custom_output_dtype: Optional[str] = None, | |||
job_size: str = "M", | |||
dry_run: bool = False, | |||
output_path: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this as it is not needed
giza_actions/model.py
Outdated
@@ -272,7 +288,7 @@ def predict( | |||
logger.info("Serialized: %s", serialized_output) | |||
|
|||
if custom_output_dtype is None: | |||
output_dtype = self._get_output_dtype() | |||
output_dtype = self._get_output_dtype(output_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be as it was before
giza_actions/model.py
Outdated
file = self.version_client.download_original( | ||
self.model.id, self.version.version | ||
) | ||
cache_str = f"{self.model.id}_{self.version.version}_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the proposed changes this should be self._output_path
Testing is still missing, that its why I proposed creating a cache class to test it easily. |
@Gonmeso @raphaelDkhn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really close to be done!
We just need to fix the tests and this would be ready, as CI is failing. Try to make sure that everything is working previously by running pytest
giza_actions/model.py
Outdated
f"{self.model_id}_{self.version_id}_{self.model.name}", | ||
) | ||
self._download_model() | ||
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cache should be initialized before _set_session
and download_model
as if the user does not provide output_path
set session will try to hit self._cache
which has not been initialized.
This is making the previous two existing tests to fail.
@patch("giza_actions.model.GizaModel._get_output_dtype") | ||
@patch("giza_actions.model.GizaModel._retrieve_uri") | ||
@patch("giza_actions.model.GizaModel._get_endpoint_id", return_value=1) | ||
def test_cache_implementation(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is failing as some mocks are missing, the current error is due to missing credentials, for that just add:
@patch("giza_actions.model.GizaModel._get_credentials")
def test_cache_implementation(*args):
Also it is possible that some patches are missing as well, for example for _get_model
and _get_version
might be need.
@Gonmeso Added Patch for test and made required changes to pass CI, |
@Gonmeso The CI failed due to "Exception: Token expired" |
Hi! One little note, making the changes and just waiting CI to pass will make this process longer, it is encouraged to run the tests locally and then push the changes when they pass locally. The error is the following: /home/runner/work/actions-sdk/actions-sdk/giza_actions/model.py:206: in _download_model
onnx_model = self.version_client.download_original(
E Exception: Token expired or not set. API Key not available. Log in again.
self = <giza_actions.model.GizaModel object at 0x7fcfd48[65](https://github.com/gizatechxyz/actions-sdk/actions/runs/8814392110/job/24195017336?pr=39#step:6:66)890> Here we can see that With the recent changes To solve this we could start by patching the self.model_client = ModelsClient(API_HOST)
self.version_client = VersionsClient(API_HOST) # < This one
self.api_client = ApiClient(API_HOST)
self.endpoints_client = EndpointsClient(API_HOST) In order to patch this, we need to patch the imported client from the script (docs here: https://docs.python.org/3/library/unittest.mock.html#where-to-patch): @patch("giza_actions.model.VersionsClient", return_value=b"some bytes") We add a returned value of bytes because it is what Hope this helps @shivam6862 |
@Gonmeso Added all test case are passed |
@Gonmeso done for all file run pre-commit run --all-files which was done previously for .gitignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raphaelDkhn
Implemented a cache for already downloaded models by using the diskcache package.
Closes #37