Skip to content

Commit

Permalink
fix when to save pretrained config
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 22, 2024
1 parent ed43a9e commit 984a689
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 10 deletions.
6 changes: 4 additions & 2 deletions optimum_benchmark/backends/neural_compressor/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def create_no_weights_model(self) -> None:
state_dict = torch.nn.Linear(1, 1).state_dict()
LOGGER.info("\t+ Saving no weights model pytorch_model.bin")
torch.save(state_dict, os.path.join(self.no_weights_model, "pytorch_model.bin"))
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_automodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
Expand Down
6 changes: 4 additions & 2 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def create_no_weights_model(self) -> None:
LOGGER.info("\t+ Saving no weights model safetensors")
safetensors = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors, metadata={"format": "pt"})
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_ortmodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
Expand Down
6 changes: 4 additions & 2 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def create_no_weights_model(self) -> None:
LOGGER.info("\t+ Saving no weights model safetensors")
safetensors = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors, metadata={"format": "pt"})
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_automodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
Expand Down
5 changes: 3 additions & 2 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def create_no_weights_model(self) -> None:
self.pretrained_config.quantization_config = self.quantization_config.to_dict()
# tricking from_pretrained to load the model as if it was quantized

LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
Expand Down
15 changes: 15 additions & 0 deletions optimum_benchmark/backends/tensorrt_llm/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from logging import getLogger
from typing import Any, Dict

import torch
from hydra.utils import get_class
from safetensors.torch import save_file
from transformers.utils import ModelOutput

from ..base import Backend
Expand All @@ -28,6 +31,18 @@ def validate_model_type(self) -> None:
self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.model_type])
LOGGER.info(f"\t+ Using TRTLLMModel class {self.trtmodel_class.__name__}")

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
LOGGER.info("\t+ Creating no weights model state dict")
state_dict = torch.nn.Linear(1, 1).state_dict()
LOGGER.info("\t+ Saving no weights model safetensors")
safetensors = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors, metadata={"format": "pt"})

if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_trtmodel_from_pretrained(self) -> None:
self.pretrained_model = self.trtmodel_class.from_pretrained(
self.config.model,
Expand Down
6 changes: 4 additions & 2 deletions optimum_benchmark/backends/torch_ort/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def create_no_weights_model(self) -> None:
LOGGER.info("\t+ Saving no weights model safetensors")
safetensors = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensors, metadata={"format": "pt"})
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

if self.config.library == "transformers":
LOGGER.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)

def load_automodel_with_no_weights(self) -> None:
LOGGER.info("\t+ Creating no weights model")
Expand Down

0 comments on commit 984a689

Please sign in to comment.