Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 28, 2024
1 parent f14fb91 commit 7cc52a7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
5 changes: 5 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ def half(self):
compress_model_transformation(model)
return self

def to(self, device):
self.language_model.to(device)
super().to(device)
return self

def forward(
self,
input_ids,
Expand Down
32 changes: 31 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,14 +1903,35 @@ def test_compare_to_transformers(self, model_arch):
set_seed(SEED)
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True)
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, export=True, compile=False)
self.assertIsInstance(ov_model, MODEL_TYPE_TO_CLS_MAPPING[ov_model.config.model_type])
self.assertIsInstance(ov_model.vision_embeddings, OVVisionEmbedding)
self.assertIsInstance(ov_model.language_model, OVModelWithEmbedForCausalLM)
for additional_part in ov_model.additional_parts:
self.assertTrue(hasattr(ov_model, additional_part))
self.assertIsInstance(getattr(ov_model, additional_part), MODEL_PARTS_CLS_MAPPING[additional_part])
self.assertIsInstance(ov_model.config, PretrainedConfig)
ov_model.to("AUTO")
self.assertTrue("AUTO" in ov_model._device)
self.assertTrue("AUTO" in ov_model.vision_embeddings._device)
self.assertTrue(ov_model.vision_embeddings.request is None)
self.assertTrue("AUTO" in ov_model.language_model._device)
self.assertTrue(ov_model.language_model.request is None)
self.assertTrue(ov_model.language_model.text_emb_request is None)
for additional_part in ov_model.additional_parts:
self.assertTrue("AUTO" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is None)
ov_model.to("CPU")
ov_model.compile()
self.assertTrue("CPU" in ov_model._device)
self.assertTrue("CPU" in ov_model.vision_embeddings._device)
self.assertTrue(ov_model.vision_embeddings.request is not None)
self.assertTrue("CPU" in ov_model.language_model._device)
self.assertTrue(ov_model.language_model.request is not None)
self.assertTrue(ov_model.language_model.text_emb_request is not None)
for additional_part in ov_model.additional_parts:
self.assertTrue("CPU" in getattr(ov_model, additional_part)._device)
self.assertTrue(getattr(ov_model, additional_part).request is not None)
ov_outputs = ov_model(**inputs)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))

Expand Down Expand Up @@ -1961,6 +1982,15 @@ def test_generate_utils(self, model_arch):

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_model_can_be_loaded_after_saving(self, model_arch):
model_id = MODEL_NAMES[model_arch]
with TemporaryDirectory() as save_dir:
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False)
ov_model.save_pretrained(save_dir)
ov_restored_model = OVModelForVisualCausalLM.from_pretrained(save_dir, compile=False)
self.assertIsInstance(ov_restored_model, type(ov_model))


class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("whisper",)
Expand Down

0 comments on commit 7cc52a7

Please sign in to comment.