diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index ef11870cea..f3978b2965 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -484,7 +484,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "pegasus", ) GENERATION_LENGTH = 100 - SPEEDUP_CACHE = 1.1 @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -557,29 +556,17 @@ def test_compare_with_and_without_past_key_values(self): tokens = tokenizer("This is a sample input", return_tensors="pt") model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True) - # Warmup - _ = model_with_pkv.generate(**tokens) - with Timer() as with_pkv_timer: - outputs_model_with_pkv = model_with_pkv.generate( - **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 - ) - + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) model_without_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False) - - # Warmup - _ = model_without_pkv.generate(**tokens) - with Timer() as without_pkv_timer: - outputs_model_without_pkv = model_without_pkv.generate( - **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 - ) + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) - self.assertTrue( - without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, - f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," - f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", - ) + del model_with_pkv del model_without_pkv gc.collect()