Skip to content

Commit

Permalink
Apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed May 10, 2024
1 parent 83aafcf commit 3850178
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ def __init__(
output_path: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError(
"Either model_path or id and version must be provided.")
raise ValueError("Either model_path or id and version must be provided.")

if model_path and id and version:
raise ValueError(
Expand Down Expand Up @@ -230,8 +228,7 @@ def _download_model(self) -> None:

logger.info(f"Model saved at: {save_path} ✅")
else:
logger.info(
f"Model already downloaded at: {self._output_path} ✅")
logger.info(f"Model already downloaded at: {self._output_path} ✅")

def _get_credentials(self) -> None:
"""
Expand Down Expand Up @@ -278,7 +275,11 @@ def predict(

# Non common arguments should be named parameters
payload = self._format_inputs_for_framework(
input_file, input_feed, fp_impl=fp_impl, model_category=model_category, job_size=job_size
input_file,
input_feed,
fp_impl=fp_impl,
model_category=model_category,
job_size=job_size,
)

if dry_run:
Expand Down Expand Up @@ -311,7 +312,8 @@ def predict(

logger.debug("Output dtype: %s", output_dtype)
preds = self._parse_cairo_response(
serialized_output, output_dtype, model_category)
serialized_output, output_dtype, model_category
)

elif self.framework == Framework.EZKL:
preds = np.array(serialized_output[0])
Expand Down Expand Up @@ -352,7 +354,7 @@ def _format_inputs_for_cairo(
input_feed: Optional[Dict],
fp_impl: str,
model_category: str,
job_size: str
job_size: str,
) -> Dict[str, str]:
"""
Formats the inputs for a prediction request using OrionRunner.
Expand All @@ -375,13 +377,13 @@ def _format_inputs_for_cairo(
if input_feed:
for name, value in input_feed.items():
if isinstance(value, np.ndarray):
if model_category == 'ONNX_ORION':
if model_category == "ONNX_ORION":
tensor = create_tensor_from_array(value, fp_impl)
elif model_category in ['XGB', 'LGBM']:
elif model_category in ["XGB", "LGBM"]:
tensor = value * 100000
tensor = tensor.astype(np.int64)
else:
tensor = create_tensor_from_array(value, 'FP16x16')
tensor = create_tensor_from_array(value, "FP16x16")
formatted_args.append(serializer(tensor))

return {"job_size": job_size, "args": " ".join(formatted_args)}
Expand Down Expand Up @@ -420,7 +422,9 @@ def _format_inputs_for_ezkl(
)
return {"input_data": [data], "job_size": job_size}

def _parse_cairo_response(self, response: str, data_type: str, model_category: str) -> str:
def _parse_cairo_response(
self, response: str, data_type: str, model_category: str
) -> str:
"""
Parses the response from a OrionRunner prediction request.
Expand Down

0 comments on commit 3850178

Please sign in to comment.