-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async register API support #1083
base: aqua_apiserver
Are you sure you want to change the base?
Changes from all commits
53305b8
ee8dc5e
8be1a97
8180950
2d4e990
de82312
f6c8c65
ebbd175
b6cf7ba
a244dc5
99871a1
7ca41c5
722595c
fd3051e
876f826
16658b5
4813195
8de4ad8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright (c) 2025 Oracle and/or its affiliates. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
||
from dataclasses import dataclass | ||
|
||
from ads.common.extended_enum import ExtendedEnum | ||
from ads.common.serializer import DataClassSerializable | ||
|
||
|
||
class TaskStatusEnum(ExtendedEnum): | ||
MODEL_VALIDATION_SUCCESSFUL = "MODEL_VALIDATION_SUCCESSFUL" | ||
MODEL_DOWNLOAD_STARTED = "MODEL_DOWNLOAD_STARTED" | ||
MODEL_DOWNLOAD_SUCCESSFUL = "MODEL_DOWNLOAD_SUCCESSFUL" | ||
MODEL_UPLOAD_STARTED = "MODEL_UPLOAD_STARTED" | ||
MODEL_UPLOAD_SUCCESSFUL = "MODEL_UPLOAD_SUCCESSFUL" | ||
DATASCIENCE_MODEL_CREATED = "DATASCIENCE_MODEL_CREATED" | ||
MODEL_REGISTRATION_SUCCESSFUL = "MODEL_REGISTRATION_SUCCESSFUL" | ||
REGISTRATION_FAILED = "REGISTRATION_FAILED" | ||
MODEL_DOWNLOAD_INPROGRESS = "MODEL_DOWNLOAD_INPROGRESS" | ||
|
||
|
||
@dataclass | ||
class TaskStatus(DataClassSerializable): | ||
state: TaskStatusEnum | ||
message: str |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,11 @@ | |
# Copyright (c) 2024, 2025 Oracle and/or its affiliates. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
||
import threading | ||
from logging import getLogger | ||
from typing import Optional | ||
from urllib.parse import urlparse | ||
from uuid import uuid4 | ||
|
||
from tornado.web import HTTPError | ||
|
||
|
@@ -12,17 +15,29 @@ | |
CustomInferenceContainerTypeFamily, | ||
) | ||
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError | ||
from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum | ||
from ads.aqua.common.utils import ( | ||
get_hf_model_info, | ||
is_valid_ocid, | ||
list_hf_models, | ||
) | ||
from ads.aqua.extension.base_handler import AquaAPIhandler | ||
from ads.aqua.extension.errors import Errors | ||
from ads.aqua.extension.status_manager import ( | ||
RegistrationStatus, | ||
StatusTracker, | ||
TaskNameEnum, | ||
) | ||
from ads.aqua.model import AquaModelApp | ||
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary | ||
from ads.aqua.model.entities import ( | ||
AquaModel, | ||
AquaModelSummary, | ||
HFModelSummary, | ||
) | ||
from ads.aqua.ui import ModelFormat | ||
|
||
logger = getLogger(__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in the other places we use - |
||
|
||
|
||
class AquaModelHandler(AquaAPIhandler): | ||
"""Handler for Aqua Model REST APIs.""" | ||
|
@@ -108,6 +123,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
HTTPError | ||
Raises HTTPError if inputs are missing or are invalid | ||
""" | ||
task_id = str(uuid4()) | ||
try: | ||
input_data = self.get_json_body() | ||
except Exception as ex: | ||
|
@@ -145,27 +161,64 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
str(input_data.get("ignore_model_artifact_check", "false")).lower() | ||
== "true" | ||
) | ||
async_mode = input_data.get("async_mode", False) | ||
|
||
return self.finish( | ||
AquaModelApp().register( | ||
model=model, | ||
os_path=os_path, | ||
download_from_hf=download_from_hf, | ||
local_dir=local_dir, | ||
cleanup_model_cache=cleanup_model_cache, | ||
inference_container=inference_container, | ||
finetuning_container=finetuning_container, | ||
compartment_id=compartment_id, | ||
project_id=project_id, | ||
model_file=model_file, | ||
inference_container_uri=inference_container_uri, | ||
allow_patterns=allow_patterns, | ||
ignore_patterns=ignore_patterns, | ||
freeform_tags=freeform_tags, | ||
defined_tags=defined_tags, | ||
ignore_model_artifact_check=ignore_model_artifact_check, | ||
def register_model(callback=None) -> AquaModel: | ||
"""Wrapper method to help initialize callback in case of async mode""" | ||
try: | ||
registered_model = AquaModelApp().register( | ||
model=model, | ||
os_path=os_path, | ||
download_from_hf=download_from_hf, | ||
local_dir=local_dir, | ||
cleanup_model_cache=cleanup_model_cache, | ||
inference_container=inference_container, | ||
finetuning_container=finetuning_container, | ||
compartment_id=compartment_id, | ||
project_id=project_id, | ||
model_file=model_file, | ||
inference_container_uri=inference_container_uri, | ||
allow_patterns=allow_patterns, | ||
ignore_patterns=ignore_patterns, | ||
freeform_tags=freeform_tags, | ||
defined_tags=defined_tags, | ||
ignore_model_artifact_check=ignore_model_artifact_check, | ||
callback=callback, | ||
) | ||
except Exception as e: | ||
if async_mode: | ||
StatusTracker.add_status( | ||
RegistrationStatus( | ||
task_id=task_id, | ||
task_status=TaskStatus( | ||
state=TaskStatusEnum.REGISTRATION_FAILED, message=str(e) | ||
), | ||
) | ||
) | ||
raise | ||
else: | ||
raise | ||
return registered_model | ||
|
||
if async_mode: | ||
t = threading.Thread( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to use a ThreadPool instead to control the number of the potential threads? Something like:
Maybe we can introduce some global ThreadPoolExecutor for this?
I think the decorator could also take care of the In this case we just mark any desired function with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mrDzurb Does threadpool allow for daemon threads? I need daemon threads here. |
||
target=register_model, | ||
args=( | ||
StatusTracker.prepare_status_callback( | ||
TaskNameEnum.REGISTRATION_STATUS, task_id=task_id | ||
), | ||
), | ||
daemon=True, | ||
) | ||
) | ||
t.start() | ||
output = { | ||
"state": "ACCEPTED", | ||
"task_id": task_id, | ||
"progress_url": f"ws://host:port/aqua/ws/{task_id}", | ||
} | ||
else: | ||
output = register_model() | ||
return self.finish(output) | ||
|
||
@handle_exceptions | ||
def put(self, id): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ class RequestResponseType(ExtendedEnum): | |
AdsVersion = "AdsVersion" | ||
CompatibilityCheck = "CompatibilityCheck" | ||
Error = "Error" | ||
RegisterModelStatus = "RegisterModelStatus" | ||
|
||
|
||
@dataclass | ||
|
@@ -141,3 +142,16 @@ class AquaWsError(DataClassSerializable): | |
class ErrorResponse(BaseResponse): | ||
data: AquaWsError | ||
kind = RequestResponseType.Error | ||
|
||
|
||
@dataclass | ||
class RequestStatus(DataClassSerializable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: For the new code, pydantic would be better to use? |
||
status: str | ||
message: str | ||
|
||
|
||
@dataclass | ||
class ModelRegisterRequest(DataClassSerializable): | ||
status: str | ||
task_id: str | ||
message: str = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not pydantic?