From 1cbe15669d229d67dd0551da651343a1b6d3c2c4 Mon Sep 17 00:00:00 2001 From: Gonzalo Mellizo-Soto Date: Thu, 25 Jan 2024 12:54:28 +0100 Subject: [PATCH] More linting fixes --- giza_actions/action.py | 2 +- giza_actions/model.py | 3 +-- tests/test_model.py | 18 ++++++++---------- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/giza_actions/action.py b/giza_actions/action.py index 3ddf8c8..25caf6b 100644 --- a/giza_actions/action.py +++ b/giza_actions/action.py @@ -10,7 +10,7 @@ from prefect import Flow # noqa: E402 from prefect import flow as _flow # noqa: E402 from prefect.settings import PREFECT_API_URL # noqa: E402 -from prefect.settings import ( +from prefect.settings import ( # noqa: E402 PREFECT_LOGGING_SETTINGS_PATH, PREFECT_UI_URL, update_current_profile, diff --git a/giza_actions/model.py b/giza_actions/model.py index e6bae27..e9fae08 100644 --- a/giza_actions/model.py +++ b/giza_actions/model.py @@ -1,7 +1,6 @@ import json -from functools import wraps from pathlib import Path -from typing import Callable, Dict, Optional +from typing import Dict, Optional import numpy as np import onnxruntime as ort diff --git a/tests/test_model.py b/tests/test_model.py index ea7e77e..ce29053 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,28 +4,26 @@ def test_predict_success(): - - model = GizaModel( - model_path="", - orion_runner_service_url="http://localhost:8080") + model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080") arr = np.array([[1, 2], [3, 4]], dtype=np.uint32) result = model.predict( - input_feed={"arr_1": arr}, verifiable=True, output_dtype='tensor_int') + input_feed={"arr_1": arr}, verifiable=True, output_dtype="tensor_int" + ) assert np.array_equal(result, arr) def test_predict_success_with_file(): - - model = GizaModel( - model_path="", - orion_runner_service_url="http://localhost:8080") + model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080") expected = np.array([[1, 2], [3, 4]], dtype=np.uint32) result = model.predict( - input_file='tests/data/simple_tensor.csv', verifiable=True, output_dtype='tensor_int') + input_file="tests/data/simple_tensor.csv", + verifiable=True, + output_dtype="tensor_int", + ) assert np.array_equal(result, expected)