Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Testing] Extension Compatibility Check #1043

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""AQUA utils and constants."""

Expand Down Expand Up @@ -226,24 +226,33 @@ def get_artifact_path(custom_metadata_list: List) -> str:

def read_file(file_path: str, **kwargs) -> str:
try:
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
with fsspec.open(
file_path, "r", **kwargs.get("auth", {}), **kwargs.get("config_kwargs", {})
) as f:
return f.read()
except Exception as e:
logger.debug(f"Failed to read file {file_path}. {e}")
except Exception as ex:
logger.debug(f"Failed to read file {file_path}.\n Error:{ex}")
if kwargs.get("raise_error", False):
raise ex
return UNKNOWN


@threaded()
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
signer = default_signer() if artifact_path.startswith("oci://") else {}

signer = (
kwargs.get("auth", default_signer())
if artifact_path.startswith("oci://")
else {}
)
config = json.loads(
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
)
if not config:
raise AquaFileNotFoundError(
f"Config file `{config_file_name}` is either empty or missing at {artifact_path}",
500,
404,
)
return config

Expand Down Expand Up @@ -600,7 +609,7 @@ def get_container_image(
return container_image


def fetch_service_compartment() -> Union[str, None]:
def fetch_service_compartment(**kwargs) -> Union[str, None]:
"""
Loads the compartment mapping json from service bucket.
This json file has a service-model-compartment key which contains a dictionary of namespaces
Expand All @@ -614,13 +623,19 @@ def fetch_service_compartment() -> Union[str, None]:
config = load_config(
file_path=config_file_name,
config_file_name=CONTAINER_INDEX,
**kwargs,
)
except Exception as e:
logger.debug(
message = (
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID "
f"could not be found. \n{str(e)}."
)
return
logger.debug(message)
if kwargs.get("raise_error", False):
raise e
else:
return UNKNOWN

compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)
if compartment_mapping:
return compartment_mapping.get(CONDA_BUCKET_NS)
Expand Down Expand Up @@ -788,7 +803,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""


