Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented a cache for already downloaded models #39

Merged
merged 7 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 23 additions & 31 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import os
from pathlib import Path
from typing import Dict, Optional
from diskcache import Cache
import os

import numpy as np
import onnx
import onnxruntime as ort
import requests
from diskcache import Cache
from giza import API_HOST
from giza.client import ApiClient, EndpointsClient, ModelsClient, VersionsClient
from giza.utils.enums import Framework, VersionStatus
Expand Down Expand Up @@ -56,15 +56,13 @@ def __init__(
output_path: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path and id and version:
raise ValueError(
Expand All @@ -83,13 +81,14 @@ def __init__(
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if will make the prediction fail when verifiable=False if it is not provided, making it mandatory which is not what we are aiming for.

self.session = self._set_session(output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_path shouldn't be a mandatory argument for set_session

if output_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having to pass output_path to _download_model and _get_output_dtype will mostly make output_path to be mandatory, but this is not the intention so another path where if output_path is not provided we handle it for the user:

  • Lets make the class to have an attribute self._output_path
  • If output_path is present when the user creates the instance, we will use that path, if not we will handle that
  • Remove output_path argument from _download_model and use self._output_path
  • Remove output_path from _set_session and _get_output_dtype as it is not needed any more

An example of the init could be:

    def __init__(
        self,
        model_path: Optional[str] = None,
        id: Optional[int] = None,
        version: Optional[int] = None,
        output_path: Optional[str] = None,
    ):
       ...

        if model_path:
            self.session = ort.InferenceSession(model_path)
        elif id and version:
            self.model_id = id
            self.version_id = version
            self.model_client = ModelsClient(API_HOST)
            self.version_client = VersionsClient(API_HOST)
            self.api_client = ApiClient(API_HOST)
            self.endpoints_client = EndpointsClient(API_HOST)
            self._get_credentials()
            self.model = self._get_model(id)
            self.version = self._get_version(version)
            self.session = self._set_session()
            self.framework = self.version.framework
            self.uri = self._retrieve_uri()
            self.endpoint_id = self._get_endpoint_id()
            if output_path is not None:
                self._output_path = output_path
            else:
                self._output_path = os.path.join(tempfile.gettempdir(), f"{self.model_id}_{self.version_id}_{self.model.name})
            # Now this internally uses self._output_path
            # As we are using the cache hitting this function should not be problematic
            self._download_model()

self._download_model(output_path)
self.cache = Cache(os.getcwd() + '/tmp/cachedir')
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache should be initialized before _set_session and download_model as if the user does not provide output_path set session will try to hit self._cache which has not been initialized.

This is making the previous two existing tests to fail.


def _get_endpoint_id(self):
"""
Expand Down Expand Up @@ -154,7 +153,7 @@ def _get_version(self, version_id: int):
"""
return self.version_client.get(self.model.id, version_id)

def _set_session(self):
def _set_session(self, output_path: str):
"""
Set onnxruntime session for the model specified by model id.

Expand All @@ -169,15 +168,12 @@ def _set_session(self):

try:
cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this and use self._output_path so cache keys are more consistent

if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
self._download_model(output_path)

if cache_str in self._cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed changes

          if self._output_path in self._cache:
                file_path = Path(self._cache.get(self._output_path))
                with open(file_path, "rb") as f:
                    onnx_model = f.read()

file_path = Path(self._cache.get(cache_str))
with open(file_path, "rb") as f:
onnx_model = f.read()
else:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)

return ort.InferenceSession(onnx_model)

Expand All @@ -203,7 +199,7 @@ def _download_model(self, output_path: str):

cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the previous comment, let's remove this and use self._output_path so cache keys are more consistent


if cache_str not in self.cache:
if cache_str not in self._cache:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
Expand All @@ -218,7 +214,7 @@ def _download_model(self, output_path: str):
with open(save_path, "wb") as f:
f.write(onnx_model)

self.cache[cache_str] = save_path
self._cache[cache_str] = save_path

logger.info(f"ONNX model saved at: {save_path} ✅")
else:
Expand All @@ -240,6 +236,7 @@ def predict(
custom_output_dtype: Optional[str] = None,
job_size: str = "M",
dry_run: bool = False,
output_path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this as it is not needed

):
"""
Makes a prediction using either a local ONNX session or a remote deployed model, depending on the
Expand Down Expand Up @@ -279,8 +276,7 @@ def predict(
response.raise_for_status()
except requests.exceptions.HTTPError as e:
logger.error(f"An error occurred in predict: {e}")
error_message = f"Deployment predict error: {
response.text}"
error_message = f"Deployment predict error: {response.text}"
logger.error(error_message)
raise e

Expand All @@ -292,13 +288,12 @@ def predict(
logger.info("Serialized: %s", serialized_output)

if custom_output_dtype is None:
output_dtype = self._get_output_dtype()
output_dtype = self._get_output_dtype(output_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be as it was before

else:
output_dtype = custom_output_dtype

logger.debug("Output dtype: %s", output_dtype)
preds = self._parse_cairo_response(
serialized_output, output_dtype)
preds = self._parse_cairo_response(serialized_output, output_dtype)
elif self.framework == Framework.EZKL:
preds = np.array(serialized_output[0])
return (preds, request_id)
Expand Down Expand Up @@ -409,7 +404,7 @@ def _parse_cairo_response(self, response, data_type: str):
"""
return deserialize(response, data_type)

def _get_output_dtype(self):
def _get_output_dtype(self, output_path: str):
"""
Retrieve the Cairo output data type base on the operator type of the final node.

Expand All @@ -418,15 +413,12 @@ def _get_output_dtype(self):
"""

cache_str = f"{self.model.id}_{self.version.version}_model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the proposed changes this should be self._output_path

if cache_str in self.cache:
file_path = self.cache.get(cache_str)
file_path = Path(file_path)
self._download_model(output_path)

if cache_str in self._cache:
file_path = Path(self._cache.get(cache_str))
with open(file_path, "rb") as f:
file = f.read()
else:
file = self.version_client.download_original(
self.model.id, self.version.version
)

model = onnx.load_model_from_string(file)
graph = model.graph
Expand Down
17 changes: 14 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.11,<4.0"
diskcache == "5.6.3"
numpy = "^1.26.2"
prefect = "2.14.6"
onnx = "^1.15.0"
Expand All @@ -28,6 +27,7 @@ giza-osiris = ">=0.2.6,<1.0.0"
loguru = "^0.7.2"
eth-ape = {version = "^0.7.10", optional = true }
ape-etherscan = {version = "^0.7.2", optional = true }
diskcache = "^5.6.3"

[tool.poetry.extras]
agents = ["eth-ape", "ape-etherscan"]
Expand Down
Loading