Skip to content

Commit

Permalink
Enable Text2text task on ipex (#1054)
Browse files Browse the repository at this point in the history
* enable IPEXModelForSeq2SeqLM

Signed-off-by: jiqing-feng <[email protected]>

* set static cache

Signed-off-by: jiqing-feng <[email protected]>

* add tests for IPEXModelForSeq2SeqLM

Signed-off-by: jiqing-feng <[email protected]>

* add docs

Signed-off-by: jiqing-feng <[email protected]>

* fix readme

Signed-off-by: jiqing-feng <[email protected]>

* refactor compile

Signed-off-by: jiqing-feng <[email protected]>

* fix check

Signed-off-by: jiqing-feng <[email protected]>

* fix ruff check

Signed-off-by: jiqing-feng <[email protected]>

* fix check

Signed-off-by: jiqing-feng <[email protected]>

* fix tests

Signed-off-by: jiqing-feng <[email protected]>

* fix opt tests

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng authored Dec 17, 2024
1 parent 3c229fc commit a76be08
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 31 deletions.
3 changes: 2 additions & 1 deletion docs/source/ipex/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m

## Loading

You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.

```diff
Expand Down Expand Up @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au
| `IPEXModelForMaskedLM` | `fill-mask` |
| `IPEXModelForAudioClassification` | `audio-classification` |
| `IPEXModelForCausalLM` | `text-generation` |
| `IPEXModelForSeq2SeqLM` | `text2text-generation` |
1 change: 1 addition & 0 deletions docs/source/ipex/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Here is the list of the supported architectures :
- Roberta
- Roformer
- SqueezeBert
- T5
- UniSpeech
- Vit
- Wav2Vec2
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_import_structure["utils.dummy_ipex_objects"] = []
_import_structure["ipex"] = [
"IPEXModelForCausalLM",
"IPEXModelForSeq2SeqLM",
"IPEXModelForSequenceClassification",
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
Expand Down Expand Up @@ -248,6 +249,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
152 changes: 123 additions & 29 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
GenerationConfig,
Expand Down Expand Up @@ -60,8 +61,8 @@
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit")
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")


def _is_patched_with_ipex(model, task, use_cache: bool = True):
Expand All @@ -84,15 +85,21 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: Optional[bool] = True,
**kwargs,
):
config = config or model.config
OptimizedModel.__init__(self, model=model, config=config)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32
self.use_cache = kwargs.get("use_cache", False)
self.model_save_dir = model_save_dir
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
self.compiled = False

self.input_names = set(inspect.signature(model.forward).parameters)

Expand All @@ -104,25 +111,10 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

# Non-generation tasks can use torch.compile to get acceleration.
if (
model.device.type == "cpu"
and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS
and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization, start warm up")
self.model.forward = torch.compile(self.model.forward)
inputs = prepare_jit_inputs(model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")
self.maybe_apply_torch_compile()

if warmup:
self._init_warmup()

@classmethod
def _from_transformers(cls, *args, **kwargs):
Expand Down Expand Up @@ -192,6 +184,31 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def maybe_apply_torch_compile(self):
if (
self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
return
if self.use_cache and not self._supports_static_cache:
return
from torch._inductor import config as inductor_config

# System level optimization
inductor_config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization")
self.model.forward = torch.compile(self.model.forward)
self.compiled = True

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -236,16 +253,10 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
if self._add_patch:
self._supports_cache_class = True
GenerationMixin.__init__(self)
Expand All @@ -269,6 +280,9 @@ def __init__(
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
self,
Expand All @@ -285,6 +299,9 @@ def _prepare_generation_config(
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
# Use static cache for torch compile
generation_config.cache_implementation = "static"
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
raise ValueError(
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
Expand Down Expand Up @@ -337,6 +354,83 @@ def generate(self, *args, **kwargs):

return result

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
export_feature = "text2text-generation"

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
GenerationMixin.__init__(self)

model_type = self.config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config)

self.config.is_decoder = False
self.config.is_encoder_decoder = True

self.generation_config = GenerationConfig.from_model_config(self.config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
# Use static cache for torch.compile
if self.compiled:
generation_config.cache_implementation = "static"

return generation_config, model_kwargs

def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def get_encoder(self, *args, **kwargs):
return self.model.get_encoder(*args, **kwargs)

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_HEAD_TO_AUTOMODELS = {
"feature-extraction": "IPEXModel",
"text-generation": "IPEXModelForCausalLM",
"text2text-generation": "IPEXModelForSeq2SeqLM",
"text-classification": "IPEXModelForSequenceClassification",
"token-classification": "IPEXModelForTokenClassification",
"question-answering": "IPEXModelForQuestionAnswering",
Expand Down
19 changes: 19 additions & 0 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand All @@ -69,6 +70,24 @@
"default": "gpt2",
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-base",
"type": "text",
},
"translation": {
"impl": TranslationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"fill-mask": {
"impl": FillMaskPipeline,
"class": (IPEXModelForMaskedLM,),
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/utils/dummy_ipex_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForSeq2SeqLM(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForQuestionAnswering(metaclass=DummyObject):
_backends = ["ipex"]

Expand Down
Loading

0 comments on commit a76be08

Please sign in to comment.