Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal working example of safetensors support for hezar #156

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hezar/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
from enum import Enum


HEZAR_HUB_ID = "hezarai"
HEZAR_CACHE_DIR = os.getenv("HEZAR_CACHE_DIR", f'{os.path.expanduser("~")}/.cache/hezar')

Expand Down Expand Up @@ -59,6 +58,7 @@ class Backends(ExplicitEnum):
SCIKIT = "sklearn"
SEQEVAL = "seqeval"
ROUGE = "rouge_score"
SAFETENSORS = "safetensors"


class TaskType(ExplicitEnum):
Expand Down
62 changes: 53 additions & 9 deletions hezar/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
>>> from hezar.models import Model
>>> model = Model.load("hezarai/bert-base-fa")
"""

from __future__ import annotations

import os
Expand All @@ -30,8 +31,13 @@
RegistryType,
)
from ..preprocessors import Preprocessor, PreprocessorsContainer
from ..utils import Logger, get_module_class, sanitize_function_parameters, verify_dependencies

from ..utils import (
Logger,
get_module_class,
is_backend_available,
sanitize_function_parameters,
verify_dependencies,
)

logger = Logger(__name__)

Expand All @@ -46,7 +52,7 @@
LossType.BCE_WITH_LOGITS: nn.BCEWithLogitsLoss,
LossType.CROSS_ENTROPY: nn.CrossEntropyLoss,
LossType.TRIPLE_MARGIN: nn.TripletMarginLoss,
LossType.CTC: nn.CTCLoss
LossType.CTC: nn.CTCLoss,
}


Expand Down Expand Up @@ -83,7 +89,7 @@ def __init__(self, config: ModelConfig, *args, **kwargs):
def __repr__(self):
representation = super().__repr__()
pattern = r"\('?_criterion'?\): [^\)]+\)\s*"
representation = re.sub(pattern, '', representation)
representation = re.sub(pattern, "", representation)
return representation

@staticmethod
Expand All @@ -103,6 +109,7 @@ def load(
config_filename: Optional[str] = None,
save_path: Optional[str | os.PathLike] = None,
cache_dir: Optional[str | os.PathLike] = None,
load_safetensors: bool = False,
**kwargs,
) -> "Model":
"""
Expand All @@ -120,6 +127,7 @@ def load(
config_filename: Optional config filename
save_path: Save model to this path after loading
cache_dir: Path to cache directory, defaults to `~/.cache/hezar`
load_safetensors: Load `safetensors` saved model. Defaults to `False` to preserve backward compatibility.

Returns:
The fully loaded Hezar model
Expand All @@ -146,6 +154,12 @@ def load(
model = model_cls(config, **kwargs)

model_filename = model_filename or model_cls.model_filename or cls.model_filename
if load_safetensors:
# conditionally loading `safetensors` if both `pickle` and `safetensors` formats
# if both formats exist requires a fair bit of alterations to the codebase by my estimation.
# moreover, it's redundant if the end goal is to convert everything to `safetensors`
# later down the line.
model_filename = model_filename.replace(".pt", ".safetensors")
# does the path exist locally?
is_local = load_locally or os.path.isdir(hub_or_local_path)
if not is_local:
Expand All @@ -158,8 +172,19 @@ def load(
else:
model_path = os.path.join(hub_or_local_path, model_filename)
# Get state dict from the model
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
if load_safetensors:
if not is_backend_available(Backends.SAFETENSORS):
raise ModuleNotFoundError(
f"`load_safetensors=True` requires `{Backends.SAFETENSORS}` to be installed."
f"Please install with `pip install {Backends.SAFETENSORS}` or set `load_safetensors=False`."
)
else:
from safetensors.torch import load_model

load_model(model, model_path)
else:
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
if device:
model.to(device)
if save_path:
Expand All @@ -179,13 +204,16 @@ def load_state_dict(self, state_dict: Mapping[str, Any], **kwargs):

Args:
state_dict: Model state dict

Returns:
NamedTuple: with ``missing_keys`` and ``unexpected_keys`` fields
"""
if len(self.skip_keys_on_load):
for key in self.skip_keys_on_load:
if key in state_dict:
state_dict.pop(key, None) # noqa
try:
super().load_state_dict(state_dict, strict=True)
tup = super().load_state_dict(state_dict, strict=True)
except RuntimeError:
compatible_state_dict = OrderedDict()
src_state_dict = self.state_dict()
Expand All @@ -200,21 +228,24 @@ def load_state_dict(self, state_dict: Mapping[str, Any], **kwargs):
compatible_state_dict[src_key] = src_weight
incompatible_keys.append(src_key)

missing_keys, _ = super().load_state_dict(compatible_state_dict, strict=False)
tup = super().load_state_dict(compatible_state_dict, strict=False)
missing_keys, _ = tup
if len(missing_keys) or len(incompatible_keys):
logger.warning(
"Partially loading the weights as the model architecture and the given state dict are "
"incompatible! \nIgnore this warning in case you plan on fine-tuning this model\n"
f"Incompatible keys: {incompatible_keys}\n"
f"Missing keys: {missing_keys}\n"
)
return tup

def save(
self,
path: str | os.PathLike,
filename: Optional[str] = None,
save_preprocessor: Optional[bool] = True,
config_filename: Optional[str] = None,
safe_serialization: bool = True,
):
"""
Save model weights and config to a local path
Expand All @@ -224,6 +255,7 @@ def save(
save_preprocessor: Whether to save preprocessor(s) along with the model or not
config_filename: Model config filename,
filename: Model weights filename
safe_serialization: Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).

Returns:
Path to the saved model
Expand All @@ -236,7 +268,19 @@ def save(
self.config.save(save_dir=path, filename=config_filename)

model_save_path = os.path.join(path, filename)
torch.save(self.state_dict(), model_save_path)
if not safe_serialization:
torch.save(self.state_dict(), model_save_path)
else:
if not is_backend_available(Backends.SAFETENSORS):
raise ModuleNotFoundError(
f"`safe_serialization=True` requires `{Backends.SAFETENSORS}` to be installed."
f"Please install with `pip install {Backends.SAFETENSORS}` or set `safe_serialization=False`."
)
else:
from safetensors.torch import save_model

model_save_path = model_save_path.replace(".pt", ".safetensors")
save_model(self, model_save_path)

if save_preprocessor:
if self.preprocessor is not None:
Expand Down
2 changes: 1 addition & 1 deletion hezar/utils/hub_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from pathlib import Path

from huggingface_hub import HfApi, Repository

from ..constants import HEZAR_CACHE_DIR, HEZAR_HUB_ID, RepoType
from ..utils.logging import Logger


__all__ = [
"resolve_pretrained_path",
"get_local_cache_path",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from hezar.models import Model


def test_safetensors():
example = ["هزار، کتابخانه‌ای کامل برای به کارگیری آسان هوش مصنوعی"]

pickled_model = Model.load("hezarai/bert-fa-sentiment-dksf")
pickled_outputs = pickled_model.predict(example)

pickled_model.save("safetensors_model", safe_serialization=True)

safe_model = Model.load("safetensors_model", load_safetensors=True)
safe_outputs = safe_model.predict(example)

assert pickled_outputs == safe_outputs