Skip to content

Commit

Permalink
remove unnecessary model_id param
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 19, 2024
1 parent 4a4b156 commit 47a30aa
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,25 @@ def __init__(
self.api_client = ApiClient(API_HOST)
self._get_credentials()
self.model = self._get_model(id)
self.version = self._get_version(id, version)
self.session = self._set_session(id)
self.version = self._get_version(version)
self.session = self._set_session()
self.framework = self.version.framework
self.uri = self._retrieve_uri(id, version)
self.uri = self._retrieve_uri(version)
if output_path:
self._download_model(id, output_path)
self._download_model(output_path)

def _retrieve_uri(self, model_id: int, version_id: int):
def _retrieve_uri(self, version_id: int):
"""
Retrieves the URI for making prediction requests to a deployed model.
Args:
model_id (int): The unique identifier of the model.
version_id (int): The version number of the model.
Returns:
The URI for making prediction requests to the deployed model.
"""
# Different URI per framework
uri = get_endpoint_uri(model_id, version_id)
uri = get_endpoint_uri(self.model.id, version_id)
if self.framework == Framework.CAIRO:
return f"{uri}/cairo_run"
else:
Expand All @@ -111,25 +110,21 @@ def _get_model(self, model_id: int):
"""
return self.model_client.get(model_id)

def _get_version(self, model_id: int, version_id: int):
def _get_version(self, version_id: int):
"""
Retrieves the version of the model specified by model_id and version_id.
Retrieves the version of the model specified by model id and version id.
Args:
model_id (int): The unique identifier of the model.
version_id (int): The version number of the model.
Returns:
The version of the model.
"""
return self.version_client.get(model_id, version_id)
return self.version_client.get(self.model.id, version_id)

def _set_session(self, model_id: int):
def _set_session(self):
"""
Set onnxruntime session for the model specified by model_id.
Args:
model_id (int): The unique identifier of the model.
Set onnxruntime session for the model specified by model id.
Raises:
ValueError: If the model version status is not completed.
Expand All @@ -142,20 +137,19 @@ def _set_session(self, model_id: int):

try:
onnx_model = self.version_client.download_original(
model_id, self.version.version)
self.model.id, self.version.version)

return ort.InferenceSession(onnx_model)

except Exception as e:
print(f"Could not download model: {e}")
return None

def _download_model(self, model_id: int, output_path: str):
def _download_model(self, output_path: str):
"""
Downloads the model specified by model_id and version_id to the given output_path.
Downloads the model specified by model id and version id to the given output_path.
Args:
model_id (int): The unique identifier of the model.
output_path (str): The file path where the downloaded model should be saved.
Raises:
Expand All @@ -168,7 +162,7 @@ def _download_model(self, model_id: int, output_path: str):
)

onnx_model = self.version_client.download_original(
model_id, self.version.version)
self.model.id, self.version.version)

print("ONNX model is ready, downloading! ✅")

Expand Down

0 comments on commit 47a30aa

Please sign in to comment.