Skip to content

Commit

Permalink
add tests for IPEXModelForSeq2SeqLM
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 9, 2024
1 parent f9fa807 commit 202df43
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
122 changes: 122 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformers import (
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForQuestionAnswering,
AutoTokenizer,
GenerationConfig,
Expand All @@ -38,6 +39,7 @@
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForSeq2SeqLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
Expand Down Expand Up @@ -510,3 +512,123 @@ def test_patched_model(self):
transformers_outputs = transformers_model(**inputs)
outputs = ipex_model(**inputs)
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))


class IPEXModelForSeq2SeqLMTest(unittest.TestCase):
IPEX_MODEL_CLASS = IPEXModelForSeq2SeqLM
SUPPORTED_ARCHITECTURES = ("t5",)
GENERATION_LENGTH = 2
SPEEDUP_CACHE = 1.0

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
# Test model forward do not need cache.
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype)
transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample",
return_tensors="pt",
return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
)
decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}
outputs = ipex_model(**tokens, **decoder_inputs)

self.assertIsInstance(outputs.logits, torch.Tensor)

with torch.no_grad():
transformers_outputs = transformers_model(**tokens, **decoder_inputs)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype)
loaded_model_outputs = loaded_model(**tokens, **decoder_inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**tokens, **decoder_inputs)

# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
# To avoid float pointing error
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype)
model.config.encoder_no_repeat_ngram_size = 0
# model.to("cpu")
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
outputs = pipe("This is a sample", max_new_tokens=10, do_sample=False)
self.assertEqual(pipe.device, model.device)

def test_compare_with_and_without_past_key_values(self):
model_id = "hf-internal-testing/tiny-random-t5"
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model_with_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=True, torch_dtype=dtype)
device = model_with_pkv.device
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt").to(device)
# Warmup
model_with_pkv.generate(**tokens)
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)
model_without_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=False, torch_dtype=dtype)
# Warmup
model_without_pkv.generate(**tokens)
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1)

@parameterized.expand(
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
}
)
)
def test_ipex_beam_search(self, test_name, model_arch, use_cache):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype)
device = model.device
transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
self.assertEqual(model.use_cache, use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Test with batch_size is 1 and 2.
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
generation_configs = (
GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False),
GenerationConfig(
max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id
),
)
for text in texts:
tokens = tokenizer(text, padding=True, return_tensors="pt").to(device)
for generation_config in generation_configs:
outputs = model.generate(**tokens, generation_config=generation_config)
transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs, torch.Tensor)
self.assertTrue(torch.equal(outputs, transformers_outputs))
44 changes: 44 additions & 0 deletions tests/ipex/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down Expand Up @@ -82,6 +83,7 @@ class PipelinesIntegrationTest(unittest.TestCase):
"resnet",
"vit",
)
TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("t5",)

@parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES)
def test_token_classification_pipeline_inference(self, model_arch):
Expand Down Expand Up @@ -215,3 +217,45 @@ def test_pipeline_load_from_jit_model(self, model_arch):
ipex_output = ipex_generator(inputs)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification))
self.assertGreaterEqual(ipex_output[0]["score"], 0.0)

@parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES)
def test_text2text_generation_pipeline_inference(self, model_arch):
model_id = MODEL_NAMES[model_arch]
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
transformers_generator = transformers_pipeline("text2text-generation", model_id, torch_dtype=dtype)
ipex_generator = ipex_pipeline("text2text-generation", model_id, accelerator="ipex", torch_dtype=dtype)
inputs = "Describe a real-world application of AI."
with torch.inference_mode():
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
with torch.inference_mode():
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM))
self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])

@parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES)
def test_summarization_generation_pipeline_inference(self, model_arch):
model_id = MODEL_NAMES[model_arch]
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
transformers_generator = transformers_pipeline("summarization", model_id, torch_dtype=dtype)
ipex_generator = ipex_pipeline("summarization", model_id, accelerator="ipex", torch_dtype=dtype)
inputs = "Describe a real-world application of AI."
with torch.inference_mode():
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
with torch.inference_mode():
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM))
self.assertEqual(transformers_output[0]["summary_text"], ipex_output[0]["summary_text"])

@parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES)
def test_translation_generation_pipeline_inference(self, model_arch):
model_id = MODEL_NAMES[model_arch]
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
transformers_generator = transformers_pipeline("translation", model_id, torch_dtype=dtype)
ipex_generator = ipex_pipeline("translation", model_id, accelerator="ipex", torch_dtype=dtype)
inputs = "Describe a real-world application of AI."
with torch.inference_mode():
transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10)
with torch.inference_mode():
ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10)
self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM))
self.assertEqual(transformers_output[0]["translation_text"], ipex_output[0]["translation_text"])

0 comments on commit 202df43

Please sign in to comment.