diff --git a/clarifai/runners/models/model_upload.py b/clarifai/runners/models/model_upload.py index 0f6bcb39..eba43000 100644 --- a/clarifai/runners/models/model_upload.py +++ b/clarifai/runners/models/model_upload.py @@ -78,8 +78,7 @@ def _validate_config_checkpoints(self): assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file" repo_id = self.config.get("checkpoints").get("repo_id") - # get from config.yaml otherwise fall back to HF_TOKEN env var. - hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None)) + hf_token = self.config.get("checkpoints").get("hf_token", None) return repo_id, hf_token def _check_app_exists(self): diff --git a/clarifai/runners/utils/loader.py b/clarifai/runners/utils/loader.py index 9795b3d5..1fc63aef 100644 --- a/clarifai/runners/utils/loader.py +++ b/clarifai/runners/utils/loader.py @@ -1,8 +1,6 @@ -import fnmatch import importlib.util import json import os -import shutil import subprocess from clarifai.utils.logging import logger @@ -41,7 +39,7 @@ def validate_hftoken(cls, hf_token: str): def download_checkpoints(self, checkpoint_path: str): # throw error if huggingface_hub wasn't installed try: - from huggingface_hub import snapshot_download + from huggingface_hub import list_repo_files, snapshot_download except ImportError: raise ImportError(self.HF_DOWNLOAD_TEXT) if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path): @@ -55,17 +53,16 @@ def download_checkpoints(self, checkpoint_path: str): logger.error("Model %s not found on Hugging Face" % (self.repo_id)) return False - self.ignore_patterns = self._get_ignore_patterns() + ignore_patterns = None # Download everything. + repo_files = list_repo_files(repo_id=self.repo_id, token=self.token) + if any(f.endswith(".safetensors") for f in repo_files): + logger.info(f"SafeTensors found in {self.repo_id}, downloading only .safetensors files.") + ignore_patterns = ["**/original/*", "**/*.pth", "**/*.bin"] snapshot_download( repo_id=self.repo_id, local_dir=checkpoint_path, local_dir_use_symlinks=False, - ignore_patterns=self.ignore_patterns) - # Remove the `.cache` folder if it exists - cache_path = os.path.join(checkpoint_path, ".cache") - if os.path.exists(cache_path) and os.path.isdir(cache_path): - shutil.rmtree(cache_path) - + ignore_patterns=ignore_patterns) except Exception as e: logger.error(f"Error downloading model checkpoints {e}") return False @@ -112,41 +109,11 @@ def validate_download(self, checkpoint_path: str): from huggingface_hub import list_repo_files except ImportError: raise ImportError(self.HF_DOWNLOAD_TEXT) - # Get the list of files on the repo - repo_files = list_repo_files(self.repo_id, token=self.token) - - self.ignore_patterns = self._get_ignore_patterns() - # Get the list of files on the repo that are not ignored - if getattr(self, "ignore_patterns", None): - patterns = self.ignore_patterns - - def should_ignore(file_path): - return any(fnmatch.fnmatch(file_path, pattern) for pattern in patterns) - - repo_files = [f for f in repo_files if not should_ignore(f)] - - # Check if downloaded files match the files we expect (ignoring ignored patterns) checkpoint_dir_files = [ f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn ] - - # Validate by comparing file lists - return len(checkpoint_dir_files) >= len(repo_files) and not ( - len(set(repo_files) - set(checkpoint_dir_files)) > 0) and len(repo_files) > 0 - - def _get_ignore_patterns(self): - # check if model exists on HF - try: - from huggingface_hub import list_repo_files - except ImportError: - raise ImportError(self.HF_DOWNLOAD_TEXT) - - # Get the list of files on the repo that are not ignored - repo_files = list_repo_files(self.repo_id, token=self.token) - self.ignore_patterns = None - if any(f.endswith(".safetensors") for f in repo_files): - self.ignore_patterns = ["**/original/*", "**/*.pth", "**/*.bin", "*.pth", "*.bin"] - return self.ignore_patterns + return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len( + list_repo_files(self.repo_id)) > 0 @staticmethod def validate_config(checkpoint_path: str): diff --git a/tests/requirements.txt b/tests/requirements.txt index 0944c15c..199337a4 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,7 +2,6 @@ pytest==7.1.2 pytest-cov==5.0.0 pytest-xdist==2.5.0 llama-index-core==0.11.17 -huggingface_hub[hf_transfer]==0.27.1 pypdf==3.17.4 seaborn==0.13.2 pycocotools==2.0.6 diff --git a/tests/runners/dummy_runner_models/config.yaml b/tests/runners/dummy_runner_models/config.yaml deleted file mode 100644 index 93358673..00000000 --- a/tests/runners/dummy_runner_models/config.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# This is the sample config file for the GOT OCR2.O model. - -model: - id: "dummy-runner-model" - user_id: "user_id" - app_id: "app_id" - model_type_id: "multimodal-to-text" - -build_info: - python_version: "3.11" - -inference_compute_info: - cpu_limit: "1" - cpu_memory: "1Gi" - num_accelerators: 0 - - -checkpoints: - type: "huggingface" - repo_id: "timm/mobilenetv3_small_100.lamb_in1k" diff --git a/tests/runners/dummy_runner_models/1/model.py b/tests/runners/dummy_runner_models/model.py similarity index 100% rename from tests/runners/dummy_runner_models/1/model.py rename to tests/runners/dummy_runner_models/model.py diff --git a/tests/runners/dummy_runner_models/1/model_wrapper.py b/tests/runners/dummy_runner_models/model_wrapper.py similarity index 100% rename from tests/runners/dummy_runner_models/1/model_wrapper.py rename to tests/runners/dummy_runner_models/model_wrapper.py diff --git a/tests/runners/dummy_runner_models/requirements.txt b/tests/runners/dummy_runner_models/requirements.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/runners/test_download_checkpoints.py b/tests/runners/test_download_checkpoints.py deleted file mode 100644 index 77d7f720..00000000 --- a/tests/runners/test_download_checkpoints.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import shutil -import tempfile - -import pytest - -from clarifai.runners.models.model_upload import ModelUploader -from clarifai.runners.utils.loader import HuggingFaceLoader - -MODEL_ID = "timm/mobilenetv3_small_100.lamb_in1k" - - -@pytest.fixture(scope="module") -def checkpoint_dir(): - # Create a temporary directory for the test checkpoints - temp_dir = os.path.join(tempfile.gettempdir(), MODEL_ID[5:]) - if not os.path.exists(temp_dir): - os.makedirs(temp_dir) - yield temp_dir # Provide the directory to the tests - # Cleanup: remove the directory after all tests are complete - shutil.rmtree(temp_dir, ignore_errors=True) - - -# Pytest fixture to delete the checkpoints in dummy runner models folder after tests complete -@pytest.fixture(scope="function") -def dummy_runner_models_dir(): - model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models") - checkpoints_path = os.path.join(model_folder_path, "1", "checkpoints") - yield checkpoints_path - # Cleanup the checkpoints folder after the test - if os.path.exists(checkpoints_path): - shutil.rmtree(checkpoints_path) - - -@pytest.fixture(scope="function", autouse=True) -def override_environment_variables(): - # Backup the existing environment variable value - original_clarifai_pat = os.environ.get("CLARIFAI_PAT") - if "CLARIFAI_PAT" in os.environ: - del os.environ["CLARIFAI_PAT"] # Temporarily unset the variable for the tests - yield - # Restore the original environment variable value after tests - if original_clarifai_pat: - os.environ["CLARIFAI_PAT"] = original_clarifai_pat - - -def test_loader_download_checkpoints(checkpoint_dir): - loader = HuggingFaceLoader(repo_id=MODEL_ID) - loader.download_checkpoints(checkpoint_path=checkpoint_dir) - assert len(os.listdir(checkpoint_dir)) == 4 - - -def test_validate_download(checkpoint_dir): - loader = HuggingFaceLoader(repo_id=MODEL_ID) - assert loader.validate_download(checkpoint_path=checkpoint_dir) is True - - -def test_download_checkpoints(dummy_runner_models_dir): - model_folder_path = os.path.join(os.path.dirname(__file__), "dummy_runner_models") - model_upload = ModelUploader(model_folder_path, validate_api_ids=False) - isdownloaded = model_upload.download_checkpoints() - assert isdownloaded is True diff --git a/tests/runners/test_runners.py b/tests/runners/test_runners.py index ceb2a90e..a0167c26 100644 --- a/tests/runners/test_runners.py +++ b/tests/runners/test_runners.py @@ -1,52 +1,21 @@ # This test will create dummy runner and start the runner at first # Testing outputs received by client and programmed outputs of runner server # -import importlib.util -import inspect import os -import sys import threading import uuid import pytest from clarifai_grpc.grpc.api import resources_pb2, service_pb2 from clarifai_grpc.grpc.api.status import status_code_pb2 -from clarifai_protocol import BaseRunner from clarifai_protocol.utils.logging import logger +from dummy_runner_models.model import MyRunner +from dummy_runner_models.model_wrapper import MyRunner as MyWrapperRunner from google.protobuf import json_format from clarifai.client import BaseClient, Model, User from clarifai.client.auth.helper import ClarifaiAuthHelper - -def runner_class(runner_path): - - # arbitrary name given to the module to be imported - module = "runner_module" - - spec = importlib.util.spec_from_file_location(module, runner_path) - runner_module = importlib.util.module_from_spec(spec) - sys.modules[module] = runner_module - spec.loader.exec_module(runner_module) - - # Find all classes in the model.py file that are subclasses of BaseRunner - classes = [ - cls for _, cls in inspect.getmembers(runner_module, inspect.isclass) - if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__ - ] - - # Ensure there is exactly one subclass of BaseRunner in the model.py file - if len(classes) != 1: - raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(len(classes))) - - return classes[0] - - -MyRunner = runner_class( - runner_path=os.path.join(os.path.dirname(__file__), "dummy_runner_models", "1", "model.py")) -MyWrapperRunner = runner_class(runner_path=os.path.join( - os.path.dirname(__file__), "dummy_runner_models", "1", "model_wrapper.py")) - # logger.disabled = True TEXT_FILE_PATH = os.path.dirname(os.path.dirname(__file__)) + "/assets/sample.txt" @@ -71,8 +40,7 @@ def init_components( # except Exception as _: # model = Model.from_auth_helper(auth=auth, model_id=model_id) - new_model = model.create_version( - pretrained_model_config=resources_pb2.PretrainedModelConfig(local_dev=True,)) + new_model = model.create_version(resources_pb2.PretrainedModelConfig(local_dev=True,)) new_model_version = new_model.model_version.id compute_cluster = resources_pb2.ComputeCluster( @@ -143,6 +111,7 @@ def init_components( return new_model_version, res.runners[0].id +@pytest.mark.skip(reason="Skipping Runners tests for now.") @pytest.mark.requires_secrets class TestRunnerServer: @@ -470,6 +439,7 @@ def test_client_stream_by_filepath(self): self._validate_response(res, text + out.format(i=i)) +@pytest.mark.skip(reason="Skipping Runners tests for now.") @pytest.mark.requires_secrets class TestWrapperRunnerServer(TestRunnerServer):