Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 3, 2023
1 parent 7c7c65a commit 1f9b163
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
16 changes: 15 additions & 1 deletion optimum/intel/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,25 @@
from transformers.trainer_utils import (
EvalPrediction,
HPSearchBackend,
ShardedDDPOption,
TrainOutput,
has_length,
speed_metrics,
)


try:
from transformers.trainer_utils import ShardedDDPOption
except ImportError:
from transformers.utils import ExplicitEnum

class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2"
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"


from transformers.utils import (
WEIGHTS_NAME,
is_apex_available,
Expand Down
67 changes: 67 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoTokenizer,
GenerationConfig,
Expand All @@ -64,6 +65,7 @@
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVStableDiffusionPipeline,
)
Expand Down Expand Up @@ -1199,3 +1201,68 @@ def test_compare_with_and_without_past_key_values(self):
del model_with_pkv
del model_without_pkv
gc.collect()


class OVModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("whisper", "speech_to_text")

def _generate_random_audio_data(self):
np.random.seed(10)
t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False)
# generate pure sine wave at 220 Hz
audio_data = 0.5 * np.sin(2 * np.pi * 220 * t)
return audio_data

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
processor = get_preprocessor(model_id)
inputs = processor(self._generate_random_audio_data(), return_tensors="pt")

with torch.no_grad():
transformers_outputs = transformers_model(**inputs)

for input_type in ["pt", "np"]:
inputs = processor(self._generate_random_audio_data(), return_tensors=input_type)
ov_outputs = ov_model(**inputs)
self.assertIn("logits", ov_outputs)
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
# Compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))

del transformers_model
del ov_model
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)
processor = get_preprocessor(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
)
data = self._generate_random_audio_data()

if model_arch == "whisper":
outputs = pipe(data, return_timestamps=True)
self.assertTrue("chunks" in outputs)
self.assertIsInstance(outputs["text"], str)

outputs = pipe(data, return_timestamps=False)
self.assertTrue("chunks" not in outputs)
self.assertIsInstance(outputs["text"], str)
else:
outputs = pipe(data)
self.assertIsInstance(outputs["text"], str)

del pipe
del model
gc.collect()
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"roberta": "hf-internal-testing/tiny-random-roberta",
"roformer": "hf-internal-testing/tiny-random-roformer",
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
"speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel",
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
Expand All @@ -83,6 +84,7 @@
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
"wav2vec2-hf": "hf-internal-testing/tiny-random-Wav2Vec2Model",
"wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer",
"whisper": "openai/whisper-tiny.en",
"xlm": "hf-internal-testing/tiny-random-xlm",
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
}
Expand Down

0 comments on commit 1f9b163

Please sign in to comment.