Skip to content

Commit

Permalink
Merge pull request #42 from gizatechxyz/feature/add-download-flags
Browse files Browse the repository at this point in the history
Add flags to download Sierra files
  • Loading branch information
Gonmeso authored Feb 12, 2024
2 parents 2fbe708 + 918dc4e commit 3f59bec
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 44 deletions.
32 changes: 24 additions & 8 deletions giza/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,13 +1222,16 @@ def create(
return Version(**response.json()), upload_url

@auth
def download(self, model_id: int, version_id: int) -> bytes:
def download(
self, model_id: int, version_id: int, params: Dict
) -> Dict[str, bytes]:
"""
Download a version.
Args:
model_id: Model identifier
version_id: Version identifier
params: Additional parameters to pass to the request
Returns:
The version binary file
Expand All @@ -1239,21 +1242,34 @@ def download(self, model_id: int, version_id: int) -> bytes:
response = self.session.get(
f"{self._get_version_url(model_id)}/{version_id}:download",
headers=headers,
params=params,
)

self._echo_debug(str(response))
response.raise_for_status()

url = response.json()["download_url"]
urls = response.json()
downloads = {}

download_response = self.session.get(
url, headers={"Content-Type": "application/octet-stream"}
)
if params["download_model"] and "download_url" in urls:
model_url = urls["download_url"]
download_response = self.session.get(
model_url,
)

self._echo_debug(str(download_response))
download_response.raise_for_status()
self._echo_debug(str(download_response))
download_response.raise_for_status()
downloads["model"] = download_response.content

return download_response.content
if params["download_sierra"] and "sierra_url" in urls:
sierra_url = urls["sierra_url"]
sierra_response = self.session.get(sierra_url)

sierra_response.raise_for_status()
self._echo_debug(str(sierra_response))
downloads["inference.sierra"] = sierra_response.content

return downloads

@auth
def download_original(self, model_id: int, version_id: int) -> bytes:
Expand Down
53 changes: 39 additions & 14 deletions giza/commands/versions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import zipfile
from io import BytesIO
from typing import Optional
from typing import Dict, Optional

import typer
from pydantic import ValidationError
Expand All @@ -15,6 +14,7 @@
from giza.schemas.versions import Version, VersionList
from giza.utils import echo, get_response_info
from giza.utils.enums import Framework, VersionStatus
from giza.utils.misc import download_model_or_sierra

app = typer.Typer()

Expand Down Expand Up @@ -84,6 +84,16 @@ def transpile(
"-i",
help="The input data to use for the transpilation",
),
download_model: bool = typer.Option(
True,
"--download-model",
help="Download the transpiled model after the transpilation is completed. CAIRO only.",
),
download_sierra: bool = typer.Option(
False,
"--download-sierra",
help="Download the siera file is the modle is fully compatible. CAIRO only.",
),
debug: Optional[bool] = DEBUG_OPTION,
) -> None:
if framework == Framework.CAIRO:
Expand All @@ -93,6 +103,8 @@ def transpile(
desc=desc,
model_desc=model_desc,
output_path=output_path,
download_model=download_model,
download_sierra=download_sierra,
debug=debug,
)
elif framework == Framework.EZKL:
Expand Down Expand Up @@ -233,6 +245,16 @@ def download(
output_path: str = typer.Option(
"cairo_model", "--output-path", "-o", help="Path to output the cairo model"
),
download_model: bool = typer.Option(
False,
"--download-model",
help="Download the transpiled model after the transpilation is completed. CAIRO only.",
),
download_sierra: bool = typer.Option(
False,
"--download-sierra",
help="Download the siera file is the modle is fully compatible. CAIRO only.",
),
debug: Optional[bool] = DEBUG_OPTION,
) -> None:
"""
Expand All @@ -252,19 +274,22 @@ def download(
if version.status != VersionStatus.COMPLETED:
raise ValueError(f"Model version status is not completed {version.status}")

echo("Transpilation is ready, downloading! ✅")
cairo_model = client.download(model_id, version.version)

try:
zip_file = zipfile.ZipFile(BytesIO(cairo_model))
except zipfile.BadZipFile as zip_error:
raise ValueError(
"Something went wrong with the transpiled file", zip_error.args[0]
) from None

zip_file.extractall(output_path)
echo(f"Transpilation saved at: {output_path}")
echo("Data is ready, downloading! ✅")
downloads: Dict[str, bytes] = client.download(
model_id,
version.version,
{"download_model": download_model, "download_sierra": download_sierra},
)

for name, content in downloads.items():
try:
echo(f"Downloading {name} ✅")
download_model_or_sierra(content, output_path, name)
except zipfile.BadZipFile as zip_error:
raise ValueError(
"Something went wrong with the download", zip_error.args[0]
) from None
echo(f"{name} saved at: {output_path}")
except ValueError as e:
echo.error(e.args[0])
if debug:
Expand Down
22 changes: 15 additions & 7 deletions giza/frameworks/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
import time
import zipfile
from io import BytesIO
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -37,6 +36,7 @@
ServiceSize,
VersionStatus,
)
from giza.utils.misc import download_model_or_sierra

app = typer.Typer()

Expand Down Expand Up @@ -212,6 +212,8 @@ def transpile(
desc: str,
model_desc: str,
output_path: str,
download_model: bool,
download_sierra: bool,
debug: Optional[bool],
) -> None:
"""
Expand All @@ -232,6 +234,8 @@ def transpile(
desc (int, optional): Description of the version. Defaults to None.
model_desc (int, optional): Description of the Model to create if model_id is not provided. Defaults to None.
output_path (str, optional): The path where the cairo model will be saved. Defaults to "cairo_model".
download_model (bool): A flag used to determine whether to download the model or not.
download_sierra (bool): A flag used to determine whether to download the sierra or not.
debug (bool, optional): A flag used to determine whether to raise exceptions or not. Defaults to DEBUG_OPTION.
Raises:
Expand Down Expand Up @@ -351,20 +355,24 @@ def transpile(
raise e
sys.exit(1)

echo("Transpilation recieved! ✅")
try:
cairo_model = client.download(model.id, version.version)
zip_file = zipfile.ZipFile(BytesIO(cairo_model))
if download_model or download_sierra:
params = {
"download_model": download_model,
"download_sierra": download_sierra,
}
downloads = client.download(model.id, version.version, params)
for name, content in downloads.items():
echo(f"Downloading {name} ✅")
download_model_or_sierra(content, output_path, name)
echo(f"{name} saved at: {output_path}")
except zipfile.BadZipFile as zip_error:
echo.error("Something went wrong with the transpiled file")
echo.error(f"Error -> {zip_error.args[0]}")
if debug:
raise zip_error
sys.exit(1)

zip_file.extractall(output_path)
echo(f"Transpilation saved at: {output_path}")


def verify(
proof_id: Optional[int],
Expand Down
27 changes: 27 additions & 0 deletions giza/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import re
import zipfile
from io import BytesIO
from typing import Optional

from giza.exceptions import PasswordError

Expand All @@ -18,3 +22,26 @@ def _check_password_strength(password: str) -> None:
raise PasswordError(
"Password must be at least 8 characters long, contain at least one uppercase letter, one lowercase letter and one number."
)


def download_model_or_sierra(
content: bytes, output_path: str, name: Optional[str] = None
):
"""
Download the model or sierra file.
Args:
content (bytes): file content
output_path (str): path to save the file
name (str): file name. Defaults to None.
"""
f = BytesIO(content)
is_zip = zipfile.is_zipfile(f)
if not is_zip and name is not None:
if not os.path.exists(output_path):
os.makedirs(output_path)
with open(os.path.join(output_path, name), "wb") as file_:
file_.write(content)
else:
zip_file = zipfile.ZipFile(f)
zip_file.extractall(output_path)
31 changes: 16 additions & 15 deletions tests/commands/test_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def return_content():
tmp = BytesIO()
with zipfile.ZipFile(tmp, mode="w", compression=zipfile.ZIP_DEFLATED) as f:
f.writestr("file1.txt", "hi")
return tmp.getvalue()
return {"model": tmp.getvalue()}

models = ModelList(
__root__=[
Expand Down Expand Up @@ -168,7 +168,8 @@ def return_content():
# Called twice, once to open the model and second to write the zip
mock_open.assert_called()
assert "Reading model from path" in result.stdout
assert "Transpilation recieved" in result.stdout
assert "Downloading model" in result.stdout
assert "model saved at" in result.stdout
assert result.exit_code == 0


Expand All @@ -190,10 +191,10 @@ def test_versions_transpile_http_error(tmpdir):
assert result.exit_code == 1


# Test version transpilation with bad zip file
def test_versions_transpile_bad_zip(tmpdir):
# Test version transpilation with file
def test_versions_transpile_file(tmpdir):
def return_content():
return b"some bytes"
return {"model": b"some bytes"}

models = ModelList(
__root__=[
Expand Down Expand Up @@ -240,9 +241,9 @@ def return_content():

# Called twice, once to open the model and second to write the zip
mock_open.assert_called()
assert "Something went wrong" in result.stdout
assert "Error ->" in result.stdout
assert result.exit_code == 1
assert "Transpilation is fully compatible" in result.stdout
assert "Downloading model" in result.stdout
assert result.exit_code == 0


# Test successful version download
Expand All @@ -260,7 +261,7 @@ def return_content():
tmp = BytesIO()
with zipfile.ZipFile(tmp, mode="w", compression=zipfile.ZIP_DEFLATED) as f:
f.writestr("file1.txt", "hi")
return tmp.getvalue()
return {"model": tmp.getvalue()}

with patch.object(VersionsClient, "get", return_value=version), patch.object(
VersionsClient, "download", return_value=return_content()
Expand All @@ -281,7 +282,7 @@ def return_content():
)

mock_open.assert_called()
assert "Transpilation is ready, downloading!" in result.stdout
assert "Data is ready, downloading!" in result.stdout
assert result.exit_code == 0


Expand All @@ -308,8 +309,8 @@ def test_versions_download_server_error():
assert "Error at download" in result.stdout


# Test version download with bad zip file
def test_versions_download_bad_zip(tmpdir):
# Test version download but a file, imitating a sierra file
def test_versions_download_file(tmpdir):
version = Version(
version=1,
size=1,
Expand All @@ -320,7 +321,7 @@ def test_versions_download_bad_zip(tmpdir):
)

def return_content():
return b"some bytes"
return {"model": b"some bytes"}

with patch.object(VersionsClient, "get", return_value=version), patch.object(
VersionsClient, "download", return_value=return_content()
Expand All @@ -339,8 +340,8 @@ def return_content():
expected_error=True,
)

assert "Something went wrong with the transpiled file" in result.stdout
assert result.exit_code == 1
assert "saved at" in result.stdout
assert result.exit_code == 0


# Test version download with missing model_id and version_id
Expand Down

0 comments on commit 3f59bec

Please sign in to comment.