Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add download original model command. Add version check. Bump to 0.6.0 #28

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion giza/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

__version__ = "0.4.0"
__version__ = "0.6.0"
# Until DNS is fixed
API_HOST = os.environ.get("GIZA_API_HOST", "https://api.gizatech.xyz")
8 changes: 5 additions & 3 deletions giza/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from giza.commands.reset_password import request_reset_password_token, reset_password
from giza.commands.users import app as users_app
from giza.commands.verify import verify
from giza.commands.version import version_entrypoint
from giza.commands.version import check_version
from giza.commands.versions import app as versions_app
from giza.commands.versions import transpile

Expand All @@ -33,8 +33,10 @@
name="giza",
help="""
🔶 Giza-CLI to manage the resources at Giza 🔶.
""",
)(version_entrypoint)
""",
invoke_without_command=True,
)(check_version)


app.add_typer(
versions_app,
Expand Down
34 changes: 34 additions & 0 deletions giza/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,40 @@ def download(self, model_id: int, version_id: int) -> bytes:

return download_response.content

@auth
def download_original(self, model_id: int, version_id: int) -> bytes:
"""
Download the original version.

Args:
model_id: Model identifier
version_id: Version identifier

Returns:
The version binary file
"""
headers = copy.deepcopy(self.default_headers)
headers.update(self._get_auth_header())

response = self.session.get(
f"{self._get_version_url(model_id)}/{version_id}:download_original",
headers=headers,
)

self._echo_debug(str(response))
response.raise_for_status()

url = response.json()["download_url"]

download_response = self.session.get(
url, headers={"Content-Type": "application/octet-stream"}
)

self._echo_debug(str(download_response))
download_response.raise_for_status()

return download_response.content

def _upload(self, upload_url: str, f: BufferedReader) -> None:
"""
Upload the file to the specified url.
Expand Down
4 changes: 2 additions & 2 deletions giza/commands/prove.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

