Skip to content

Commit

Permalink
Merge pull request #28 from gizatechxyz/feature/download-original-model
Browse files Browse the repository at this point in the history
Add download original model command. Add version check. Bump to 0.6.0
  • Loading branch information
Gonmeso authored Dec 12, 2023
2 parents 3decc58 + 06874c6 commit 624b583
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 27 deletions.
2 changes: 1 addition & 1 deletion giza/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

__version__ = "0.4.0"
__version__ = "0.6.0"
# Until DNS is fixed
API_HOST = os.environ.get("GIZA_API_HOST", "https://api.gizatech.xyz")
8 changes: 5 additions & 3 deletions giza/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from giza.commands.reset_password import request_reset_password_token, reset_password
from giza.commands.users import app as users_app
from giza.commands.verify import verify
from giza.commands.version import version_entrypoint
from giza.commands.version import check_version
from giza.commands.versions import app as versions_app
from giza.commands.versions import transpile

Expand All @@ -33,8 +33,10 @@
name="giza",
help="""
🔶 Giza-CLI to manage the resources at Giza 🔶.
""",
)(version_entrypoint)
""",
invoke_without_command=True,
)(check_version)


app.add_typer(
versions_app,
Expand Down
34 changes: 34 additions & 0 deletions giza/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,40 @@ def download(self, model_id: int, version_id: int) -> bytes:

return download_response.content

@auth
def download_original(self, model_id: int, version_id: int) -> bytes:
"""
Download the original version.
Args:
model_id: Model identifier
version_id: Version identifier
Returns:
The version binary file
"""
headers = copy.deepcopy(self.default_headers)
headers.update(self._get_auth_header())

response = self.session.get(
f"{self._get_version_url(model_id)}/{version_id}:download_original",
headers=headers,
)

self._echo_debug(str(response))
response.raise_for_status()

url = response.json()["download_url"]

download_response = self.session.get(
url, headers={"Content-Type": "application/octet-stream"}
)

self._echo_debug(str(download_response))
download_response.raise_for_status()

return download_response.content

def _upload(self, upload_url: str, f: BufferedReader) -> None:
"""
Upload the file to the specified url.
Expand Down
4 changes: 2 additions & 2 deletions giza/commands/prove.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

