diff --git a/ads/aqua/common/task_status.py b/ads/aqua/common/task_status.py new file mode 100644 index 000000000..6427844c5 --- /dev/null +++ b/ads/aqua/common/task_status.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from dataclasses import dataclass + +from ads.common.extended_enum import ExtendedEnum +from ads.common.serializer import DataClassSerializable + + +class TaskStatusEnum(ExtendedEnum): + MODEL_VALIDATION_SUCCESSFUL = "MODEL_VALIDATION_SUCCESSFUL" + MODEL_DOWNLOAD_STARTED = "MODEL_DOWNLOAD_STARTED" + MODEL_DOWNLOAD_SUCCESSFUL = "MODEL_DOWNLOAD_SUCCESSFUL" + MODEL_UPLOAD_STARTED = "MODEL_UPLOAD_STARTED" + MODEL_UPLOAD_SUCCESSFUL = "MODEL_UPLOAD_SUCCESSFUL" + DATASCIENCE_MODEL_CREATED = "DATASCIENCE_MODEL_CREATED" + MODEL_REGISTRATION_SUCCESSFUL = "MODEL_REGISTRATION_SUCCESSFUL" + REGISTRATION_FAILED = "REGISTRATION_FAILED" + MODEL_DOWNLOAD_INPROGRESS = "MODEL_DOWNLOAD_INPROGRESS" + + +@dataclass +class TaskStatus(DataClassSerializable): + state: TaskStatusEnum + message: str diff --git a/ads/aqua/extension/aqua_ws_msg_handler.py b/ads/aqua/extension/aqua_ws_msg_handler.py index 1fcbbf946..373fdd154 100644 --- a/ads/aqua/extension/aqua_ws_msg_handler.py +++ b/ads/aqua/extension/aqua_ws_msg_handler.py @@ -10,6 +10,7 @@ from typing import List from tornado.web import HTTPError +from tornado.websocket import WebSocketHandler from ads.aqua import logger from ads.aqua.common.decorator import handle_exceptions @@ -53,6 +54,9 @@ def process(self) -> BaseResponse: """ pass + def set_ws_connection(self, con: WebSocketHandler): + self.ws_connection = con + def write_error(self, status_code, **kwargs): """AquaWSMSGhandler errors are JSON, not human pages.""" reason = kwargs.get("reason") diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index a5b89f8d1..b14ea0d00 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -2,8 +2,11 @@ # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import threading +from logging import getLogger from typing import Optional from urllib.parse import urlparse +from uuid import uuid4 from tornado.web import HTTPError @@ -12,6 +15,7 @@ CustomInferenceContainerTypeFamily, ) from ads.aqua.common.errors import AquaRuntimeError, AquaValueError +from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum from ads.aqua.common.utils import ( get_hf_model_info, is_valid_ocid, @@ -19,10 +23,21 @@ ) from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors +from ads.aqua.extension.status_manager import ( + RegistrationStatus, + StatusTracker, + TaskNameEnum, +) from ads.aqua.model import AquaModelApp -from ads.aqua.model.entities import AquaModelSummary, HFModelSummary +from ads.aqua.model.entities import ( + AquaModel, + AquaModelSummary, + HFModelSummary, +) from ads.aqua.ui import ModelFormat +logger = getLogger(__name__) + class AquaModelHandler(AquaAPIhandler): """Handler for Aqua Model REST APIs.""" @@ -108,6 +123,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 HTTPError Raises HTTPError if inputs are missing or are invalid """ + task_id = str(uuid4()) try: input_data = self.get_json_body() except Exception as ex: @@ -145,27 +161,64 @@ def post(self, *args, **kwargs): # noqa: ARG002 str(input_data.get("ignore_model_artifact_check", "false")).lower() == "true" ) + async_mode = input_data.get("async_mode", False) - return self.finish( - AquaModelApp().register( - model=model, - os_path=os_path, - download_from_hf=download_from_hf, - local_dir=local_dir, - cleanup_model_cache=cleanup_model_cache, - inference_container=inference_container, - finetuning_container=finetuning_container, - compartment_id=compartment_id, - project_id=project_id, - model_file=model_file, - inference_container_uri=inference_container_uri, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - freeform_tags=freeform_tags, - defined_tags=defined_tags, - ignore_model_artifact_check=ignore_model_artifact_check, + def register_model(callback=None) -> AquaModel: + """Wrapper method to help initialize callback in case of async mode""" + try: + registered_model = AquaModelApp().register( + model=model, + os_path=os_path, + download_from_hf=download_from_hf, + local_dir=local_dir, + cleanup_model_cache=cleanup_model_cache, + inference_container=inference_container, + finetuning_container=finetuning_container, + compartment_id=compartment_id, + project_id=project_id, + model_file=model_file, + inference_container_uri=inference_container_uri, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + freeform_tags=freeform_tags, + defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, + callback=callback, + ) + except Exception as e: + if async_mode: + StatusTracker.add_status( + RegistrationStatus( + task_id=task_id, + task_status=TaskStatus( + state=TaskStatusEnum.REGISTRATION_FAILED, message=str(e) + ), + ) + ) + raise + else: + raise + return registered_model + + if async_mode: + t = threading.Thread( + target=register_model, + args=( + StatusTracker.prepare_status_callback( + TaskNameEnum.REGISTRATION_STATUS, task_id=task_id + ), + ), + daemon=True, ) - ) + t.start() + output = { + "state": "ACCEPTED", + "task_id": task_id, + "progress_url": f"ws://host:port/aqua/ws/{task_id}", + } + else: + output = register_model() + return self.finish(output) @handle_exceptions def put(self, id): diff --git a/ads/aqua/extension/models/ws_models.py b/ads/aqua/extension/models/ws_models.py index 38432e22b..ef6420e81 100644 --- a/ads/aqua/extension/models/ws_models.py +++ b/ads/aqua/extension/models/ws_models.py @@ -23,6 +23,7 @@ class RequestResponseType(ExtendedEnum): AdsVersion = "AdsVersion" CompatibilityCheck = "CompatibilityCheck" Error = "Error" + RegisterModelStatus = "RegisterModelStatus" @dataclass @@ -141,3 +142,16 @@ class AquaWsError(DataClassSerializable): class ErrorResponse(BaseResponse): data: AquaWsError kind = RequestResponseType.Error + + +@dataclass +class RequestStatus(DataClassSerializable): + status: str + message: str + + +@dataclass +class ModelRegisterRequest(DataClassSerializable): + status: str + task_id: str + message: str = "" diff --git a/ads/aqua/extension/models_ws_msg_handler.py b/ads/aqua/extension/models_ws_msg_handler.py index 8df4a0232..e18fee2ed 100644 --- a/ads/aqua/extension/models_ws_msg_handler.py +++ b/ads/aqua/extension/models_ws_msg_handler.py @@ -1,9 +1,10 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import json +from logging import getLogger from typing import List, Union from ads.aqua.common.decorator import handle_exceptions @@ -11,18 +12,35 @@ from ads.aqua.extension.models.ws_models import ( ListModelsResponse, ModelDetailsResponse, + ModelRegisterRequest, RequestResponseType, ) +from ads.aqua.extension.status_manager import ( + RegistrationSubscriber, + StatusTracker, + TaskNameEnum, +) from ads.aqua.model import AquaModelApp +logger = getLogger(__name__) + +REGISTRATION_STATUS = "registration_status" + class AquaModelWSMsgHandler(AquaWSMsgHandler): + status_subscriber = {} + register_status = {} # Not threadsafe + def __init__(self, message: Union[str, bytes]): super().__init__(message) @staticmethod def get_message_types() -> List[RequestResponseType]: - return [RequestResponseType.ListModels, RequestResponseType.ModelDetails] + return [ + RequestResponseType.ListModels, + RequestResponseType.ModelDetails, + RequestResponseType.RegisterModelStatus, + ] @handle_exceptions def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]: @@ -47,3 +65,26 @@ def process(self) -> Union[ListModelsResponse, ModelDetailsResponse]: kind=RequestResponseType.ModelDetails, data=response, ) + elif request.get("kind") == "RegisterModelStatus": + task_id = request.get("task_id") + StatusTracker.add_subscriber( + subscriber=RegistrationSubscriber( + task_id=task_id, subscriber=self.ws_connection + ), + notify_latest_status=False, + ) + + latest_status = StatusTracker.get_latest_status( + TaskNameEnum.REGISTRATION_STATUS, task_id=task_id + ) + logger.info(latest_status) + if latest_status: + return ModelRegisterRequest( + status=latest_status.state, + message=latest_status.message, + task_id=task_id, + ) + else: + return ModelRegisterRequest( + status="SUBSCRIBED", task_id=task_id, message="" + ) diff --git a/ads/aqua/extension/status_manager.py b/ads/aqua/extension/status_manager.py new file mode 100644 index 000000000..b989745f2 --- /dev/null +++ b/ads/aqua/extension/status_manager.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import threading +from dataclasses import dataclass, field +from functools import partial +from logging import getLogger +from typing import Callable, Dict, List, Union + +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler + +from ads.aqua.common.task_status import TaskStatus +from ads.common.extended_enum import ExtendedEnum + +logger = getLogger(__name__) + + +class TaskNameEnum(ExtendedEnum): + REGISTRATION_STATUS = "REGISTRATION_STATUS" + + +@dataclass +class Task: + task_name: TaskNameEnum = None + task_id: str = None + + +@dataclass +class Status(Task): + task_status: TaskStatus = None + + +@dataclass +class Subscriber(Task): + subscriber: WebSocketHandler = None + + +@dataclass +class RegistrationStatus(Status): + task_name: TaskNameEnum = TaskNameEnum.REGISTRATION_STATUS + + +@dataclass +class RegistrationSubscriber(Subscriber): + task_name: TaskNameEnum = TaskNameEnum.REGISTRATION_STATUS + + +@dataclass +class StatusSubscription: + task_status_list: List[TaskStatus] = field(default_factory=list) + subscribers: List[Subscriber] = field(default_factory=list) + + +class StatusTracker: + lock = threading.RLock() + """ + Maintains a mapping of task statuses and subscribers for notifications. + Example: + { + "REGISTRATION_STATUS": { + "sample-task-id": StatusSubscription( + task_status_list=[TaskStatus(state="MODEL_DOWNLOAD_INPROGRESS", message="1 out of 10 files downloaded")], + subscribers=[Subscriber(subscriber=websocket123)] + ) + } + } + """ + status: Dict[TaskNameEnum, Dict[str, StatusSubscription]] = {} + + @staticmethod + def get_latest_status( + task_name: TaskNameEnum, task_id: str + ) -> Union[TaskStatus, None]: + """Returns latest task status if availble, else returns None""" + task_list = [] + logger.info(f"Status dump: {StatusTracker.status}") + with StatusTracker.lock: + task_list = ( + StatusTracker.status.get(task_name, {}) + .get(task_id, StatusSubscription()) + .task_status_list + ) + return task_list[-1] if task_list else None + + @staticmethod + def get_statuses(task_name: TaskNameEnum, task_id: str) -> Union[TaskStatus, None]: + """Returns latest task status if availble, else returns None""" + with StatusTracker.lock: + return ( + StatusTracker.status.get(task_name, {}) + .get(task_id, StatusSubscription()) + .task_status_list + ) + + @staticmethod + def add_status(status: Status, notify=True): + """Appends to the status list. Notifies the status to all the subcribers""" + logger.info(f"status: {status}") + with StatusTracker.lock: + if status.task_name not in StatusTracker.status: + StatusTracker.status[status.task_name] = { + status.task_id: StatusSubscription( + task_status_list=[status.task_status] + ) + } + elif status.task_id in StatusTracker.status[status.task_name]: + StatusTracker.status[status.task_name][ + status.task_id + ].task_status_list.append(status.task_status) + else: + StatusTracker.status[status.task_name][status.task_id] = ( + StatusSubscription(task_status_list=[status.task_status]) + ) + # Since there is a task id, Notify subscribers if any + if notify: + StatusTracker.notify_latest_to_all( + task_name=status.task_name, task_id=status.task_id + ) + + @staticmethod + def notify_latest_to_all(task_name: TaskNameEnum, task_id: str): + """Notify the latest task status to all the subscribers""" + task_status = StatusTracker.get_latest_status( + task_name=task_name, task_id=task_id + ) + logger.info(f"status: {task_status}") + subscribers = [] + with StatusTracker.lock: + subscribers = ( + StatusTracker.status.get(task_name, {}) + .get(task_id, StatusSubscription()) + .subscribers + ) + for subscriber in subscribers: + StatusTracker.send_message(status=task_status, subscriber=subscriber) + + @staticmethod + def notify(task_name: TaskNameEnum, subscriber: Subscriber, latest_only=True): + """Notify the subscriber of all the status""" + if latest_only: + task_status = StatusTracker.get_latest_status( + task_name=task_name, task_id=subscriber.task_id + ) + else: + task_status = StatusTracker.get_statuses( + task_name=task_name, task_id=subscriber.task_id + ) + logger.info(task_status) + StatusTracker.send_message(status=task_status, subscriber=subscriber) + + @staticmethod + def send_message(status: TaskStatus, subscriber: Subscriber): + if ( + subscriber + and subscriber.ws_connection + and subscriber.ws_connection.stream.socket + ): + try: + subscriber.write_message(status.to_json()) + except Exception as e: + print(e) + IOLoop.current().add_callback( + lambda: subscriber.write_message(status.to_json()) + ) + + @staticmethod + def add_subscriber(subscriber: Subscriber, notify_latest_status=True): + """Appends to the subscriber list""" + with StatusTracker.lock: + if subscriber.task_name not in StatusTracker.status: + StatusTracker.status[subscriber.task_name] = { + subscriber.task_id: StatusSubscription( + subscribers=[subscriber.subscriber] + ) + } + elif subscriber.task_id in StatusTracker.status[subscriber.task_name]: + StatusTracker.status[subscriber.task_name][ + subscriber.task_id + ].subscribers.append(subscriber.subscriber) + else: + StatusTracker.status[subscriber.task_name][subscriber.task_id] = ( + StatusSubscription(subscribers=[subscriber.subscriber]) + ) + if notify_latest_status: + StatusTracker.notify( + task_name=subscriber.task_name, task_id=subscriber.task_id + ) + + @staticmethod + def prepare_status_callback( + task_name: TaskNameEnum, task_id: str + ) -> Callable[[TaskStatus], None]: + def callback(task_name: TaskNameEnum, task_id: str, status: TaskStatus): + st = Status(task_name=task_name, task_id=task_id, task_status=status) + StatusTracker.add_status(st) + + return partial(callback, task_name=task_name, task_id=task_id) diff --git a/ads/aqua/extension/ui_websocket_handler.py b/ads/aqua/extension/ui_websocket_handler.py index 77dfc301d..6b95bf866 100644 --- a/ads/aqua/extension/ui_websocket_handler.py +++ b/ads/aqua/extension/ui_websocket_handler.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import concurrent.futures from asyncio.futures import Future @@ -46,10 +45,12 @@ def get_aqua_internal_error_response(message_id: str) -> ErrorResponse: class AquaUIWebSocketHandler(WebSocketHandler): """Handler for Aqua Websocket.""" - _handlers_: List[Type[AquaWSMsgHandler]] = [AquaEvaluationWSMsgHandler, - AquaDeploymentWSMsgHandler, - AquaModelWSMsgHandler, - AquaCommonWsMsgHandler] + _handlers_: List[Type[AquaWSMsgHandler]] = [ + AquaEvaluationWSMsgHandler, + AquaDeploymentWSMsgHandler, + AquaModelWSMsgHandler, + AquaCommonWsMsgHandler, + ] thread_pool: ThreadPoolExecutor @@ -98,10 +99,17 @@ def on_message(self, message: Union[str, bytes]): raise ValueError(f"No handler found for message type {request.kind}") else: message_handler = handler(message) + message_handler.set_ws_connection(self) future: Future = self.thread_pool.submit(message_handler.process) self.future_message_map[future] = request future.add_done_callback(self.on_message_processed) + def on_close(self) -> None: + self.thread_pool.shutdown() + logger.info("AQUA WebSocket closed") + + +class AquaAsyncRequestProgressWebSocketHandler(AquaUIWebSocketHandler): def on_message_processed(self, future: concurrent.futures.Future): """Callback function to handle the response from the various AquaWSMsgHandlers.""" try: @@ -120,11 +128,12 @@ def on_message_processed(self, future: concurrent.futures.Future): finally: self.future_message_map.pop(future) # Send the response back to the client on the event thread - IOLoop.current().run_sync(lambda: self.write_message(response.to_json())) - - def on_close(self) -> None: - self.thread_pool.shutdown() - logger.info("AQUA WebSocket closed") + IOLoop.current().add_callback( + lambda: self.write_message(response.to_json()) + ) -__handlers__ = [("ws?([^/]*)", AquaUIWebSocketHandler)] +__handlers__ = [ + ("ws?([^/]*)", AquaUIWebSocketHandler), + ("ws?/progress([^/]*)", AquaAsyncRequestProgressWebSocketHandler), +] diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index 991c67b54..5528c85bf 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -11,7 +11,7 @@ import re from dataclasses import InitVar, dataclass, field -from typing import List, Optional +from typing import Callable, List, Optional import oci from huggingface_hub import hf_api @@ -295,6 +295,7 @@ class ImportModelDetails(CLIBuilderMixin): freeform_tags: Optional[dict] = None defined_tags: Optional[dict] = None ignore_model_artifact_check: Optional[bool] = None + callback: Optional[Callable] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index fff23578f..d1a6c3370 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -5,7 +5,7 @@ import pathlib from datetime import datetime, timedelta from threading import Lock -from typing import Dict, List, Optional, Set, Union +from typing import Callable, Dict, List, Optional, Set, Union import oci from cachetools import TTLCache @@ -26,6 +26,7 @@ AquaRuntimeError, AquaValueError, ) +from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum from ads.aqua.common.utils import ( LifecycleStatus, _build_resource_identifier, @@ -76,6 +77,7 @@ ModelFormat, ModelValidationResult, ) +from ads.aqua.model.utils import prepare_progress_tracker_with_callback from ads.aqua.ui import AquaContainerConfig, AquaContainerConfigItem from ads.common.auth import default_signer from ads.common.oci_resource import SEARCH_TYPE, OCIResource @@ -1403,6 +1405,7 @@ def _download_model_from_hf( local_dir: str = None, allow_patterns: List[str] = None, ignore_patterns: List[str] = None, + callback: Callable = None, ) -> str: """This helper function downloads the model artifact from Hugging Face to a local folder, then uploads to object storage location. @@ -1428,24 +1431,54 @@ def _download_model_from_hf( local_dir = os.path.join(local_dir, model_name) os.makedirs(local_dir, exist_ok=True) + def tqdm_callback(status: TaskStatus): # noqa: ARG001 + callback(status) + + def publish_status(status: TaskStatus): + """wrapper to avoid repeated null check""" + if callback: + callback(status) + # if local_dir is not set, the return value points to the cached data folder + tqdm = None + if callback: + tqdm = prepare_progress_tracker_with_callback(tqdm_callback) local_dir = snapshot_download( repo_id=model_name, local_dir=local_dir, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, + tqdm_class=tqdm, + ) + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_DOWNLOAD_SUCCESSFUL, + message="Model download complete", + ) ) # Upload to object storage and skip .cache/huggingface/ folder logger.debug( f"Uploading local artifacts from local directory {local_dir} to {os_path}." ) # Upload to object storage + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_UPLOAD_STARTED, + message=f"Uploading model to Object Storage: {os_path}", + ) + ) model_artifact_path = upload_folder( os_path=os_path, local_dir=local_dir, model_name=model_name, exclude_pattern=f"{HF_METADATA_FOLDER}*", ) + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_UPLOAD_SUCCESSFUL, + message=f"Model uploaded successfully to {os_path}", + ) + ) return model_artifact_path @@ -1479,6 +1512,16 @@ def register( if not import_model_details: import_model_details = ImportModelDetails(**kwargs) + def publish_status(status: TaskStatus): + """Invoke callback with the status""" + logger.info( + f"Publishing status using callback: {import_model_details.callback}" + ) + if import_model_details.callback: + import_model_details.callback(status=status) + else: + logger.info("No callback registered") + # If OCID of a model is passed, we need to copy the defaults for Tags and metadata from the service model. verified_model: Optional[DataScienceModel] = None if ( @@ -1497,6 +1540,7 @@ def register( f"Found service model for {import_model_details.model}: {model_service_id}" ) verified_model = DataScienceModel.from_id(model_service_id) + logger.info("fetched model from service catalog") # Copy the model name from the service model if `model` is ocid model_name = ( @@ -1511,15 +1555,28 @@ def register( model_name=model_name, verified_model=verified_model, ) + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_VALIDATION_SUCCESSFUL, + message="Model information validated", + ) + ) # download model from hugginface if indicates if import_model_details.download_from_hf: + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_DOWNLOAD_STARTED, + message=f"Downloading {model_name} from Hugging Face", + ) + ) artifact_path = self._download_model_from_hf( model_name=model_name, os_path=import_model_details.os_path, local_dir=import_model_details.local_dir, allow_patterns=import_model_details.allow_patterns, ignore_patterns=import_model_details.ignore_patterns, + callback=publish_status, ).rstrip("/") else: artifact_path = import_model_details.os_path.rstrip("/") @@ -1537,6 +1594,12 @@ def register( freeform_tags=import_model_details.freeform_tags, defined_tags=import_model_details.defined_tags, ) + publish_status( + TaskStatus( + TaskStatusEnum.DATASCIENCE_MODEL_CREATED, + message=f"DataScience model created. Model id is: {ds_model.id}", + ) + ) # registered model will always have inference and evaluation container, but # fine-tuning container may be not set inference_container = ds_model.custom_metadata_list.get( @@ -1587,7 +1650,12 @@ def register( cleanup_local_hf_model_artifact( model_name=model_name, local_dir=import_model_details.local_dir ) - + publish_status( + TaskStatus( + state=TaskStatusEnum.MODEL_REGISTRATION_SUCCESSFUL, + message=f"Model {model_name} successfully registered. Model id is: {ds_model.id}", + ) + ) return AquaModel(**aqua_model_attributes) def _if_show(self, model: DataScienceModel) -> bool: @@ -1688,6 +1756,7 @@ def _find_matching_aqua_model(self, model_id: str) -> Optional[str]: aqua_model_list = self.list() for aqua_model_summary in aqua_model_list: + print(aqua_model_summary.name.lower()) if aqua_model_summary.name.lower() == model_id_lower: logger.info( f"Found matching verified model id {aqua_model_summary.id} for the model {model_id}" diff --git a/ads/aqua/model/utils.py b/ads/aqua/model/utils.py new file mode 100644 index 000000000..af7564d32 --- /dev/null +++ b/ads/aqua/model/utils.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from typing import Callable, List, Union + +from tqdm import tqdm + +from ads.aqua.common.task_status import TaskStatus, TaskStatusEnum + + +class HFModelProgressTracker(tqdm): + """snapshot_download method from huggingface_hub library is used to download the models. This class provides a way to register for callbacks as the downloads of different files are complete.""" + + hooks = [] + + def __init__(self, *args, **kwargs): + """ + A custom tqdm class that calls `callback` each time progress is updated. + + """ + super().__init__(*args, **kwargs) + + def update(self, n=1): + # Perform the standard progress update + super().update(n) + # Invoke the callback with the current progress value (self.n) + for hook in self.hooks: + hook( + TaskStatus( + state=TaskStatusEnum.MODEL_DOWNLOAD_INPROGRESS, + message=f"{self.n} of {self.total} files downloaded", + ) + ) + + def close(self): + for hook in self.hooks: + hook( + TaskStatus( + state=TaskStatusEnum.MODEL_DOWNLOAD_INPROGRESS, + message=f"{self.n} of {self.total} files downloaded", + ) + ) + super().close() + + +def prepare_progress_tracker_with_callback( + hook: Union[Callable, List[Callable]], +) -> "HFModelProgressTrackerWithHook": # type: ignore # noqa: F821 + """Provide a list of callables or single callable to be invoked upon download progress. snapshot_download only allows to pass in class, does not allow for tqdm_kwargs supported by thread_map. + This class provides a thread safe way to use hooks""" + + class HFModelProgressTrackerWithHook(HFModelProgressTracker): + hooks = hook if isinstance(hook, list) else [hook] + + return HFModelProgressTrackerWithHook diff --git a/ads/aqua/server/app.py b/ads/aqua/server/app.py index 8d2a0eab6..61db24a48 100644 --- a/ads/aqua/server/app.py +++ b/ads/aqua/server/app.py @@ -4,7 +4,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os -from logging import getLogger +from logging import DEBUG, getLogger import tornado.ioloop import tornado.web @@ -35,7 +35,20 @@ def make_app(): def start_server(): + access_log = getLogger("tornado.access") + # Set the logging level to DEBUG + access_log.setLevel(DEBUG) app = make_app() + logger.info("Endpoints:") + for rule in app.wildcard_router.rules: + # Depending on the rule type, the route may be stored in different properties. + # If the rule has a regex matcher, you can get its pattern. + regex = ( + rule.matcher.regex.pattern + if hasattr(rule.matcher, "regex") + else str(rule.matcher) + ) + print(f"\t\t{regex}") server = tornado.httpserver.HTTPServer(app) port = int(os.environ.get(AQUA_PORT, 8080)) host = os.environ.get(AQUA_HOST, "0.0.0.0") diff --git a/pyproject.toml b/pyproject.toml index 15fad7980..49686bb8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,6 +205,8 @@ pii = [ ] llm = ["langchain>=0.2", "langchain-community", "langchain_openai", "pydantic>=2,<3", "evaluate>=0.4.0"] aqua = [ + "oci-cli", + "jupyter_server", "tornado", "notebook~=6.5", "fire", diff --git a/tests/unitary/with_extras/aqua/test_tqdm.py b/tests/unitary/with_extras/aqua/test_tqdm.py new file mode 100644 index 000000000..95a29bb99 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_tqdm.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from unittest.mock import MagicMock + +from tqdm.contrib.concurrent import thread_map + +from ads.aqua.model.utils import prepare_progress_tracker_with_callback + + +def test_custom_tqdm_thread_map(): + def process(item): + import time + + time.sleep(0.1) + return item + + items = list(range(0, 10)) + callback = MagicMock() + + clz = prepare_progress_tracker_with_callback(callback) + thread_map( + process, + items, + desc=f"Fetching {len(items)} items", + tqdm_class=clz, + max_workers=3, + ) + callback.assert_called() + + +def test_custom_tqdm(): + callback = MagicMock() + clz = prepare_progress_tracker_with_callback(callback) + with clz(range(10), desc="Processing") as bar: + for _ in bar: + # Simulate work + import time + + time.sleep(0.01) + callback.assert_called() + + +if __name__ == "__main__": + test_custom_tqdm_thread_map()