Skip to content

Commit

Permalink
[EAGLE-5408]: Add storage request inferred from tar and checkpoint si…
Browse files Browse the repository at this point in the history
…ze (#479)

* Add storage request inferred from tar and checkpoint size to model upload
  • Loading branch information
patricklundquist authored Feb 5, 2025
1 parent 324491e commit fc99a27
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
40 changes: 37 additions & 3 deletions clarifai/runners/models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _clear_line(n: int = 1) -> None:


class ModelBuilder:
DEFAULT_CHECKPOINT_SIZE = 50 * 1024**3 # 50 GiB

def __init__(self, folder: str, validate_api_ids: bool = True, download_validation_only=False):
"""
Expand Down Expand Up @@ -154,6 +155,9 @@ def _check_app_exists(self):
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
if resp.status.code == status_code_pb2.SUCCESS:
return True
logger.error(
f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
)
return False

def _validate_config_model(self):
Expand Down Expand Up @@ -200,6 +204,24 @@ def _validate_config(self):
)
logger.info("Continuing without Hugging Face token")

@staticmethod
def _get_tar_file_content_size(tar_file_path):
"""
Calculates the total size of the contents of a tar file.
Args:
tar_file_path (str): The path to the tar file.
Returns:
int: The total size of the contents in bytes.
"""
total_size = 0
with tarfile.open(tar_file_path, 'r') as tar:
for member in tar:
if member.isfile():
total_size += member.size
return total_size

@property
def client(self):
if self._client is None:
Expand All @@ -211,9 +233,8 @@ def client(self):
user_id = model.get('user_id')
app_id = model.get('app_id')

base = os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')

self._client = BaseClient(user_id=user_id, app_id=app_id, base=base)
self._base_api = os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')
self._client = BaseClient(user_id=user_id, app_id=app_id, base=self._base_api)

return self._client

Expand Down Expand Up @@ -520,6 +541,18 @@ def filter_func(tarinfo):
file_size = os.path.getsize(self.tar_file)
logger.info(f"Size of the tar is: {file_size} bytes")

self.storage_request_size = self._get_tar_file_content_size(file_path)
if not download_checkpoints and self.config.get("checkpoints"):
# Get the checkpoint size to add to the storage request.
# First check for the env variable, then try querying huggingface. If all else fails, use the default.
checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
if not checkpoint_size:
_, repo_id, _ = self._validate_config_checkpoints()
checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
if not checkpoint_size:
checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
self.storage_request_size += checkpoint_size

self.maybe_create_model()
if not self.check_model_exists():
logger.error(f"Failed to create model: {self.model_proto.id}")
Expand Down Expand Up @@ -594,6 +627,7 @@ def init_upload_model_version(self, model_version_proto, file_path):
model_id=self.model_proto.id,
model_version=model_version_proto,
total_size=file_size,
storage_request_size=self.storage_request_size,
is_v3=self.is_v3,
))
return result
Expand Down
33 changes: 33 additions & 0 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import shutil

import requests

from clarifai.utils.logging import logger


Expand Down Expand Up @@ -169,3 +171,34 @@ def fetch_labels(checkpoint_path: str):

labels = config['id2label']
return labels

@staticmethod
def get_huggingface_checkpoint_total_size(repo_name):
"""
Fetches the JSON data for a Hugging Face model using the API with `?blobs=true`.
Calculates the total size from the JSON output.
Args:
repo_name (str): The name of the model on Hugging Face Hub. e.g. "casperhansen/llama-3-8b-instruct-awq"
Returns:
int: The total size in bytes.
"""
try:
url = f"https://huggingface.co/api/models/{repo_name}?blobs=true"
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
json_data = response.json()

if isinstance(json_data, str):
data = json.loads(json_data)
else:
data = json_data

total_size = 0
for file in data['siblings']:
total_size += file['size']
return total_size
except Exception as e:
logger.error(f"Error fetching checkpoint size from huggingface.co: {e}")
return 0

0 comments on commit fc99a27

Please sign in to comment.