Skip to content

Commit

Permalink
More linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Gonmeso committed Jan 25, 2024
1 parent 326ee6f commit 1cbe156
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
2 changes: 1 addition & 1 deletion giza_actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prefect import Flow # noqa: E402
from prefect import flow as _flow # noqa: E402
from prefect.settings import PREFECT_API_URL # noqa: E402
from prefect.settings import (
from prefect.settings import ( # noqa: E402
PREFECT_LOGGING_SETTINGS_PATH,
PREFECT_UI_URL,
update_current_profile,
Expand Down
3 changes: 1 addition & 2 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from functools import wraps
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Dict, Optional

import numpy as np
import onnxruntime as ort
Expand Down
18 changes: 8 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@


def test_predict_success():

model = GizaModel(
model_path="",
orion_runner_service_url="http://localhost:8080")
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")

arr = np.array([[1, 2], [3, 4]], dtype=np.uint32)

result = model.predict(
input_feed={"arr_1": arr}, verifiable=True, output_dtype='tensor_int')
input_feed={"arr_1": arr}, verifiable=True, output_dtype="tensor_int"
)

assert np.array_equal(result, arr)


def test_predict_success_with_file():

model = GizaModel(
model_path="",
orion_runner_service_url="http://localhost:8080")
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")

expected = np.array([[1, 2], [3, 4]], dtype=np.uint32)

result = model.predict(
input_file='tests/data/simple_tensor.csv', verifiable=True, output_dtype='tensor_int')
input_file="tests/data/simple_tensor.csv",
verifiable=True,
output_dtype="tensor_int",
)

assert np.array_equal(result, expected)

0 comments on commit 1cbe156

Please sign in to comment.