def prove(
data: str = typer.Argument(None),
model_id: Optional[str] = typer.Option(None, "--model-id", "-m"),
version_id: Optional[str] = typer.Option(None, "--version-id", "-v"),
model_id: Optional[int] = typer.Option(None, "--model-id", "-m"),
version_id: Optional[int] = typer.Option(None, "--version-id", "-v"),
size: JobSize = typer.Option(JobSize.S, "--size", "-s"),
framework: Framework = typer.Option(Framework.CAIRO, "--framework", "-f"),
output_path: str = typer.Option("zk.proof", "--output-path", "-o"),
Expand Down
29 changes: 15 additions & 14 deletions giza/commands/version.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import requests
import typer

from giza.callbacks import version_callback
from giza import __version__
from giza.utils.echo import Echo


def version_entrypoint(
version: bool = typer.Option(
None,
"--version",
callback=version_callback,
is_eager=True,
), # noqa
) -> None:
def check_version(ctx: typer.Context):
"""
Prints the current CLI version.
Args:
version (bool): Tper callback to retrieve the version.
Check if there is a new version available of the cli in pypi to suggest upgrade
"""
pass
current_version = __version__
response = requests.get("https://pypi.org/pypi/giza/json")
latest_version = response.json()["info"]["version"]

if latest_version > current_version:
echo = Echo()
echo.warning(f"Current version of Giza CLI: {current_version}")
echo.warning(
f"A new version ({latest_version}) is available. Please upgrade :bell:"
)
63 changes: 61 additions & 2 deletions giza/commands/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def transpile(
None, help="The ID of the model where a new version will be created"
),
desc: str = typer.Option(None, help="Description of the version"),
model_desc: int = typer.Option(
model_desc: str = typer.Option(
None, help="Description of the Model to create if model_id is not provided"
),
framework: Framework = typer.Option(Framework.CAIRO, "--framework", "-f"),
Expand Down Expand Up @@ -219,7 +219,7 @@ def list(


@app.command(
short_help="⚡️ Download the transpiled cairo model if available",
short_help="⚡️ Download the transpiled cairo model if available.",
help="""⚡️ Download the transpiled cairo model if available.
Download an unzip a transpiled model.
Expand Down Expand Up @@ -282,3 +282,62 @@ def download(
if debug:
raise e
sys.exit(1)


@app.command(
short_help="⚡️ Download the original ONNX model.",
help="""⚡️ Download the original ONNX model.
Verification and an active token is needed.
""",
)
def download_original(
model_id: int = typer.Option(None, help="The ID of the model"),
version_id: int = typer.Option(None, help="The ID of the version"),
output_path: str = typer.Option(
"model.onnx", "--output-path", "-o", help="Path to output the ONNX model"
),
debug: Optional[bool] = DEBUG_OPTION,
) -> None:
"""
Retrieve information about the current user and print it as json to stdout.
Args:
debug (Optional[bool], optional): Whether to add debug information, will show requests, extra logs and traceback if there is an Exception. Defaults to DEBUG_OPTION (False)
"""

try:
if any([model_id is None, version_id is None]):
raise ValueError("⛔️Model ID and version ID are required⛔️")

client = VersionsClient(API_HOST, debug=debug)
version = client.get(model_id, version_id)

if version.status != VersionStatus.COMPLETED:
raise ValueError(f"Model version status is not completed {version.status}")

echo("ONNX model is ready, downloading! ✅")
onnx_model = client.download_original(model_id, version.version)

with open(output_path, "wb") as f:
f.write(onnx_model)

echo(f"ONNX model saved at: {output_path}")

except ValueError as e:
echo.error(e.args[0])
if debug:
raise e
sys.exit(1)
except HTTPError as e:
info = get_response_info(e.response)
echo.error("⛔️Error at download")
echo.error(f"⛔️Detail -> {info.get('detail')}⛔️")
echo.error(f"⛔️Status code -> {info.get('status_code')}⛔️")
echo.error(f"⛔️Error message -> {info.get('content')}⛔️")
echo.error(
f"⛔️Request ID: Give this to an administrator to trace the error -> {info.get('request_id')}⛔️"
) if info.get("request_id") else None
if debug:
raise e
sys.exit(1)
3 changes: 3 additions & 0 deletions giza/frameworks/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def transpile(
HTTPError: If there is an HTTP error while communicating with the server.
"""
echo = Echo(debug=debug)
if model_path is None:
echo.error("No model name provided, please provide a model path ⛔️")
sys.exit(1)
if model_id is None:
model_name = model_path.split("/")[-1].split(".")[0]
echo("No model id provided, checking if model exists ✅ ")
Expand Down
6 changes: 3 additions & 3 deletions giza/frameworks/ezkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup(
model_path: str,
model_id: int,
desc: str,
model_desc: int,
model_desc: str,
input_data: str,
debug: Optional[bool],
size: JobSize = JobSize.S,
Expand Down Expand Up @@ -229,8 +229,8 @@ def verify(
proof_id: Optional[int],
model_id: Optional[int],
version_id: Optional[int],
proof: str = None,
debug: bool = False,
proof: Optional[str] = None,
debug: Optional[bool] = False,
size: JobSize = JobSize.S,
):
"""
Expand Down
25 changes: 25 additions & 0 deletions giza/utils/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ def format_error(self, message: str) -> str:
"""
return self.format_message(rf"[red]{message}[/red]", "ERROR", "red")

def format_warning(self, message: str) -> str:
"""
Specific format for warning purposes
Args:
message (str): message to format
Returns:
str: error formatted message
"""
yellow = typer.colors.YELLOW
return self.format_message(
rf"[{yellow}]{message}[/{yellow}]", "WARNING", f"{yellow}"
)

def echo(self, message: str, formatted: str) -> None:
"""
Main function to print information of a message, original message is provided as well as the formatted one.
Expand Down Expand Up @@ -111,6 +126,16 @@ def info(self, message: str) -> None:
formatted_message = self.format_message(message)
self.echo(message, formatted_message)

def warning(self, message: str) -> None:
"""
Format and echo a warning message
Args:
message (str): message to format and echo
"""
formatted_message = self.format_warning(message)
self.echo(message, formatted_message)

def __call__(self, message: str) -> None:
"""
Provided as facility to echo through the class instance
Expand Down
Loading

0 comments on commit 624b583

Please sign in to comment.