Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented a cache for already downloaded models #39

Merged
merged 7 commits into from
Apr 25, 2024

Conversation

shivam6862
Copy link
Contributor

@shivam6862 shivam6862 commented Apr 22, 2024

@raphaelDkhn
Implemented a cache for already downloaded models by using the diskcache package.
Closes #37

Copy link
Contributor

@Gonmeso Gonmeso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty well usage of diskcache!

There are some things that we need to change and also review the formatting changes that should have been handled prior to the commit by black.

Also, I think that we could move this logic into a class in giza_actions/utils.py and just initialise this class on the GizaModel.__init__() method.

With a specific class that handles the download and the cache we can create tests for it without the complexity of GizaModel

@@ -1,6 +1,8 @@
import logging
from pathlib import Path
from typing import Dict, Optional
from diskcache import Cache
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import os shouldn't be placed here.

This should have been handle by pre-commit, make sure that you have it installed and run pre-commit run --files giza_actions/model.py

@@ -54,13 +56,15 @@ def __init__(
output_path: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError("Either model_path or id and version must be provided.")
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing this formatting? This is handled by black and every commit, should trigger it through pre-commit

@@ -85,6 +89,7 @@ def __init__(
self.endpoint_id = self._get_endpoint_id()
if output_path:
self._download_model(output_path)
self.cache = Cache(os.getcwd() + '/tmp/cachedir')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use os.path.join or pathlib as the separator depends on the OS ( / unix, \ for windows)

cache_str = f"{self.model.id}_{self.version.version}_model"
if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this is not needed, as open accepts strings as well or you could just do it directly:

file_path = Path(self.cache.get(cache_str))

@@ -85,6 +89,7 @@ def __init__(
self.endpoint_id = self._get_endpoint_id()
if output_path:
self._download_model(output_path)
self.cache = Cache(os.getcwd() + '/tmp/cachedir')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also make it "private" as this cache is not something that we want the user to use directly.

self._cache = Cache(...)

self.model.id, self.version.version
)
cache_str = f"{self.model.id}_{self.version.version}_model"
if cache_str in self.cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach for the cache very much 🚀

cache_str = f"{self.model.id}_{self.version.version}_model"
if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the previous comment, could be the string only or make a single line to create the path

pyproject.toml Outdated
@@ -13,6 +13,7 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.11,<4.0"
diskcache == "5.6.3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not directly pin this dependency, make sure to add it just running

poetry add diskcache

Copy link
Contributor Author

@shivam6862 shivam6862 Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it worked by degrading the Python version to Python 3.11.9

with open(file_path, "rb") as f:
onnx_model = f.read()
else:
onnx_model = self.version_client.download_original(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change this to use _download_model and the read it, so we have the benefits of the cache.

with open(file_path, "rb") as f:
file = f.read()
else:
file = self.version_client.download_original(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change this to use _download_model which handles the usage of the cache if it does not exist. Currently if this is used multiple times the model would be download each time as is not reflected on the cache.

Lets use _download_model and the read.

@shivam6862
Copy link
Contributor Author

shivam6862 commented Apr 22, 2024

@Gonmeso @raphaelDkhn requested changes have been made

self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if will make the prediction fail when verifiable=False if it is not provided, making it mandatory which is not what we are aiming for.

self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
self.session = self._set_session(output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_path shouldn't be a mandatory argument for set_session

self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
self.session = self._set_session(output_path)
if output_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having to pass output_path to _download_model and _get_output_dtype will mostly make output_path to be mandatory, but this is not the intention so another path where if output_path is not provided we handle it for the user:

  • Lets make the class to have an attribute self._output_path
  • If output_path is present when the user creates the instance, we will use that path, if not we will handle that
  • Remove output_path argument from _download_model and use self._output_path
  • Remove output_path from _set_session and _get_output_dtype as it is not needed any more

An example of the init could be:

    def __init__(
        self,
        model_path: Optional[str] = None,
        id: Optional[int] = None,
        version: Optional[int] = None,
        output_path: Optional[str] = None,
    ):
       ...

        if model_path:
            self.session = ort.InferenceSession(model_path)
        elif id and version:
            self.model_id = id
            self.version_id = version
            self.model_client = ModelsClient(API_HOST)
            self.version_client = VersionsClient(API_HOST)
            self.api_client = ApiClient(API_HOST)
            self.endpoints_client = EndpointsClient(API_HOST)
            self._get_credentials()
            self.model = self._get_model(id)
            self.version = self._get_version(version)
            self.session = self._set_session()
            self.framework = self.version.framework
            self.uri = self._retrieve_uri()
            self.endpoint_id = self._get_endpoint_id()
            if output_path is not None:
                self._output_path = output_path
            else:
                self._output_path = os.path.join(tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name})
            # Now this internally uses self._output_path
            # As we are using the cache hitting this function should not be problematic
            self._download_model()

onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this and use self._output_path so cache keys are more consistent

cache_str = f"{self.model.id}_{self.version.version}_model"
self._download_model(output_path)

if cache_str in self._cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed changes

          if self._output_path in self._cache:
                file_path = Path(self._cache.get(self._output_path))
                with open(file_path, "rb") as f:
                    onnx_model = f.read()

onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the previous comment, let's remove this and use self._output_path so cache keys are more consistent

save_path = Path(f"{output_path}/{self.model.name}.onnx")
logger.info("ONNX model is ready, downloading! ✅")

if ".onnx" in output_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed changes this should be self._output_path

@@ -221,6 +236,7 @@ def predict(
custom_output_dtype: Optional[str] = None,
job_size: str = "M",
dry_run: bool = False,
output_path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this as it is not needed

@@ -272,7 +288,7 @@ def predict(
logger.info("Serialized: %s", serialized_output)

if custom_output_dtype is None:
output_dtype = self._get_output_dtype()
output_dtype = self._get_output_dtype(output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be as it was before

file = self.version_client.download_original(
self.model.id, self.version.version
)
cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed changes this should be self._output_path

@Gonmeso
Copy link
Contributor

Gonmeso commented Apr 23, 2024

Testing is still missing, that its why I proposed creating a cache class to test it easily.

@shivam6862
Copy link
Contributor Author

shivam6862 commented Apr 23, 2024

@Gonmeso @raphaelDkhn
Implemented cache test and requested changes have been done.
I'm looking for some guidance.

Copy link
Contributor

@Gonmeso Gonmeso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really close to be done!

We just need to fix the tests and this would be ready, as CI is failing. Try to make sure that everything is working previously by running pytest

f"{self.model_id}_{self.version_id}_{self.model.name}",
)
self._download_model()
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache should be initialized before _set_session and download_model as if the user does not provide output_path set session will try to hit self._cache which has not been initialized.

This is making the previous two existing tests to fail.

@patch("giza_actions.model.GizaModel._get_output_dtype")
@patch("giza_actions.model.GizaModel._retrieve_uri")
@patch("giza_actions.model.GizaModel._get_endpoint_id", return_value=1)
def test_cache_implementation(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is failing as some mocks are missing, the current error is due to missing credentials, for that just add:

@patch("giza_actions.model.GizaModel._get_credentials")
def test_cache_implementation(*args):

Also it is possible that some patches are missing as well, for example for _get_model and _get_version might be need.

@shivam6862
Copy link
Contributor Author

@Gonmeso Added Patch for test and made required changes to pass CI,
I'm looking for some guidance.

@shivam6862
Copy link
Contributor Author

@Gonmeso The CI failed due to "Exception: Token expired"
guide me on how to deal with this situation

@Gonmeso
Copy link
Contributor

Gonmeso commented Apr 24, 2024

Hi!

One little note, making the changes and just waiting CI to pass will make this process longer, it is encouraged to run the tests locally and then push the changes when they pass locally.

The error is the following:

/home/runner/work/actions-sdk/actions-sdk/giza_actions/model.py:206: in _download_model
    onnx_model = self.version_client.download_original(
E   Exception: Token expired or not set. API Key not available. Log in again.
        self       = <giza_actions.model.GizaModel object at 0x7fcfd48[65](https://github.com/gizatechxyz/actions-sdk/actions/runs/8814392110/job/24195017336?pr=39#step:6:66)890>

Here we can see that _download_model is being executed and the next call in the stack is self.version_client.download_original which is trying to get the model from the API and fails to do so without credentials, that means that we need to patch this.

With the recent changes _download_model, it is being executed every single time as currently we always handle the output path to use with the cache.

To solve this we could start by patching the self.version_client.download_original function. Checking this function, this dependency comes in the __init__ method:

            self.model_client = ModelsClient(API_HOST)
            self.version_client = VersionsClient(API_HOST) # < This one 
            self.api_client = ApiClient(API_HOST)
            self.endpoints_client = EndpointsClient(API_HOST)

In order to patch this, we need to patch the imported client from the script (docs here: https://docs.python.org/3/library/unittest.mock.html#where-to-patch):

@patch("giza_actions.model.VersionsClient", return_value=b"some bytes")

We add a returned value of bytes because it is what self.version_client.download_original returns.

Hope this helps @shivam6862

@Gonmeso Gonmeso self-assigned this Apr 24, 2024
@shivam6862
Copy link
Contributor Author

shivam6862 commented Apr 24, 2024

Screenshot 2024-04-24 200923

@Gonmeso Added all test case are passed
Please see if these are the required changes

@shivam6862
Copy link
Contributor Author

@Gonmeso done for all file run pre-commit run --all-files which was done previously for .gitignore

Copy link
Contributor

@Gonmeso Gonmeso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your input!

LGTM!!!!

Great work @shivam6862

@Gonmeso Gonmeso merged commit f5b84eb into gizatechxyz:main Apr 25, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a cache for already downloaded models
2 participants