diff --git a/giza_actions/action.py b/giza_actions/action.py index 58178f8..25caf6b 100644 --- a/giza_actions/action.py +++ b/giza_actions/action.py @@ -9,8 +9,8 @@ from prefect import Flow # noqa: E402 from prefect import flow as _flow # noqa: E402 +from prefect.settings import PREFECT_API_URL # noqa: E402 from prefect.settings import ( # noqa: E402 - PREFECT_API_URL, PREFECT_LOGGING_SETTINGS_PATH, PREFECT_UI_URL, update_current_profile, diff --git a/giza_actions/model.py b/giza_actions/model.py index 8858c42..e9fae08 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -1,15 +1,14 @@ -from functools import wraps import json from pathlib import Path -from typing import Callable, Optional, Dict -import numpy as np +from typing import Dict, Optional -import requests +import numpy as np import onnxruntime as ort +import requests from giza import API_HOST from giza.client import ApiClient, ModelsClient, VersionsClient from giza.utils.enums import VersionStatus -from osiris.app import serialize, deserialize, serializer, create_tensor_from_array +from osiris.app import create_tensor_from_array, deserialize, serialize, serializer class GizaModel: @@ -22,15 +21,13 @@ def __init__( orion_runner_service_url: Optional[str] = None, ): if model_path is None and id is None and version is None: - raise ValueError( - "Either model_path or id and version must be provided.") + raise ValueError("Either model_path or id and version must be provided.") if model_path is None and (id is None or version is None): raise ValueError("Both id and version must be provided.") if model_path and (id or version): - raise ValueError( - "Either model_path or id and version must be provided.") + raise ValueError("Either model_path or id and version must be provided.") self.orion_runner_service_url = orion_runner_service_url @@ -48,12 +45,10 @@ def _download_model(self, model_id: int, version_id: int, output_path: str): version = self.version_client.get(model_id, version_id) if version.status != VersionStatus.COMPLETED: - raise ValueError( - f"Model version status is not completed {version.status}") + raise ValueError(f"Model version status is not completed {version.status}") print("ONNX model is ready, downloading! ✅") - onnx_model = self.api_client.download_original( - model_id, version.version) + onnx_model = self.api_client.download_original(model_id, version.version) model_name = version.original_model_path.split("/")[-1] save_path = Path(output_path) / model_name @@ -69,8 +64,14 @@ def _get_credentials(self): self.api_client.retrieve_token() self.api_client.retrieve_api_key() - def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] = None, verifiable: bool = False, fp_impl='FP16x16', output_dtype: str = 'tensor_fixed_point'): - + def predict( + self, + input_file: Optional[str] = None, + input_feed: Optional[Dict] = None, + verifiable: bool = False, + fp_impl="FP16x16", + output_dtype: str = "tensor_fixed_point", + ): if verifiable: if not self.orion_runner_service_url: raise ValueError("Orion Runner service URL must be provided") @@ -78,17 +79,17 @@ def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] = endpoint = f"{self.orion_runner_service_url}/cairo_run" cairo_payload = self._format_inputs_for_cairo( - input_file, input_feed, fp_impl) + input_file, input_feed, fp_impl + ) response = requests.post(endpoint, json=cairo_payload) - serialized_output = json.dumps( - response.json()['result']) + serialized_output = json.dumps(response.json()["result"]) if response.status_code == 200: - preds = self._parse_cairo_response( - serialized_output, output_dtype, fp_impl) + serialized_output, output_dtype, fp_impl + ) else: raise Exception(f"OrionRunner service error: {response.text}") @@ -100,11 +101,12 @@ def predict(self, input_file: Optional[str] = None, input_feed: Optional[Dict] = preds = self.session.run(None, input_feed)[0] return preds - def _format_inputs_for_cairo(self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl): + def _format_inputs_for_cairo( + self, input_file: Optional[str], input_feed: Optional[Dict], fp_impl + ): serialized = None if input_file is not None: - print(input_file) serialized = serialize(input_file, fp_impl) @@ -122,11 +124,3 @@ def _format_inputs_for_cairo(self, input_file: Optional[str], input_feed: Option def _parse_cairo_response(self, response, data_type: str, fp_impl): return deserialize(response, data_type, fp_impl) - - -def model(func: Callable, id: int, version: int): - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper diff --git a/poetry.lock b/poetry.lock index 00f2332..3d2384f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -927,13 +927,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-api-python-client" -version = "2.114.0" +version = "2.115.0" description = "Google API Client Library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-python-client-2.114.0.tar.gz", hash = "sha256:e041bbbf60e682261281e9d64b4660035f04db1cccba19d1d68eebc24d1465ed"}, - {file = "google_api_python_client-2.114.0-py2.py3-none-any.whl", hash = "sha256:690e0bb67d70ff6dea4e8a5d3738639c105a478ac35da153d3b2a384064e9e1a"}, + {file = "google-api-python-client-2.115.0.tar.gz", hash = "sha256:96af11376535236ba600ebbe23588cfe003ec9b74e66dd6ddb53aa3ec87e1b52"}, + {file = "google_api_python_client-2.115.0-py2.py3-none-any.whl", hash = "sha256:26178e33684763099142e2cad201057bd27d4efefd859a495aac21ab3e6129c2"}, ] [package.dependencies] @@ -945,13 +945,13 @@ uritemplate = ">=3.0.1,<5" [[package]] name = "google-auth" -version = "2.26.2" +version = "2.27.0" description = "Google Authentication Library" optional = false python-versions = ">=3.7" files = [ - {file = "google-auth-2.26.2.tar.gz", hash = "sha256:97327dbbf58cccb58fc5a1712bba403ae76668e64814eb30f7316f7e27126b81"}, - {file = "google_auth-2.26.2-py2.py3-none-any.whl", hash = "sha256:3f445c8ce9b61ed6459aad86d8ccdba4a9afed841b2d1451a11ef4db08957424"}, + {file = "google-auth-2.27.0.tar.gz", hash = "sha256:e863a56ccc2d8efa83df7a80272601e43487fa9a728a376205c86c26aaefa821"}, + {file = "google_auth-2.27.0-py2.py3-none-any.whl", hash = "sha256:8e4bad367015430ff253fe49d500fdc3396c1a434db5740828c728e45bcce245"}, ] [package.dependencies] @@ -2272,13 +2272,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co [[package]] name = "pluggy" -version = "1.3.0" +version = "1.4.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" files = [ - {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, - {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, ] [package.extras] @@ -2407,13 +2407,13 @@ dev = ["cairosvg", "codespell", "flaky", "ipython", "ipython (==8.12.*)", "jinja [[package]] name = "prefect-docker" -version = "0.4.3" +version = "0.4.4" description = "Prefect integrations for working with Docker" optional = false python-versions = ">=3.7" files = [ - {file = "prefect-docker-0.4.3.tar.gz", hash = "sha256:eb6aa1e61299484ba36d572039ad9e6333d1060c5d78d852bb307304889e92df"}, - {file = "prefect_docker-0.4.3-py3-none-any.whl", hash = "sha256:2e5d8cd965719e2a9a0e77722455c6f0bcfa65d81a50d0f46577328fe3225706"}, + {file = "prefect-docker-0.4.4.tar.gz", hash = "sha256:f6295613ca2072044008afdf73c817d1fa47a2a202fcb7c974928e8128b4f219"}, + {file = "prefect_docker-0.4.4-py3-none-any.whl", hash = "sha256:e6b052203bb52bb008ed979808db345b0e621fcaf6048589ac68dbd345664ff4"}, ] [package.dependencies] @@ -2809,7 +2809,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2817,16 +2816,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2843,7 +2834,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2851,7 +2841,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3545,13 +3534,13 @@ types-pyasn1 = "*" [[package]] name = "types-requests" -version = "2.31.0.20240106" +version = "2.31.0.20240125" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.31.0.20240106.tar.gz", hash = "sha256:0e1c731c17f33618ec58e022b614a1a2ecc25f7dc86800b36ef341380402c612"}, - {file = "types_requests-2.31.0.20240106-py3-none-any.whl", hash = "sha256:da997b3b6a72cc08d09f4dba9802fdbabc89104b35fe24ee588e674037689354"}, + {file = "types-requests-2.31.0.20240125.tar.gz", hash = "sha256:03a28ce1d7cd54199148e043b2079cdded22d6795d19a2c2a6791a4b2b5e2eb5"}, + {file = "types_requests-2.31.0.20240125-py3-none-any.whl", hash = "sha256:9592a9a4cb92d6d75d9b491a41477272b710e021011a2a3061157e2fb1f1a5d1"}, ] [package.dependencies] @@ -3835,4 +3824,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<4.0" -content-hash = "33808e15ce52bcfe31b973ce1043f57096b22a9bb05077324bd1351953b09781" +content-hash = "31a5ea14d9f7bce3cee310f106a7e1176f6036f646eba6bb8f8bf19ea373355a" diff --git a/pyproject.toml b/pyproject.toml index 5807e15..818a924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ pyyaml = "^6.0.1" prefect-docker = "^0.4.1" distlib = "^0.3.8" giza-cli = "^0.7.0" -giza-osiris = "0.2.1" +giza-osiris = "^0.2.1" [tool.poetry.dev-dependencies] pytest = "^6.2.5" @@ -37,3 +37,6 @@ pre-commit = "^3.5.0" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.isort] +profile = "black" diff --git a/tests/test_model.py b/tests/test_model.py index ea7e77e..ce29053 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,28 +4,26 @@ def test_predict_success(): - - model = GizaModel( - model_path="", - orion_runner_service_url="http://localhost:8080") + model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080") arr = np.array([[1, 2], [3, 4]], dtype=np.uint32) result = model.predict( - input_feed={"arr_1": arr}, verifiable=True, output_dtype='tensor_int') + input_feed={"arr_1": arr}, verifiable=True, output_dtype="tensor_int" + ) assert np.array_equal(result, arr) def test_predict_success_with_file(): - - model = GizaModel( - model_path="", - orion_runner_service_url="http://localhost:8080") + model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080") expected = np.array([[1, 2], [3, 4]], dtype=np.uint32) result = model.predict( - input_file='tests/data/simple_tensor.csv', verifiable=True, output_dtype='tensor_int') + input_file="tests/data/simple_tensor.csv", + verifiable=True, + output_dtype="tensor_int", + ) assert np.array_equal(result, expected)