Skip to content

Commit

Permalink
allow to run on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 14, 2023
1 parent daefcf3 commit 2148358
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
2 changes: 0 additions & 2 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def __init__(
self.ov_config = ov_config if ov_config is not None else {}
self.preprocessors = kwargs.get("preprocessors", [])

if "GPU" in self._device:
raise ValueError("Support of dynamic shapes for GPU devices is not yet available.")
if self.is_dynamic:
encoder = self._reshape(encoder, -1, -1, is_decoder=False)
decoder = self._reshape(decoder, -1, -1)
Expand Down
7 changes: 3 additions & 4 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_compare_to_transformers(self, model_arch):
set_seed(SEED)
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
Expand All @@ -510,8 +510,7 @@ def test_compare_to_transformers(self, model_arch):
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)
# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4),
f"Max diff {torch.abs(ov_outputs.logits - transformers_outputs.logits).max()}")
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
del transformers_model
del ov_model
gc.collect()
Expand Down Expand Up @@ -1211,7 +1210,7 @@ def test_compare_with_and_without_past_key_values(self):


class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("whisper", "speech_to_text")
SUPPORTED_ARCHITECTURES = ("whisper",)

def _generate_random_audio_data(self):
np.random.seed(10)
Expand Down

0 comments on commit 2148358

Please sign in to comment.