def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
def upload_folder(
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
) -> str:
"""Upload the local folder to the object storage

Args:
Expand Down
4 changes: 3 additions & 1 deletion ads/aqua/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
"""This module defines constants used in ads.aqua module."""

Expand Down Expand Up @@ -46,6 +46,8 @@
SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
SUPPORTED_FILE_FORMATS = ["jsonl"]
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
AQUA_EXTENSION_LOAD_DEFAULT_TIMEOUT = 10
AQUA_EXTENSION_LOAD_MAX_ATTEMPTS = 1

CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
"datasciencemodel": "models",
Expand Down
54 changes: 47 additions & 7 deletions ads/aqua/extension/common_handler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import concurrent.futures
import sys
import traceback
from importlib import metadata

import huggingface_hub
import oci
import requests
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from tornado.web import HTTPError

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
from ads.aqua.common.errors import (
AquaResourceAccessError,
AquaRuntimeError,
)
from ads.aqua.common.utils import (
get_huggingface_login_timeout,
known_realm,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.models.ws_models import CompatibilityCheckResponseData
from ads.aqua.extension.utils import ui_compatability_check


Expand Down Expand Up @@ -50,10 +58,42 @@ def get(self):
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.

"""
if ui_compatability_check():
return self.finish({"status": "ok"})
elif known_realm():
return self.finish({"status": "compatible"})
service_compartment = None
response = None
extension_status = "compatible" if known_realm() else "incompatible"
try:
service_compartment = ui_compatability_check()
except (concurrent.futures.TimeoutError, oci.exceptions.ConnectTimeout) as ex:
response = CompatibilityCheckResponseData(
status=extension_status,
msg="If you are using custom networking in your notebook session, "
"please check if the subnet has service gateway configured.",
payload={
"status_code": 408,
"reason": f"{type(ex).__name__}: {str(ex)}",
"exc_info": "".join(traceback.format_exception(*sys.exc_info())),
},
).to_dict()
except Exception as ex:
response = CompatibilityCheckResponseData(
status=extension_status,
msg="Unable to load AI Quick Actions configuration. "
"Please check if you have set up the policies to enable the extension.",
payload={
"status_code": 404,
"reason": f"{type(ex).__name__}: {str(ex)}",
"exc_info": "".join(traceback.format_exception(*sys.exc_info())),
},
).to_dict()
if service_compartment:
response = CompatibilityCheckResponseData(
status="ok",
msg="Successfully retrieved service compartment id.",
payload={"ODSC_MODEL_COMPARTMENT_OCID": service_compartment},
).to_dict()
return self.finish(response)
elif extension_status == "compatible" and response is not None:
return self.finish(response)
else:
raise AquaResourceAccessError(
"The AI Quick actions extension is not compatible in the given region."
Expand All @@ -73,7 +113,7 @@ class HFLoginHandler(AquaAPIhandler):
"""Handler to login to HF."""

@handle_exceptions
def post(self, *args, **kwargs):
def post(self, *args, **kwargs): # noqa: ARG002
"""Handles post request for the HF login.

Raises
Expand Down
55 changes: 50 additions & 5 deletions ads/aqua/extension/common_ws_msg_handler.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
#!/usr/bin/env python

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import concurrent.futures
import json
import sys
import traceback
from importlib import metadata
from typing import List, Union

import oci

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError
from ads.aqua.common.utils import known_realm
from ads.aqua.extension.aqua_ws_msg_handler import AquaWSMsgHandler
from ads.aqua.extension.models.ws_models import (
AdsVersionResponse,
CompatibilityCheckResponse,
CompatibilityCheckResponseData,
RequestResponseType,
)
from ads.aqua.extension.utils import ui_compatability_check
Expand All @@ -39,17 +45,56 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
)
return response
if request.get("kind") == "CompatibilityCheck":
if ui_compatability_check():
service_compartment = None
response = None
extension_status = "compatible" if known_realm() else "incompatible"
try:
service_compartment = ui_compatability_check()
except (
concurrent.futures.TimeoutError,
oci.exceptions.ConnectTimeout,
) as ex:
response = CompatibilityCheckResponseData(
status=extension_status,
msg="If you are using custom networking in your notebook session, "
"please check if the subnet has service gateway configured.",
payload={
"status_code": 408,
"reason": f"{type(ex).__name__}: {str(ex)}",
"exc_info": "".join(
traceback.format_exception(*sys.exc_info())
),
},
).to_dict()
except Exception as ex:
response = CompatibilityCheckResponseData(
status=extension_status,
msg="Unable to load AI Quick Actions configuration. "
"Please check if you have set up the policies to enable the extension.",
payload={
"status_code": 404,
"reason": f"{type(ex).__name__}: {str(ex)}",
"exc_info": "".join(
traceback.format_exception(*sys.exc_info())
),
},
).to_dict()
if service_compartment:
response = CompatibilityCheckResponseData(
status="ok",
msg="Successfully retrieved service compartment id.",
payload={"ODSC_MODEL_COMPARTMENT_OCID": service_compartment},
).to_dict()
return CompatibilityCheckResponse(
message_id=request.get("message_id"),
kind=RequestResponseType.CompatibilityCheck,
data={"status": "ok"},
data=response,
)
elif known_realm():
elif extension_status == "compatible" and response is not None:
return CompatibilityCheckResponse(
message_id=request.get("message_id"),
kind=RequestResponseType.CompatibilityCheck,
data={"status": "compatible"},
data=response,
)
else:
raise AquaResourceAccessError(
Expand Down
19 changes: 15 additions & 4 deletions ads/aqua/extension/models/ws_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from dataclasses import dataclass
from typing import List, Optional

from ads.aqua.evaluation.entities import AquaEvaluationSummary, AquaEvaluationDetail
from ads.aqua.model.entities import AquaModelSummary, AquaModel
from pydantic import Field

from ads.aqua.config.utils.serializer import Serializable
from ads.aqua.evaluation.entities import AquaEvaluationDetail, AquaEvaluationSummary
from ads.aqua.model.entities import AquaModel, AquaModelSummary
from ads.aqua.modeldeployment.entities import AquaDeployment, AquaDeploymentDetail
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.serializer import DataClassSerializable
Expand Down Expand Up @@ -142,3 +144,12 @@ class AquaWsError(DataClassSerializable):
class ErrorResponse(BaseResponse):
data: AquaWsError
kind = RequestResponseType.Error


class CompatibilityCheckResponseData(Serializable):
status: str
payload: dict = Field(default_factory=dict)
msg: Optional[str] = None

class Config:
extra = "ignore"
48 changes: 46 additions & 2 deletions ads/aqua/extension/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
from contextlib import contextmanager
from dataclasses import fields
from datetime import datetime, timedelta
from typing import Dict, Optional

import oci
from cachetools import TTLCache, cached
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.utils import fetch_service_compartment
from ads.aqua.constants import (
AQUA_EXTENSION_LOAD_DEFAULT_TIMEOUT,
AQUA_EXTENSION_LOAD_MAX_ATTEMPTS,
)
from ads.aqua.extension.errors import Errors
from ads.config import THREADED_DEFAULT_TIMEOUT


def validate_function_parameters(data_class, input_data: Dict):
Expand All @@ -26,9 +35,44 @@ def validate_function_parameters(data_class, input_data: Dict):
)


@contextmanager
def use_temporary_envs(overrides: dict):
existing_vars: dict = {}
for key, new_value in overrides.items():
existing_vars[key] = os.getenv(key)
os.environ[key] = new_value
try:
yield
finally:
for key, old_value in existing_vars.items():
if old_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = old_value


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
def ui_compatability_check():
"""This method caches the service compartment OCID details that is set by either the environment variable or if
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
from the UI to avoid multiple config file loads."""
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()
if ODSC_MODEL_COMPARTMENT_OCID:
return ODSC_MODEL_COMPARTMENT_OCID

# set threaded default to 2x the extension load timeout value
env_overrides = {
"THREADED_DEFAULT_TIMEOUT": max(
THREADED_DEFAULT_TIMEOUT, AQUA_EXTENSION_LOAD_DEFAULT_TIMEOUT * 2
)
}
with use_temporary_envs(env_overrides):
retry_strategy = oci.retry.RetryStrategyBuilder(
max_attempts=AQUA_EXTENSION_LOAD_MAX_ATTEMPTS
)
return fetch_service_compartment(
config_kwargs={
"timeout": AQUA_EXTENSION_LOAD_DEFAULT_TIMEOUT,
"retry_strategy": retry_strategy.get_retry_strategy(),
},
raise_error=True,
)
Loading
Loading