Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async register API support #1083

Draft
wants to merge 18 commits into
base: aqua_apiserver
Choose a base branch
from
Draft
27 changes: 27 additions & 0 deletions ads/aqua/common/task_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python

# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from dataclasses import dataclass

from ads.common.extended_enum import ExtendedEnum
from ads.common.serializer import DataClassSerializable


class TaskStatusEnum(ExtendedEnum):
MODEL_VALIDATION_SUCCESSFUL = "MODEL_VALIDATION_SUCCESSFUL"
MODEL_DOWNLOAD_STARTED = "MODEL_DOWNLOAD_STARTED"
MODEL_DOWNLOAD_SUCCESSFUL = "MODEL_DOWNLOAD_SUCCESSFUL"
MODEL_UPLOAD_STARTED = "MODEL_UPLOAD_STARTED"
MODEL_UPLOAD_SUCCESSFUL = "MODEL_UPLOAD_SUCCESSFUL"
DATASCIENCE_MODEL_CREATED = "DATASCIENCE_MODEL_CREATED"
MODEL_REGISTRATION_SUCCESSFUL = "MODEL_REGISTRATION_SUCCESSFUL"
REGISTRATION_FAILED = "REGISTRATION_FAILED"
MODEL_DOWNLOAD_INPROGRESS = "MODEL_DOWNLOAD_INPROGRESS"


@dataclass
class TaskStatus(DataClassSerializable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not pydantic?

state: TaskStatusEnum
message: str
4 changes: 4 additions & 0 deletions ads/aqua/extension/aqua_ws_msg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List

from tornado.web import HTTPError
from tornado.websocket import WebSocketHandler

from ads.aqua import logger
from ads.aqua.common.decorator import handle_exceptions
Expand Down Expand Up @@ -53,6 +54,9 @@ def process(self) -> BaseResponse:
"""
pass

def set_ws_connection(self, con: WebSocketHandler):
self.ws_connection = con

def write_error(self, status_code, **kwargs):
"""AquaWSMSGhandler errors are JSON, not human pages."""
reason = kwargs.get("reason")
Expand Down
93 changes: 73 additions & 20 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import threading
from logging import getLogger
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4

from tornado.web import HTTPError

Expand All @@ -12,17 +15,29 @@
CustomInferenceContainerTypeFamily,
)
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum
from ads.aqua.common.utils import (
get_hf_model_info,
is_valid_ocid,
list_hf_models,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.status_manager import (
RegistrationStatus,
StatusTracker,
TaskNameEnum,
)
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.aqua.model.entities import (
AquaModel,
AquaModelSummary,
HFModelSummary,
)
from ads.aqua.ui import ModelFormat

logger = getLogger(__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the other places we use - from ads.aqua import logger.



class AquaModelHandler(AquaAPIhandler):
"""Handler for Aqua Model REST APIs."""
Expand Down Expand Up @@ -108,6 +123,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
HTTPError
Raises HTTPError if inputs are missing or are invalid
"""
task_id = str(uuid4())
try:
input_data = self.get_json_body()
except Exception as ex:
Expand Down Expand Up @@ -145,27 +161,64 @@ def post(self, *args, **kwargs): # noqa: ARG002
str(input_data.get("ignore_model_artifact_check", "false")).lower()
== "true"
)
async_mode = input_data.get("async_mode", False)

