Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Text2text task on ipex #1054

Merged
merged 13 commits into from
Dec 17, 2024
3 changes: 2 additions & 1 deletion docs/source/ipex/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m

## Loading

You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean here, I see torch.compile being applied to text-generation task, why does it say "except text-generation tasks" here ?

Copy link
Collaborator Author

@jiqing-feng jiqing-feng Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we didn't apply torch.compile in text-generation task which means IPEXModelForCausalLM doesn't have torch.compile in init. And generation tasks are also excluded in IPEXModel.init when calling torch.compile

For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.

```diff
Expand Down Expand Up @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au
| `IPEXModelForMaskedLM` | `fill-mask` |
| `IPEXModelForAudioClassification` | `audio-classification` |
| `IPEXModelForCausalLM` | `text-generation` |
| `IPEXModelForSeq2SeqLM` | `text2text-generation` |
1 change: 1 addition & 0 deletions docs/source/ipex/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Here is the list of the supported architectures :
- Roberta
- Roformer
- SqueezeBert
- T5
- UniSpeech
- Vit
- Wav2Vec2
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
else:
_import_structure["ipex"] = [
"IPEXModelForCausalLM",
"IPEXModelForSeq2SeqLM",
"IPEXModelForSequenceClassification",
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
Expand Down Expand Up @@ -247,6 +248,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
91 changes: 88 additions & 3 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
GenerationConfig,
Expand Down Expand Up @@ -57,6 +58,7 @@
logger = logging.getLogger(__name__)


_IPEX_SUPPORTED_GENERATION_TASKS = ("text-generation", "text2text-generation")
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
Expand Down Expand Up @@ -106,9 +108,9 @@ def __init__(

# Non-generation tasks can use torch.compile to get acceleration.
if (
model.device.type == "cpu"
and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS
and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
self.model.device.type == "cpu"
and self.export_feature not in _IPEX_SUPPORTED_GENERATION_TASKS
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config
Expand Down Expand Up @@ -336,6 +338,89 @@ def generate(self, *args, **kwargs):
return result


class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
export_feature = "text2text-generation"

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

GenerationMixin.__init__(self)

model_type = self.config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config)

self.config.is_decoder = False
self.config.is_encoder_decoder = True

self.generation_config = GenerationConfig.from_model_config(self.config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if (
self._supports_static_cache
and self.model.device.type == "cpu"
and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization, start warm up")
self.model.forward = torch.compile(self.model.forward)
inputs = prepare_jit_inputs(model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4)
logger.info("Warm up end")

@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
# Use static cache for torch.compile
setattr(generation_config, "cache_implementation", "static")

return generation_config, model_kwargs

def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def get_encoder(self, *args, **kwargs):
return self.model.get_encoder(*args, **kwargs)


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
if isinstance(past_key_values, IPEXPagedCache):
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_HEAD_TO_AUTOMODELS = {
"feature-extraction": "IPEXModel",
"text-generation": "IPEXModelForCausalLM",
"text2text-generation": "IPEXModelForSeq2SeqLM",
"text-classification": "IPEXModelForSequenceClassification",
"token-classification": "IPEXModelForTokenClassification",
"question-answering": "IPEXModelForQuestionAnswering",
Expand Down
19 changes: 19 additions & 0 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand All @@ -69,6 +70,24 @@
"default": "gpt2",
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-base",
"type": "text",
},
"translation": {
"impl": TranslationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"fill-mask": {
"impl": FillMaskPipeline,
"class": (IPEXModelForMaskedLM,),
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/utils/dummy_ipex_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForSeq2SeqLM(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForQuestionAnswering(metaclass=DummyObject):
_backends = ["ipex"]

Expand Down
122 changes: 122 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformers import (
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForQuestionAnswering,
AutoTokenizer,
GenerationConfig,
Expand All @@ -37,6 +38,7 @@
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForSeq2SeqLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
Expand Down Expand Up @@ -523,6 +525,126 @@ def test_patched_model(self):
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))


class IPEXModelForSeq2SeqLMTest(unittest.TestCase):
IPEX_MODEL_CLASS = IPEXModelForSeq2SeqLM
SUPPORTED_ARCHITECTURES = ("t5",)
GENERATION_LENGTH = 2
SPEEDUP_CACHE = 1.0

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
# Test model forward do not need cache.
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype)
transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample",
return_tensors="pt",
return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
)
decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}
outputs = ipex_model(**tokens, **decoder_inputs)

self.assertIsInstance(outputs.logits, torch.Tensor)

with torch.no_grad():
transformers_outputs = transformers_model(**tokens, **decoder_inputs)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype)
loaded_model_outputs = loaded_model(**tokens, **decoder_inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**tokens, **decoder_inputs)

# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
# To avoid float pointing error
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype)
model.config.encoder_no_repeat_ngram_size = 0
# model.to("cpu")
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
outputs = pipe("This is a sample", max_new_tokens=10, do_sample=False)
self.assertEqual(pipe.device, model.device)

def test_compare_with_and_without_past_key_values(self):
model_id = "hf-internal-testing/tiny-random-t5"
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model_with_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=True, torch_dtype=dtype)
device = model_with_pkv.device
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt").to(device)
# Warmup
model_with_pkv.generate(**tokens)
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)
model_without_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=False, torch_dtype=dtype)
# Warmup
model_without_pkv.generate(**tokens)
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1)

@parameterized.expand(
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
}
)
)
def test_ipex_beam_search(self, test_name, model_arch, use_cache):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype)
device = model.device
transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
self.assertEqual(model.use_cache, use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Test with batch_size is 1 and 2.
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
generation_configs = (
GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False),
GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False),
GenerationConfig(
max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id
),
)
for text in texts:
tokens = tokenizer(text, padding=True, return_tensors="pt").to(device)
for generation_config in generation_configs:
outputs = model.generate(**tokens, generation_config=generation_config)
transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config)
self.assertIsInstance(outputs, torch.Tensor)
self.assertTrue(torch.equal(outputs, transformers_outputs))


class IPEXSTModel(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"st-bert",
Expand Down
Loading