Skip to content

Commit

Permalink
Merge pull request #8 from gizatechxyz/feature/minor-fixes
Browse files Browse the repository at this point in the history
Linting, imports and dependencies fixes
  • Loading branch information
Gonmeso authored Jan 25, 2024
2 parents 3b5f08b + 1cbe156 commit 9063356
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 69 deletions.
2 changes: 1 addition & 1 deletion giza_actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from prefect import Flow # noqa: E402
from prefect import flow as _flow # noqa: E402
from prefect.settings import PREFECT_API_URL # noqa: E402
from prefect.settings import ( # noqa: E402
PREFECT_API_URL,
PREFECT_LOGGING_SETTINGS_PATH,
PREFECT_UI_URL,
update_current_profile,
Expand Down
54 changes: 24 additions & 30 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from functools import wraps
import json
from pathlib import Path
from typing import Callable, Optional, Dict
import numpy as np
from typing import Dict, Optional

import requests
import numpy as np
import onnxruntime as ort
import requests
from giza import API_HOST
from giza.client import ApiClient, ModelsClient, VersionsClient
from giza.utils.enums import VersionStatus
from osiris.app import serialize, deserialize, serializer, create_tensor_from_array
from osiris.app import create_tensor_from_array, deserialize, serialize, serializer


class GizaModel:
Expand All @@ -22,15 +21,13 @@ def __init__(
orion_runner_service_url: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

self.orion_runner_service_url = orion_runner_service_url

Expand All @@ -48,12 +45,10 @@ def _download_model(self, model_id: int, version_id: int, output_path: str):
version = self.version_client.get(model_id, version_id)

if version.status != VersionStatus.COMPLETED:
raise ValueError(
f"Model version status is not completed {version.status}")
raise ValueError(f"Model version status is not completed {version.status}")

print("ONNX model is ready, downloading! ✅")
onnx_model = self.api_client.download_original(
model_id, version.version)
onnx_model = self.api_client.download_original(model_id, version.version)

model_name = version.original_model_path.split("/")[-1]
save_path = Path(output_path) / model_name
Expand All @@ -69,26 +64,32 @@ def _get_credentials(self):
self.api_client.retrieve_token()
self.api_client.retrieve_api_key()

def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] = None, verifiable: bool = False, fp_impl='FP16x16', output_dtype: str = 'tensor_fixed_point'):

def predict(
self,
input_file: Optional[str] = None,
input_feed: Optional[Dict] = None,
verifiable: bool = False,
fp_impl="FP16x16",
output_dtype: str = "tensor_fixed_point",
):
if verifiable:
if not self.orion_runner_service_url:
raise ValueError("Orion Runner service URL must be provided")

endpoint = f"{self.orion_runner_service_url}/cairo_run"

cairo_payload = self._format_inputs_for_cairo(
input_file, input_feed, fp_impl)
input_file, input_feed, fp_impl
)

response = requests.post(endpoint, json=cairo_payload)

serialized_output = json.dumps(
response.json()['result'])
serialized_output = json.dumps(response.json()["result"])

if response.status_code == 200:

preds = self._parse_cairo_response(
serialized_output, output_dtype, fp_impl)
serialized_output, output_dtype, fp_impl
)
else:
raise Exception(f"OrionRunner service error: {response.text}")

Expand All @@ -100,11 +101,12 @@ def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] =
preds = self.session.run(None, input_feed)[0]
return preds

def _format_inputs_for_cairo(self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl):
def _format_inputs_for_cairo(
self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl
):
serialized = None

if input_file is not None:

print(input_file)

serialized = serialize(input_file, fp_impl)
Expand All @@ -122,11 +124,3 @@ def _format_inputs_for_cairo(self, input_file: Optional[str], input_feed: Option

def _parse_cairo_response(self, response, data_type: str, fp_impl):
return deserialize(response, data_type, fp_impl)


def model(func: Callable, id: int, version: int):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper
43 changes: 16 additions & 27 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pyyaml = "^6.0.1"
prefect-docker = "^0.4.1"
distlib = "^0.3.8"
giza-cli = "^0.7.0"
giza-osiris = "0.2.1"
giza-osiris = "^0.2.1"

[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
Expand All @@ -37,3 +37,6 @@ pre-commit = "^3.5.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.isort]
profile = "black"
18 changes: 8 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@


def test_predict_success():

model = GizaModel(
model_path="",
orion_runner_service_url="http://localhost:8080")
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")

arr = np.array([[1, 2], [3, 4]], dtype=np.uint32)

result = model.predict(
input_feed={"arr_1": arr}, verifiable=True, output_dtype='tensor_int')
input_feed={"arr_1": arr}, verifiable=True, output_dtype="tensor_int"
)

assert np.array_equal(result, arr)


def test_predict_success_with_file():

model = GizaModel(
model_path="",
orion_runner_service_url="http://localhost:8080")
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")

expected = np.array([[1, 2], [3, 4]], dtype=np.uint32)

result = model.predict(
input_file='tests/data/simple_tensor.csv', verifiable=True, output_dtype='tensor_int')
input_file="tests/data/simple_tensor.csv",
verifiable=True,
output_dtype="tensor_int",
)

assert np.array_equal(result, expected)

0 comments on commit 9063356

Please sign in to comment.