diff --git a/docs/source/ipex/inference.mdx b/docs/source/ipex/inference.mdx index 54b586924d..72826da595 100644 --- a/docs/source/ipex/inference.mdx +++ b/docs/source/ipex/inference.mdx @@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m ## Loading -You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. +You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22. ```diff @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au | `IPEXModelForMaskedLM` | `fill-mask` | | `IPEXModelForAudioClassification` | `audio-classification` | | `IPEXModelForCausalLM` | `text-generation` | +| `IPEXModelForSeq2SeqLM` | `text2text-generation` | diff --git a/docs/source/ipex/models.mdx b/docs/source/ipex/models.mdx index 346ca26599..b8cd6c482f 100644 --- a/docs/source/ipex/models.mdx +++ b/docs/source/ipex/models.mdx @@ -40,6 +40,7 @@ Here is the list of the supported architectures : - Roberta - Roformer - SqueezeBert +- T5 - UniSpeech - Vit - Wav2Vec2 diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index b441b76f93..ad9fdca078 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -54,6 +54,7 @@ _import_structure["utils.dummy_ipex_objects"] = [] _import_structure["ipex"] = [ "IPEXModelForCausalLM", + "IPEXModelForSeq2SeqLM", "IPEXModelForSequenceClassification", "IPEXModelForMaskedLM", "IPEXModelForTokenClassification", @@ -248,6 +249,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) diff --git a/optimum/intel/ipex/__init__.py b/optimum/intel/ipex/__init__.py index 62e6afcf6b..9aae96b08a 100644 --- a/optimum/intel/ipex/__init__.py +++ b/optimum/intel/ipex/__init__.py @@ -20,6 +20,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index d8f830e519..af36d06f4d 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -30,6 +30,7 @@ AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, GenerationConfig, @@ -60,8 +61,8 @@ _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" -# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit") +# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2") def _is_patched_with_ipex(model, task, use_cache: bool = True): @@ -84,15 +85,21 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + warmup: Optional[bool] = True, **kwargs, ): config = config or model.config OptimizedModel.__init__(self, model=model, config=config) + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32 self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) + self.compiled = False self.input_names = set(inspect.signature(model.forward).parameters) @@ -104,25 +111,10 @@ def __init__( if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - # Non-generation tasks can use torch.compile to get acceleration. - if ( - model.device.type == "cpu" - and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS - and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES - and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) - ): - from torch._inductor import config - - # System level optimization - torch._inductor.config.cpp_wrapper = True - os.environ["TORCHINDUCTOR_FREEZING"] = "1" - logger.info("Enable torch.compile optimization, start warm up") - self.model.forward = torch.compile(self.model.forward) - inputs = prepare_jit_inputs(model, self.export_feature, False) - with torch.no_grad(): - self.model(**inputs) - self.model(**inputs) - logger.info("Warm up end") + self.maybe_apply_torch_compile() + + if warmup: + self._init_warmup() @classmethod def _from_transformers(cls, *args, **kwargs): @@ -192,6 +184,31 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def maybe_apply_torch_compile(self): + if ( + self.model.device.type != "cpu" + or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES + or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + return + if self.use_cache and not self._supports_static_cache: + return + from torch._inductor import config as inductor_config + + # System level optimization + inductor_config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization") + self.model.forward = torch.compile(self.model.forward) + self.compiled = True + + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + with torch.no_grad(): + self.model(**inputs) + self.model(**inputs) + logger.info("Warm up end") + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -236,16 +253,10 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) - - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_sdpa = getattr(model, "_supports_sdpa", None) - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) - self._supports_static_cache = getattr(model, "_supports_static_cache", None) - + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) if self._add_patch: self._supports_cache_class = True GenerationMixin.__init__(self) @@ -269,6 +280,9 @@ def __init__( if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: + self._init_warmup() + @torch.no_grad() def forward( self, @@ -285,6 +299,9 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value + if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: + # Use static cache for torch compile + generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: raise ValueError( f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" @@ -337,6 +354,83 @@ def generate(self, *args, **kwargs): return result + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + + +class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin): + auto_model_class = AutoModelForSeq2SeqLM + export_feature = "text2text-generation" + + def __init__( + self, + model, + config: PretrainedConfig = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + use_cache: bool = True, + warmup: Optional[bool] = True, + **kwargs, + ): + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) + GenerationMixin.__init__(self) + + model_type = self.config.model_type.replace("_", "-") + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config) + + self.config.is_decoder = False + self.config.is_encoder_decoder = True + + self.generation_config = GenerationConfig.from_model_config(self.config) + try: + self.model_cls = get_class_from_dynamic_module( + self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir + ) + except AttributeError: + self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping) + + if hasattr(self.model_cls, "_convert_to_standard_cache"): + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache + + if warmup: + self._init_warmup() + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + # Use static cache for torch.compile + if self.compiled: + generation_config.cache_implementation = "static" + + return generation_config, model_kwargs + + def _reorder_cache(self, *args, **kwargs): + return self.model._reorder_cache(*args, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def get_encoder(self, *args, **kwargs): + return self.model.get_encoder(*args, **kwargs) + + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): diff --git a/optimum/intel/ipex/utils.py b/optimum/intel/ipex/utils.py index 3d3feb3db2..23126bcd4c 100644 --- a/optimum/intel/ipex/utils.py +++ b/optimum/intel/ipex/utils.py @@ -16,6 +16,7 @@ _HEAD_TO_AUTOMODELS = { "feature-extraction": "IPEXModel", "text-generation": "IPEXModelForCausalLM", + "text2text-generation": "IPEXModelForSeq2SeqLM", "text-classification": "IPEXModelForSequenceClassification", "token-classification": "IPEXModelForTokenClassification", "question-answering": "IPEXModelForQuestionAnswering", diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index 5b8531c674..04390ba3b1 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -58,6 +58,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) @@ -69,6 +70,24 @@ "default": "gpt2", "type": "text", }, + "summarization": { + "impl": SummarizationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-base", + "type": "text", + }, + "translation": { + "impl": TranslationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, + "text2text-generation": { + "impl": Text2TextGenerationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, "fill-mask": { "impl": FillMaskPipeline, "class": (IPEXModelForMaskedLM,), diff --git a/optimum/intel/utils/dummy_ipex_objects.py b/optimum/intel/utils/dummy_ipex_objects.py index de68e40023..7c1922305b 100644 --- a/optimum/intel/utils/dummy_ipex_objects.py +++ b/optimum/intel/utils/dummy_ipex_objects.py @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["ipex"]) +class IPEXModelForSeq2SeqLM(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) + + class IPEXModelForQuestionAnswering(metaclass=DummyObject): _backends = ["ipex"] diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index b595f6139f..419e1bb42a 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -26,6 +26,7 @@ from transformers import ( AutoFeatureExtractor, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoTokenizer, GenerationConfig, @@ -37,6 +38,7 @@ IPEXModel, IPEXModelForAudioClassification, IPEXModelForCausalLM, + IPEXModelForSeq2SeqLM, IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, @@ -45,7 +47,7 @@ IPEXSentenceTransformer, ) from optimum.utils.testing_utils import grid_parameters, require_sentence_transformers -from optimum.intel.utils.import_utils import is_sentence_transformers_available +from optimum.intel.utils.import_utils import is_sentence_transformers_available, is_torch_version if is_sentence_transformers_available(): from sentence_transformers import SentenceTransformer @@ -360,6 +362,9 @@ def test_ipex_beam_search(self, test_name, model_arch, use_cache): model = IPEXModelForCausalLM.from_pretrained( model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE ) + # It will be removed when torch 2.6 released + if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"): + return if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: self.assertTrue(model.add_patch) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) @@ -554,6 +559,126 @@ def test_patched_model(self): self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) +class IPEXModelForSeq2SeqLMTest(unittest.TestCase): + IPEX_MODEL_CLASS = IPEXModelForSeq2SeqLM + SUPPORTED_ARCHITECTURES = ("t5",) + GENERATION_LENGTH = 2 + SPEEDUP_CACHE = 1.0 + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + # Test model forward do not need cache. + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample", + return_tensors="pt", + return_token_type_ids=False if model_arch in ("llama", "llama2") else None, + ) + decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2 + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + outputs = ipex_model(**tokens, **decoder_inputs) + + self.assertIsInstance(outputs.logits, torch.Tensor) + + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens, **decoder_inputs) + + # Test re-load model + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype) + loaded_model_outputs = loaded_model(**tokens, **decoder_inputs) + + # Test init method + init_model = self.IPEX_MODEL_CLASS(transformers_model) + init_model_outputs = init_model(**tokens, **decoder_inputs) + + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + # To avoid float pointing error + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7)) + self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + model.config.encoder_no_repeat_ngram_size = 0 + # model.to("cpu") + pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) + outputs = pipe("This is a sample", max_new_tokens=10, do_sample=False) + self.assertEqual(pipe.device, model.device) + + def test_compare_with_and_without_past_key_values(self): + model_id = "hf-internal-testing/tiny-random-t5" + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_with_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=True, torch_dtype=dtype) + device = model_with_pkv.device + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt").to(device) + # Warmup + model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=False, torch_dtype=dtype) + # Warmup + model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True, False], + } + ) + ) + def test_ipex_beam_search(self, test_name, model_arch, use_cache): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype) + device = model.device + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device) + self.assertEqual(model.use_cache, use_cache) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + # Test with batch_size is 1 and 2. + texts = ["This is a sample", ["This is the first input", "This is the second input"]] + generation_configs = ( + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False), + GenerationConfig( + max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id + ), + ) + for text in texts: + tokens = tokenizer(text, padding=True, return_tensors="pt").to(device) + for generation_config in generation_configs: + outputs = model.generate(**tokens, generation_config=generation_config) + transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertTrue(torch.equal(outputs, transformers_outputs)) + + class IPEXSTModel(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "st-bert", diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index d9ddaf2586..f376c6050a 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -28,6 +28,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) @@ -83,6 +84,7 @@ class PipelinesIntegrationTest(unittest.TestCase): "resnet", "vit", ) + TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("t5",) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_token_classification_pipeline_inference(self, model_arch): @@ -224,3 +226,45 @@ def test_pipeline_load_from_jit_model(self, model_arch): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_text2text_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("text2text-generation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("text2text-generation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_summarization_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("summarization", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("summarization", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["summary_text"], ipex_output[0]["summary_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_translation_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("translation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("translation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["translation_text"], ipex_output[0]["translation_text"])