Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gcs hf #891

Merged
merged 3 commits into from
Feb 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import safetensors
import safetensors.numpy
import transformers.utils.hub
from fsspec import AbstractFileSystem
from huggingface_hub import HfApi, hf_hub_download, repo_exists, snapshot_download
from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError, RepositoryNotFoundError
from jax.experimental.multihost_utils import sync_global_devices
Expand Down Expand Up @@ -435,11 +436,20 @@ def _get_ref(self, ref) -> Tuple[str, Optional[str]]:
return ref.model_name_or_path, ref.revision

def load_state_dict(self, ref: Optional[Union[str, RepoRef]] = None, dtype: Optional[jnp.dtype] = None) -> dict:
"""Load a state dict from either HF Hub or a GCS path"""
if ref is None:
ref = self.reference_checkpoint
if ref is None:
raise ValueError("Must provide a checkpoint to load from")

# Handle GCS paths directly
if isinstance(ref, RepoRef) and ref.model_name_or_path.startswith("gs://"):
logger.info("\n\n loading hf from GCS! \n\n")
return self._load_from_gcs(ref.model_name_or_path, dtype)
elif isinstance(ref, str) and ref.startswith("gs://"):
logger.info("\n\n loading hf from GCS! \n\n")
return self._load_from_gcs(ref, dtype)

id, rev = self._get_ref(ref)

for index_file in [SAFE_TENSORS_INDEX_NAME, PYTORCH_WEIGHTS_INDEX_NAME]:
Expand Down Expand Up @@ -498,6 +508,52 @@ def _load_shards(self, id: str, index_file: str, rev: Optional[str], dtype) -> d

return final_state_dict

def _load_from_gcs(self, gcs_path: str, dtype: Optional[jnp.dtype] = None) -> dict:
"""Load a state dict from a GCS path"""
fs: AbstractFileSystem
fs, path = fsspec.core.url_to_fs(gcs_path)

# First try to load sharded checkpoint
for index_file in [SAFE_TENSORS_INDEX_NAME, PYTORCH_WEIGHTS_INDEX_NAME]:
index_path = os.path.join(path, index_file)
if fs.exists(index_path):
with fs.open(index_path, "r") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
final_state_dict = {}

if "safetensors" in index_file:
loader = _load_safe_tensors
else:
loader = _load_torch

for shard_file in shard_files:
shard_path = os.path.join(path, shard_file)
if not fs.exists(shard_path):
raise FileNotFoundError(f"Shard file {shard_path} not found")

# Download shard to temporary file
with tempfile.NamedTemporaryFile() as tmp:
fs.get(shard_path, tmp.name)
shard_state_dict = loader(tmp.name, dtype)
final_state_dict.update(shard_state_dict)

return final_state_dict

# If no index file found, try loading single file checkpoint
for model_file in [SAFE_TENSORS_MODEL, PYTORCH_MODEL]:
model_path = os.path.join(path, model_file)
if fs.exists(model_path):
with tempfile.NamedTemporaryFile() as tmp:
fs.get(model_path, tmp.name)
if model_file == SAFE_TENSORS_MODEL:
return _load_safe_tensors(tmp.name, dtype)
else:
return _load_torch(tmp.name, dtype)

raise FileNotFoundError(f"No checkpoint files found in {gcs_path}")

def load_pretrained(
self,
lm_model_cls: Type[ModelWithHfSerializationMixin],
Expand Down
Loading