return self.finish(
AquaModelApp().register(
model=model,
os_path=os_path,
download_from_hf=download_from_hf,
local_dir=local_dir,
cleanup_model_cache=cleanup_model_cache,
inference_container=inference_container,
finetuning_container=finetuning_container,
compartment_id=compartment_id,
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
ignore_model_artifact_check=ignore_model_artifact_check,
def register_model(callback=None) -> AquaModel:
"""Wrapper method to help initialize callback in case of async mode"""
try:
registered_model = AquaModelApp().register(
model=model,
os_path=os_path,
download_from_hf=download_from_hf,
local_dir=local_dir,
cleanup_model_cache=cleanup_model_cache,
inference_container=inference_container,
finetuning_container=finetuning_container,
compartment_id=compartment_id,
project_id=project_id,
model_file=model_file,
inference_container_uri=inference_container_uri,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
ignore_model_artifact_check=ignore_model_artifact_check,
callback=callback,
)
except Exception as e:
if async_mode:
StatusTracker.add_status(
RegistrationStatus(
task_id=task_id,
task_status=TaskStatus(
state=TaskStatusEnum.REGISTRATION_FAILED, message=str(e)
),
)
)
raise
else:
raise
return registered_model

if async_mode:
t = threading.Thread(
Copy link
Member

@mrDzurb mrDzurb Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to use a ThreadPool instead to control the number of the potential threads?

Something like:

THREAD_POOL_EXECUTOR = ThreadPoolExecutor(max_workers=10)
if async_mode:
            # Submit the registration task to a thread pool.
            THREAD_POOL_EXECUTOR.submit(self._register_model, task_id, input_data, async_mode)
            output = {
                "state": "ACCEPTED",
                "task_id": task_id,
                "progress_url": f"ws://host:port/aqua/ws/{task_id}",
            }
        else:
            output = self._register_model(task_id, input_data, async_mode)

Maybe we can introduce some global ThreadPoolExecutor for this?
I'm wondering if we can use a decorator for this, something similar that we do for in @threaded decorator.

THREAD_POOL_EXECUTOR = ThreadPoolExecutor(max_workers=10)

def run_in_thread_if_async(func):
    """Decorator to run the function in a thread if async_mode is True."""
    @wraps(func)
    def wrapper(self, async_mode, *args, **kwargs):
        if async_mode:
            task_id = str(uuid4())
            future = THREAD_POOL_EXECUTOR.submit(func, self, task_id, *args, **kwargs)
            return {
                "state": "ACCEPTED",
                "task_id": task_id,
                "progress_url": f"ws://host:port/aqua/ws/{task_id}",
            }
        else:
            return func(self, None, *args, **kwargs)
    return wrapper

I think the decorator could also take care of the StatusTracker.

In this case we just mark any desired function with the @run_in_thread_if_async decorator which will do all the related work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrDzurb Does threadpool allow for daemon threads? I need daemon threads here.

target=register_model,
args=(
StatusTracker.prepare_status_callback(
TaskNameEnum.REGISTRATION_STATUS, task_id=task_id
),
),
daemon=True,
)
)
t.start()
output = {
"state": "ACCEPTED",
"task_id": task_id,
"progress_url": f"ws://host:port/aqua/ws/{task_id}",
}
else:
output = register_model()
return self.finish(output)

@handle_exceptions
def put(self, id):
Expand Down
14 changes: 14 additions & 0 deletions ads/aqua/extension/models/ws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RequestResponseType(ExtendedEnum):
AdsVersion = "AdsVersion"
CompatibilityCheck = "CompatibilityCheck"
Error = "Error"
RegisterModelStatus = "RegisterModelStatus"


@dataclass
Expand Down Expand Up @@ -141,3 +142,16 @@ class AquaWsError(DataClassSerializable):
class ErrorResponse(BaseResponse):
data: AquaWsError
kind = RequestResponseType.Error


@dataclass
class RequestStatus(DataClassSerializable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: For the new code, pydantic would be better to use?

status: str
message: str


@dataclass
class ModelRegisterRequest(DataClassSerializable):
status: str
task_id: str
message: str = ""
45 changes: 43 additions & 2 deletions ads/aqua/extension/models_ws_msg_handler.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
#!/usr/bin/env python

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
from logging import getLogger
from typing import List, Union

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.models.ws_models import (
ListModelsResponse,
ModelDetailsResponse,
ModelRegisterRequest,
RequestResponseType,
)
from ads.aqua.extension.status_manager import (
RegistrationSubscriber,
StatusTracker,
TaskNameEnum,
)
from ads.aqua.model import AquaModelApp

logger = getLogger(__name__)

REGISTRATION_STATUS = "registration_status"


class AquaModelWSMsgHandler(AquaWSMsgHandler):
status_subscriber = {}
register_status = {} # Not threadsafe

def __init__(self, message: Union[str, bytes]):
super().__init__(message)

@staticmethod
def get_message_types() -> List[RequestResponseType]:
return [RequestResponseType.ListModels, RequestResponseType.ModelDetails]
return [
RequestResponseType.ListModels,
RequestResponseType.ModelDetails,
RequestResponseType.RegisterModelStatus,
]

@handle_exceptions
def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
Expand All @@ -47,3 +65,26 @@ def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]:
kind=RequestResponseType.ModelDetails,
data=response,
)
elif request.get("kind") == "RegisterModelStatus":
task_id = request.get("task_id")
StatusTracker.add_subscriber(
subscriber=RegistrationSubscriber(
task_id=task_id, subscriber=self.ws_connection
),
notify_latest_status=False,
)

latest_status = StatusTracker.get_latest_status(
TaskNameEnum.REGISTRATION_STATUS, task_id=task_id
)
logger.info(latest_status)
if latest_status:
return ModelRegisterRequest(
status=latest_status.state,
message=latest_status.message,
task_id=task_id,
)
else:
return ModelRegisterRequest(
status="SUBSCRIBED", task_id=task_id, message=""
)
Loading