Skip to content

Commit

Permalink
Revert "[EAGLE-5416] Added Tests for Download checkpoints and Fix dow…
Browse files Browse the repository at this point in the history
…nload me…" (#487)

This reverts commit 9296979.
  • Loading branch information
zeiler authored Jan 17, 2025
1 parent 9296979 commit f0d909a
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 162 deletions.
3 changes: 1 addition & 2 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _validate_config_checkpoints(self):
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

# get from config.yaml otherwise fall back to HF_TOKEN env var.
hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
hf_token = self.config.get("checkpoints").get("hf_token", None)
return repo_id, hf_token

def _check_app_exists(self):
Expand Down
51 changes: 9 additions & 42 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import fnmatch
import importlib.util
import json
import os
import shutil
import subprocess

from clarifai.utils.logging import logger
Expand Down Expand Up @@ -41,7 +39,7 @@ def validate_hftoken(cls, hf_token: str):
def download_checkpoints(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
try:
from huggingface_hub import snapshot_download
from huggingface_hub import list_repo_files, snapshot_download
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
Expand All @@ -55,17 +53,16 @@ def download_checkpoints(self, checkpoint_path: str):
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
return False

self.ignore_patterns = self._get_ignore_patterns()
ignore_patterns = None # Download everything.
repo_files = list_repo_files(repo_id=self.repo_id, token=self.token)
if any(f.endswith(".safetensors") for f in repo_files):
logger.info(f"SafeTensors found in {self.repo_id}, downloading only .safetensors files.")
ignore_patterns = ["**/original/*", "**/*.pth", "**/*.bin"]
snapshot_download(
repo_id=self.repo_id,
local_dir=checkpoint_path,
local_dir_use_symlinks=False,
ignore_patterns=self.ignore_patterns)
# Remove the `.cache` folder if it exists
cache_path = os.path.join(checkpoint_path, ".cache")
if os.path.exists(cache_path) and os.path.isdir(cache_path):
shutil.rmtree(cache_path)

ignore_patterns=ignore_patterns)
except Exception as e:
logger.error(f"Error downloading model checkpoints {e}")
return False
Expand Down Expand Up @@ -112,41 +109,11 @@ def validate_download(self, checkpoint_path: str):
from huggingface_hub import list_repo_files
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
# Get the list of files on the repo
repo_files = list_repo_files(self.repo_id, token=self.token)

self.ignore_patterns = self._get_ignore_patterns()
# Get the list of files on the repo that are not ignored
if getattr(self, "ignore_patterns", None):
patterns = self.ignore_patterns

def should_ignore(file_path):
return any(fnmatch.fnmatch(file_path, pattern) for pattern in patterns)

repo_files = [f for f in repo_files if not should_ignore(f)]

# Check if downloaded files match the files we expect (ignoring ignored patterns)
checkpoint_dir_files = [
f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn
]

# Validate by comparing file lists
return len(checkpoint_dir_files) >= len(repo_files) and not (
len(set(repo_files) - set(checkpoint_dir_files)) > 0) and len(repo_files) > 0

def _get_ignore_patterns(self):
# check if model exists on HF
try:
from huggingface_hub import list_repo_files
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)

# Get the list of files on the repo that are not ignored
repo_files = list_repo_files(self.repo_id, token=self.token)
self.ignore_patterns = None
if any(f.endswith(".safetensors") for f in repo_files):
self.ignore_patterns = ["**/original/*", "**/*.pth", "**/*.bin", "*.pth", "*.bin"]
return self.ignore_patterns
return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
list_repo_files(self.repo_id)) > 0

@staticmethod
def validate_config(checkpoint_path: str):
Expand Down
1 change: 0 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pytest==7.1.2
pytest-cov==5.0.0
pytest-xdist==2.5.0
llama-index-core==0.11.17
huggingface_hub[hf_transfer]==0.27.1
pypdf==3.17.4
seaborn==0.13.2
pycocotools==2.0.6
20 changes: 0 additions & 20 deletions tests/runners/dummy_runner_models/config.yaml

This file was deleted.

File renamed without changes.
Empty file.
62 changes: 0 additions & 62 deletions tests/runners/test_download_checkpoints.py

This file was deleted.

40 changes: 5 additions & 35 deletions tests/runners/test_runners.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,21 @@
# This test will create dummy runner and start the runner at first
# Testing outputs received by client and programmed outputs of runner server
#
import importlib.util
import inspect
import os
import sys
import threading
import uuid

import pytest
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from clarifai_protocol import BaseRunner
from clarifai_protocol.utils.logging import logger
from dummy_runner_models.model import MyRunner
from dummy_runner_models.model_wrapper import MyRunner as MyWrapperRunner
from google.protobuf import json_format

from clarifai.client import BaseClient, Model, User
from clarifai.client.auth.helper import ClarifaiAuthHelper


def runner_class(runner_path):

# arbitrary name given to the module to be imported
module = "runner_module"

spec = importlib.util.spec_from_file_location(module, runner_path)
runner_module = importlib.util.module_from_spec(spec)
sys.modules[module] = runner_module
spec.loader.exec_module(runner_module)

# Find all classes in the model.py file that are subclasses of BaseRunner
classes = [
cls for _, cls in inspect.getmembers(runner_module, inspect.isclass)
if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__
]

# Ensure there is exactly one subclass of BaseRunner in the model.py file
if len(classes) != 1:
raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(len(classes)))

return classes[0]


MyRunner = runner_class(
runner_path=os.path.join(os.path.dirname(__file__), "dummy_runner_models", "1", "model.py"))
MyWrapperRunner = runner_class(runner_path=os.path.join(
os.path.dirname(__file__), "dummy_runner_models", "1", "model_wrapper.py"))

# logger.disabled = True

TEXT_FILE_PATH = os.path.dirname(os.path.dirname(__file__)) + "/assets/sample.txt"
Expand All @@ -71,8 +40,7 @@ def init_components(
# except Exception as _:
# model = Model.from_auth_helper(auth=auth, model_id=model_id)

new_model = model.create_version(
pretrained_model_config=resources_pb2.PretrainedModelConfig(local_dev=True,))
new_model = model.create_version(resources_pb2.PretrainedModelConfig(local_dev=True,))

new_model_version = new_model.model_version.id
compute_cluster = resources_pb2.ComputeCluster(
Expand Down Expand Up @@ -143,6 +111,7 @@ def init_components(
return new_model_version, res.runners[0].id


@pytest.mark.skip(reason="Skipping Runners tests for now.")
@pytest.mark.requires_secrets
class TestRunnerServer:

Expand Down Expand Up @@ -470,6 +439,7 @@ def test_client_stream_by_filepath(self):
self._validate_response(res, text + out.format(i=i))


@pytest.mark.skip(reason="Skipping Runners tests for now.")
@pytest.mark.requires_secrets
class TestWrapperRunnerServer(TestRunnerServer):

Expand Down

0 comments on commit f0d909a

Please sign in to comment.