Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 17, 2023
1 parent 0532243 commit 8259e83
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 3 additions & 2 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _from_pretrained(
file_name: str = WEIGHTS_NAME,
local_files_only: bool = False,
subfolder: str = "",
trust_remote_code: bool = False,
**kwargs,
):
model_name_or_path = kwargs.pop("model_name_or_path", None)
Expand Down Expand Up @@ -178,8 +179,8 @@ def _from_pretrained(
model, config=config, model_save_dir=model_save_dir, q_config=q_config, inc_config=inc_config, **kwargs
)

def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = WEIGHTS_NAME):
output_path = os.path.join(save_directory, file_name)
def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)

if isinstance(self.model, torch.nn.Module):
state_dict = self.model.state_dict()
Expand Down
8 changes: 3 additions & 5 deletions tests/neural_compressor/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
INCStableDiffusionPipeline,
INCTrainer,
)
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS

from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME

os.environ["CUDA_VISIBLE_DEVICES"] = ""
set_seed(1009)
Expand Down Expand Up @@ -94,10 +93,9 @@ def test_compare_to_transformers(self, model_id, task):
config = config_class(inc_model.config)
model_inputs = config.generate_dummy_inputs(framework="pt")
outputs = inc_model(**model_inputs)
file_name = "model.pt"
with tempfile.TemporaryDirectory() as tmpdirname:
inc_model.save_pretrained(tmpdirname, file_name)
loaded_model = model_class.from_pretrained(tmpdirname, file_name=file_name)
inc_model.save_pretrained(tmpdirname)
loaded_model = model_class.from_pretrained(tmpdirname, file_name=WEIGHTS_NAME)
outputs_loaded = loaded_model(**model_inputs)

if task == "feature-extraction":
Expand Down

0 comments on commit 8259e83

Please sign in to comment.