Skip to content

Commit

Permalink
add test for use_cache=False
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 27, 2023
1 parent f55417c commit 532130a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ def test_load_from_hub_and_save_model(self):
del model
gc.collect()

def test_load_from_hub_and_save_decoder_model(self):
@parameterized.expand((True, False))
def test_load_from_hub_and_save_decoder_model(self, use_cache):
tokenizer = AutoTokenizer.from_pretrained(self.OV_DECODER_MODEL_ID)
tokens = tokenizer("This is a sample input", return_tensors="pt")
loaded_model = OVModelForCausalLM.from_pretrained(self.OV_DECODER_MODEL_ID, use_cache=True)
loaded_model = OVModelForCausalLM.from_pretrained(self.OV_DECODER_MODEL_ID, use_cache=use_cache)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
loaded_model_outputs = loaded_model(**tokens)

Expand All @@ -133,7 +134,8 @@ def test_load_from_hub_and_save_decoder_model(self):
folder_contents = os.listdir(tmpdirname)
self.assertTrue(OV_XML_FILE_NAME in folder_contents)
self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents)
model = OVModelForCausalLM.from_pretrained(tmpdirname, use_cache=True)
model = OVModelForCausalLM.from_pretrained(tmpdirname, use_cache=use_cache)
self.assertEqual(model.use_cache, use_cache)

outputs = model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
Expand Down

0 comments on commit 532130a

Please sign in to comment.