Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented a cache for already downloaded models #39

Merged
merged 7 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ examples/on-chain_mnist/cairo/lofi_mnst_2
examples/on-chain_mnist/cairo/soft
examples/on-chain_mnist/cairo/mnist_sierra
examples/on-chain_mnist/contracts/out

# cache files
tmp
63 changes: 42 additions & 21 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional

import numpy as np
import onnx
import onnxruntime as ort
import requests
from diskcache import Cache
from giza import API_HOST
from giza.client import ApiClient, EndpointsClient, ModelsClient, VersionsClient
from giza.utils.enums import Framework, VersionStatus
Expand Down Expand Up @@ -79,12 +82,19 @@ def __init__(
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri()
self.endpoint_id = self._get_endpoint_id()
if output_path:
self._download_model(output_path)
self._cache = Cache(os.path.join(os.getcwd(), "tmp", "cachedir"))
self.session = self._set_session()
if output_path is not None:
self._output_path = output_path
else:
self._output_path = os.path.join(
tempfile.gettempdir(),
f"{self.model_id}_{self.version_id}_{self.model.name}",
)
self._download_model()

def _get_endpoint_id(self):
"""
Expand Down Expand Up @@ -163,17 +173,20 @@ def _set_session(self):
)

try:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
self._download_model()

if self._output_path in self._cache:
file_path = Path(self._cache.get(self._output_path))
with open(file_path, "rb") as f:
onnx_model = f.read()

return ort.InferenceSession(onnx_model)

except Exception as e:
logger.info(f"Could not download model: {e}")
return None

def _download_model(self, output_path: str):
def _download_model(self):
"""
Downloads the model specified by model id and version id to the given output_path.

Expand All @@ -189,21 +202,26 @@ def _download_model(self, output_path: str):
f"Model version status is not completed {self.version.status}"
)

onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)
if self._output_path not in self._cache:
onnx_model = self.version_client.download_original(
self.model.id, self.version.version
)

logger.info("ONNX model is ready, downloading! ✅")
logger.info("ONNX model is ready, downloading! ✅")

if ".onnx" in output_path:
save_path = Path(output_path)
else:
save_path = Path(f"{output_path}/{self.model.name}.onnx")
if ".onnx" in self._output_path:
save_path = Path(self._output_path)
else:
save_path = Path(f"{self._output_path}.onnx")

with open(save_path, "wb") as f:
f.write(onnx_model)
with open(save_path, "wb") as f:
f.write(onnx_model)

logger.info(f"ONNX model saved at: {save_path} ✅")
self._cache[self._output_path] = save_path

logger.info(f"ONNX model saved at: {save_path} ✅")
else:
logger.info(f"ONNX model already downloaded at: {self._output_path} ✅")

def _get_credentials(self):
"""
Expand Down Expand Up @@ -396,9 +414,12 @@ def _get_output_dtype(self):
The output dtype as a string.
"""

file = self.version_client.download_original(
self.model.id, self.version.version
)
self._download_model()

if self._output_path in self._cache:
file_path = Path(self._cache.get(self._output_path))
with open(file_path, "rb") as f:
file = f.read()

model = onnx.load_model_from_string(file)
graph = model.graph
Expand Down
17 changes: 14 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ giza-osiris = ">=0.2.6,<1.0.0"
loguru = "^0.7.2"
eth-ape = {version = "^0.7.10", optional = true }
ape-etherscan = {version = "^0.7.2", optional = true }
diskcache = "^5.6.3"

[tool.poetry.extras]
agents = ["eth-ape", "ape-etherscan"]
Expand Down
44 changes: 44 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def raise_for_status(self):
"giza_actions.model.GizaModel._parse_cairo_response",
return_value=np.array([[1, 2], [3, 4]], dtype=np.uint32),
)
@patch(
"giza_actions.model.VersionsClient.download_original", return_value=b"some bytes"
)
def test_predict_success(*args):
model = GizaModel(id=50, version=2)

Expand Down Expand Up @@ -86,6 +89,9 @@ def test_predict_success(*args):
"giza_actions.model.GizaModel._parse_cairo_response",
return_value=np.array([[1, 2], [3, 4]], dtype=np.uint32),
)
@patch(
"giza_actions.model.VersionsClient.download_original", return_value=b"some bytes"
)
def test_predict_success_with_file(*args):
model = GizaModel(id=50, version=2)

Expand All @@ -102,3 +108,41 @@ def test_predict_success_with_file(*args):

assert np.array_equal(result, expected)
assert req_id == "123"


@patch("giza_actions.model.GizaModel._get_credentials")
@patch("giza_actions.model.GizaModel._get_model", return_value=Model(id=50))
@patch(
"giza_actions.model.GizaModel._get_version",
return_value=Version(
version=2,
framework="CAIRO",
size=1,
status="COMPLETED",
created_date="2022-01-01T00:00:00Z",
last_update="2022-01-01T00:00:00Z",
),
)
@patch("giza_actions.model.GizaModel._set_session")
@patch("giza_actions.model.GizaModel._get_output_dtype")
@patch("giza_actions.model.GizaModel._retrieve_uri")
@patch("giza_actions.model.GizaModel._get_endpoint_id", return_value=1)
@patch(
"giza_actions.model.VersionsClient.download_original", return_value=b"some bytes"
)
def test_cache_implementation(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is failing as some mocks are missing, the current error is due to missing credentials, for that just add:

@patch("giza_actions.model.GizaModel._get_credentials")
def test_cache_implementation(*args):

Also it is possible that some patches are missing as well, for example for _get_model and _get_version might be need.

model = GizaModel(id=50, version=2)

result1 = model._set_session()
cache_size_after_first_call = len(model._cache)
result2 = model._set_session()
cache_size_after_second_call = len(model._cache)
assert result1 == result2
assert cache_size_after_first_call == cache_size_after_second_call

result3 = model._get_output_dtype()
cache_size_after_third_call = len(model._cache)
result4 = model._get_output_dtype()
cache_size_after_fourth_call = len(model._cache)
assert result3 == result4
assert cache_size_after_third_call == cache_size_after_fourth_call
Loading