Skip to content

Commit

Permalink
Fix linting and add tests for deployments
Browse files Browse the repository at this point in the history
  • Loading branch information
Gonmeso committed Jan 17, 2024
1 parent 9370bc6 commit f04a83b
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 5 deletions.
7 changes: 3 additions & 4 deletions giza/frameworks/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,8 @@ def deploy(
try:
client = DeploymentsClient(API_HOST)

deployments: DeploymentsList = client.list(model_id, version_id)
deployments = deployments.json()
deployments = json.loads(deployments)
deployments_list: DeploymentsList = client.list(model_id, version_id)
deployments: dict = json.loads(deployments_list.json())

if len(deployments) > 0:
echo.info(
Expand All @@ -146,7 +145,7 @@ def deploy(

spinner = Spinner(name="aesthetic", text="Creating deployment!")

with Live(renderable=spinner) as live:
with Live(renderable=spinner):
with open(data) as casm:
deployment = client.create(
model_id,
Expand Down
1 change: 0 additions & 1 deletion giza/schemas/deployments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
from typing import Optional

from pydantic import BaseModel
Expand Down
166 changes: 166 additions & 0 deletions tests/commands/test_deployments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from unittest.mock import patch

from requests import HTTPError

from giza.commands.deployments import DeploymentsClient, cairo
from giza.schemas.deployments import Deployment, DeploymentsList
from tests.conftest import invoke_cli_runner


def test_deploy_with_cairo_framework():
with patch.object(cairo, "deploy") as mock_deploy:
result = invoke_cli_runner(
[
"deployments",
"deploy",
"--model-id",
"1",
"--version-id",
"1",
"--framework",
"CAIRO",
"--size",
"S",
"--debug",
"data_path",
]
)
mock_deploy.assert_called_once()
assert result.exit_code == 0


def test_deploy_with_ezkl_framework():
result = invoke_cli_runner(
[
"deployments",
"deploy",
"--model-id",
"1",
"--version-id",
"1",
"--framework",
"EZKL",
"--size",
"S",
"data_path",
],
expected_error=True,
)
assert "EZKL deployment is not yet supported" in str(result.exception)


def test_deploy_with_unsupported_framework():
result = invoke_cli_runner(
[
"deployments",
"deploy",
"--model-id",
"1",
"--version-id",
"1",
"--framework",
"NONEXISTING",
"--size",
"S",
"data_path",
],
expected_error=True,
)
assert result.exit_code == 2


def test_list_deployments():
deployments_list = DeploymentsList(
__root__=[
Deployment(
id=1,
status="COMPLETED",
uri="https://giza-api.com/deployments/1",
size="S",
service_name="giza-deployment-1",
model_id=1,
version_id=1,
),
Deployment(
id=2,
status="COMPLETED",
uri="https://giza-api.com/deployments/2",
size="S",
service_name="giza-deployment-2",
model_id=1,
version_id=1,
),
]
)
with patch.object(
DeploymentsClient, "list", return_value=deployments_list
) as mock_list:
result = invoke_cli_runner(
["deployments", "list", "--model-id", "1", "--version-id", "1"],
)
mock_list.assert_called_once()
assert result.exit_code == 0
assert "giza-deployment-1" in result.stdout
assert "giza-deployment-2" in result.stdout


def test_list_deployments_http_error():
with patch.object(DeploymentsClient, "list", side_effect=HTTPError):
result = invoke_cli_runner(
["deployments", "list", "--model-id", "1", "--version-id", "1"],
expected_error=True,
)
assert result.exit_code == 1
assert "Could not list deployments" in result.stdout


def test_get_deployment():
deployment = Deployment(
id=1,
status="COMPLETED",
uri="https://giza-api.com/deployments/1",
size="S",
service_name="giza-deployment-1",
model_id=1,
version_id=1,
)
with patch.object(
DeploymentsClient, "get", return_value=deployment
) as mock_deployment:
result = invoke_cli_runner(
[
"deployments",
"get",
"--model-id",
"1",
"--version-id",
"1",
"--deployment-id",
"1",
],
)
mock_deployment.assert_called_once()
assert result.exit_code == 0
assert "giza-deployment-1" in result.stdout


def test_get_deployment_http_error():
with patch.object(
DeploymentsClient, "get", side_effect=HTTPError
) as mock_deployment:
result = invoke_cli_runner(
[
"deployments",
"get",
"--model-id",
"1",
"--version-id",
"1",
"--deployment-id",
"1",
],
expected_error=True,
)
mock_deployment.assert_called_once()
assert result.exit_code == 1
assert "Could not get deployment" in result.stdout

0 comments on commit f04a83b

Please sign in to comment.