From 2148358af2940886a4da0a02b79ff7581d84e453 Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 7 Nov 2023 09:02:29 +0400 Subject: [PATCH] allow to run on GPU --- optimum/intel/openvino/modeling_base_seq2seq.py | 2 -- tests/openvino/test_modeling.py | 7 +++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 527adc4347..e3dd1d7aa0 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -68,8 +68,6 @@ def __init__( self.ov_config = ov_config if ov_config is not None else {} self.preprocessors = kwargs.get("preprocessors", []) - if "GPU" in self._device: - raise ValueError("Support of dynamic shapes for GPU devices is not yet available.") if self.is_dynamic: encoder = self._reshape(encoder, -1, -1, is_decoder=False) decoder = self._reshape(decoder, -1, -1) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 3a7cf02a00..dc33b39f2a 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -494,7 +494,7 @@ def test_compare_to_transformers(self, model_arch): set_seed(SEED) ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True) self.assertIsInstance(ov_model.config, PretrainedConfig) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None @@ -510,8 +510,7 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Compare tensor outputs - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4), - f"Max diff {torch.abs(ov_outputs.logits - transformers_outputs.logits).max()}") + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) del transformers_model del ov_model gc.collect() @@ -1211,7 +1210,7 @@ def test_compare_with_and_without_past_key_values(self): class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES = ("whisper", "speech_to_text") + SUPPORTED_ARCHITECTURES = ("whisper",) def _generate_random_audio_data(self): np.random.seed(10)