def prove(
data: str = typer.Argument(None),
model_id: Optional[str] = typer.Option(None, "--model-id", "-m"),
version_id: Optional[str] = typer.Option(None, "--version-id", "-v"),
model_id: Optional[int] = typer.Option(None, "--model-id", "-m"),
version_id: Optional[int] = typer.Option(None, "--version-id", "-v"),
size: JobSize = typer.Option(JobSize.S, "--size", "-s"),
framework: Framework = typer.Option(Framework.CAIRO, "--framework", "-f"),
output_path: str = typer.Option("zk.proof", "--output-path", "-o"),
Expand Down
29 changes: 15 additions & 14 deletions giza/commands/version.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import requests
import typer

from giza.callbacks import version_callback
from giza import __version__
from giza.utils.echo import Echo


def version_entrypoint(
version: bool = typer.Option(
None,
"--version",
callback=version_callback,
is_eager=True,
), # noqa
) -> None:
def check_version(ctx: typer.Context):
"""
Prints the current CLI version.

Args:
version (bool): Tper callback to retrieve the version.
Check if there is a new version available of the cli in pypi to suggest upgrade
"""
pass
current_version = __version__
response = requests.get("https://pypi.org/pypi/giza/json")
latest_version = response.json()["info"]["version"]

if latest_version > current_version:
echo = Echo()
echo.warning(f"Current version of Giza CLI: {current_version}")
echo.warning(
f"A new version ({latest_version}) is available. Please upgrade :bell:"
)
63 changes: 61 additions & 2 deletions giza/commands/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def transpile(
None, help="The ID of the model where a new version will be created"
),
desc: str = typer.Option(None, help="Description of the version"),
model_desc: int = typer.Option(
model_desc: str = typer.Option(
None, help="Description of the Model to create if model_id is not provided"
),
framework: Framework = typer.Option(Framework.CAIRO, "--framework", "-f"),
Expand Down Expand Up @@ -219,7 +219,7 @@ def list(


@app.command(
short_help="⚡️ Download the transpiled cairo model if available",
short_help="⚡️ Download the transpiled cairo model if available.",
help="""⚡️ Download the transpiled cairo model if available.

Download an unzip a transpiled model.
Expand Down Expand Up @@ -282,3 +282,62 @@ def download(
if debug:
raise e
sys.exit(1)


@app.command(
short_help="⚡️ Download the original ONNX model.",
help="""⚡️ Download the original ONNX model.

Verification and an active token is needed.
""",
)
def download_original(
model_id: int = typer.Option(None, help="The ID of the model"),
version_id: int = typer.Option(None, help="The ID of the version"),
output_path: str = typer.Option(
"model.onnx", "--output-path", "-o", help="Path to output the ONNX model"
),
debug: Optional[bool] = DEBUG_OPTION,
) -> None:
"""
Retrieve information about the current user and print it as json to stdout.

Args:
debug (Optional[bool], optional): Whether to add debug information, will show requests, extra logs and traceback if there is an Exception. Defaults to DEBUG_OPTION (False)
"""

try:
if any([model_id is None, version_id is None]):
raise ValueError("⛔️Model ID and version ID are required⛔️")

client = VersionsClient(API_HOST, debug=debug)
version = client.get(model_id, version_id)

if version.status != VersionStatus.COMPLETED:
raise ValueError(f"Model version status is not completed {version.status}")

echo("ONNX model is ready, downloading! ✅")
onnx_model = client.download_original(model_id, version.version)

with open(output_path, "wb") as f:
f.write(onnx_model)

echo(f"ONNX model saved at: {output_path}")

except ValueError as e:
echo.error(e.args[0])
if debug:
raise e
sys.exit(1)
except HTTPError as e:
info = get_response_info(e.response)
echo.error("⛔️Error at download")
echo.error(f"⛔️Detail -> {info.get('detail')}⛔️")
echo.error(f"⛔️Status code -> {info.get('status_code')}⛔️")
echo.error(f"⛔️Error message -> {info.get('content')}⛔️")
echo.error(
f"⛔️Request ID: Give this to an administrator to trace the error -> {info.get('request_id')}⛔️"
) if info.get("request_id") else None
if debug:
raise e
sys.exit(1)
3 changes: 3 additions & 0 deletions giza/frameworks/cairo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def transpile(
HTTPError: If there is an HTTP error while communicating with the server.
"""
echo = Echo(debug=debug)
if model_path is None:
echo.error("No model name provided, please provide a model path ⛔️")
sys.exit(1)
if model_id is None:
model_name = model_path.split("/")[-1].split(".")[0]
echo("No model id provided, checking if model exists ✅ ")
Expand Down
6 changes: 3 additions & 3 deletions giza/frameworks/ezkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup(
model_path: str,
model_id: int,
desc: str,
model_desc: int,
model_desc: str,
input_data: str,
debug: Optional[bool],
size: JobSize = JobSize.S,
Expand Down Expand Up @@ -229,8 +229,8 @@ def verify(
proof_id: Optional[int],
model_id: Optional[int],
version_id: Optional[int],
proof: str = None,
debug: bool = False,
proof: Optional[str] = None,
debug: Optional[bool] = False,
size: JobSize = JobSize.S,
):
"""
Expand Down
25 changes: 25 additions & 0 deletions giza/utils/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ def format_error(self, message: str) -> str:
"""
return self.format_message(rf"[red]{message}[/red]", "ERROR", "red")

def format_warning(self, message: str) -> str:
"""
Specific format for warning purposes

Args:
message (str): message to format

Returns:
str: error formatted message
"""
yellow = typer.colors.YELLOW
return self.format_message(
rf"[{yellow}]{message}[/{yellow}]", "WARNING", f"{yellow}"
)

def echo(self, message: str, formatted: str) -> None:
"""
Main function to print information of a message, original message is provided as well as the formatted one.
Expand Down Expand Up @@ -111,6 +126,16 @@ def info(self, message: str) -> None:
formatted_message = self.format_message(message)
self.echo(message, formatted_message)

def warning(self, message: str) -> None:
"""
Format and echo a warning message

Args:
message (str): message to format and echo
"""
formatted_message = self.format_warning(message)
self.echo(message, formatted_message)

def __call__(self, message: str) -> None:
"""
Provided as facility to echo through the class instance
Expand Down
Loading