From 0532243068effd61801d8771e14a90a51672a947 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 17 Oct 2023 16:56:59 +0200 Subject: [PATCH] check file name loading --- tests/neural_compressor/test_modeling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index 854ae61ac0..9d2f07f0bd 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -94,10 +94,10 @@ def test_compare_to_transformers(self, model_id, task): config = config_class(inc_model.config) model_inputs = config.generate_dummy_inputs(framework="pt") outputs = inc_model(**model_inputs) - + file_name = "model.pt" with tempfile.TemporaryDirectory() as tmpdirname: - inc_model.save_pretrained(tmpdirname) - loaded_model = model_class.from_pretrained(tmpdirname) + inc_model.save_pretrained(tmpdirname, file_name) + loaded_model = model_class.from_pretrained(tmpdirname, file_name=file_name) outputs_loaded = loaded_model(**model_inputs) if task == "feature-extraction":