Skip to content

Commit

Permalink
Linting, imports and dependencies fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Gonmeso committed Jan 25, 2024
1 parent 3b5f08b commit 6ab3968
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 52 deletions.
4 changes: 2 additions & 2 deletions giza_actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from prefect import Flow # noqa: E402
from prefect import flow as _flow # noqa: E402
from prefect.settings import ( # noqa: E402
PREFECT_API_URL,
from prefect.settings import PREFECT_API_URL # noqa: E402
from prefect.settings import (
PREFECT_LOGGING_SETTINGS_PATH,
PREFECT_UI_URL,
update_current_profile,
Expand Down
47 changes: 25 additions & 22 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from functools import wraps
import json
from functools import wraps
from pathlib import Path
from typing import Callable, Optional, Dict
import numpy as np
from typing import Callable, Dict, Optional

import requests
import numpy as np
import onnxruntime as ort
import requests
from giza import API_HOST
from giza.client import ApiClient, ModelsClient, VersionsClient
from giza.utils.enums import VersionStatus
from osiris.app import serialize, deserialize, serializer, create_tensor_from_array
from osiris.app import create_tensor_from_array, deserialize, serialize, serializer


class GizaModel:
Expand All @@ -22,15 +22,13 @@ def __init__(
orion_runner_service_url: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

self.orion_runner_service_url = orion_runner_service_url

Expand All @@ -48,12 +46,10 @@ def _download_model(self, model_id: int, version_id: int, output_path: str):
version = self.version_client.get(model_id, version_id)

if version.status != VersionStatus.COMPLETED:
raise ValueError(
f"Model version status is not completed {version.status}")
raise ValueError(f"Model version status is not completed {version.status}")

print("ONNX model is ready, downloading! ✅")
onnx_model = self.api_client.download_original(
model_id, version.version)
onnx_model = self.api_client.download_original(model_id, version.version)

model_name = version.original_model_path.split("/")[-1]
save_path = Path(output_path) / model_name
Expand All @@ -69,26 +65,32 @@ def _get_credentials(self):
self.api_client.retrieve_token()
self.api_client.retrieve_api_key()

def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] = None, verifiable: bool = False, fp_impl='FP16x16', output_dtype: str = 'tensor_fixed_point'):

def predict(
self,
input_file: Optional[str] = None,
input_feed: Optional[Dict] = None,
verifiable: bool = False,
fp_impl="FP16x16",
output_dtype: str = "tensor_fixed_point",
):
if verifiable:
if not self.orion_runner_service_url:
raise ValueError("Orion Runner service URL must be provided")

endpoint = f"{self.orion_runner_service_url}/cairo_run"

cairo_payload = self._format_inputs_for_cairo(
input_file, input_feed, fp_impl)
input_file, input_feed, fp_impl
)

response = requests.post(endpoint, json=cairo_payload)

serialized_output = json.dumps(
response.json()['result'])
serialized_output = json.dumps(response.json()["result"])

if response.status_code == 200:

preds = self._parse_cairo_response(
serialized_output, output_dtype, fp_impl)
serialized_output, output_dtype, fp_impl
)
else:
raise Exception(f"OrionRunner service error: {response.text}")

Expand All @@ -100,11 +102,12 @@ def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] =
preds = self.session.run(None, input_feed)[0]
return preds

def _format_inputs_for_cairo(self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl):
def _format_inputs_for_cairo(
self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl
):
serialized = None

if input_file is not None:

print(input_file)

serialized = serialize(input_file, fp_impl)
Expand Down
43 changes: 16 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pyyaml = "^6.0.1"
prefect-docker = "^0.4.1"
distlib = "^0.3.8"
giza-cli = "^0.7.0"
giza-osiris = "0.2.1"
giza-osiris = "^0.2.1"

[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
Expand All @@ -37,3 +37,6 @@ pre-commit = "^3.5.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.isort]
profile = "black"

0 comments on commit 6ab3968

Please sign in to comment.