Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Text2text task on ipex #1054

Merged
merged 13 commits into from
Dec 17, 2024
3 changes: 2 additions & 1 deletion docs/source/ipex/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m

## Loading

You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean here, I see torch.compile being applied to text-generation task, why does it say "except text-generation tasks" here ?

Copy link
Collaborator Author

@jiqing-feng jiqing-feng Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we didn't apply torch.compile in text-generation task which means IPEXModelForCausalLM doesn't have torch.compile in init. And generation tasks are also excluded in IPEXModel.init when calling torch.compile

For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.

```diff
Expand Down Expand Up @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au
| `IPEXModelForMaskedLM` | `fill-mask` |
| `IPEXModelForAudioClassification` | `audio-classification` |
| `IPEXModelForCausalLM` | `text-generation` |
| `IPEXModelForSeq2SeqLM` | `text2text-generation` |
1 change: 1 addition & 0 deletions docs/source/ipex/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Here is the list of the supported architectures :
- Roberta
- Roformer
- SqueezeBert
- T5
- UniSpeech
- Vit
- Wav2Vec2
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
else:
_import_structure["ipex"] = [
"IPEXModelForCausalLM",
"IPEXModelForSeq2SeqLM",
"IPEXModelForSequenceClassification",
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
Expand Down Expand Up @@ -247,6 +248,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand Down
152 changes: 123 additions & 29 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
GenerationConfig,
Expand Down Expand Up @@ -60,8 +61,8 @@
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit")
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")


def _is_patched_with_ipex(model, task, use_cache: bool = True):
Expand All @@ -84,15 +85,21 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: Optional[bool] = True,
**kwargs,
):
config = config or model.config
OptimizedModel.__init__(self, model=model, config=config)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32
self.use_cache = kwargs.get("use_cache", False)
self.model_save_dir = model_save_dir
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
self.compiled = False

self.input_names = set(inspect.signature(model.forward).parameters)

Expand All @@ -104,25 +111,10 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

# Non-generation tasks can use torch.compile to get acceleration.
if (
model.device.type == "cpu"
and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS
and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization, start warm up")
self.model.forward = torch.compile(self.model.forward)
inputs = prepare_jit_inputs(model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")
self.maybe_apply_torch_compile()

if warmup:
self._init_warmup()

@classmethod
def _from_transformers(cls, *args, **kwargs):
Expand Down Expand Up @@ -192,6 +184,31 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def maybe_apply_torch_compile(self):
if (
self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
return
if self.use_cache and not self._supports_static_cache:
return
from torch._inductor import config as inductor_config

# System level optimization
inductor_config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization")
self.model.forward = torch.compile(self.model.forward)
self.compiled = True

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -236,16 +253,10 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
if self._add_patch:
self._supports_cache_class = True
GenerationMixin.__init__(self)
Expand All @@ -269,6 +280,9 @@ def __init__(
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
self,
Expand All @@ -283,6 +297,9 @@ def _prepare_generation_config(
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
# Use static cache for torch compile
generation_config.cache_implementation = "static"
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
raise ValueError(
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
Expand Down Expand Up @@ -335,6 +352,83 @@ def generate(self, *args, **kwargs):

return result

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
export_feature = "text2text-generation"

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
GenerationMixin.__init__(self)

model_type = self.config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config)

self.config.is_decoder = False
self.config.is_encoder_decoder = True

self.generation_config = GenerationConfig.from_model_config(self.config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
# Use static cache for torch.compile
if self.compiled:
generation_config.cache_implementation = "static"

return generation_config, model_kwargs

def _reorder_cache(self, *args, **kwargs):
return self.model._reorder_cache(*args, **kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def get_encoder(self, *args, **kwargs):
return self.model.get_encoder(*args, **kwargs)

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/ipex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_HEAD_TO_AUTOMODELS = {
"feature-extraction": "IPEXModel",
"text-generation": "IPEXModelForCausalLM",
"text2text-generation": "IPEXModelForSeq2SeqLM",
"text-classification": "IPEXModelForSequenceClassification",
"token-classification": "IPEXModelForTokenClassification",
"question-answering": "IPEXModelForQuestionAnswering",
Expand Down
19 changes: 19 additions & 0 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSeq2SeqLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
Expand All @@ -69,6 +70,24 @@
"default": "gpt2",
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-base",
"type": "text",
},
"translation": {
"impl": TranslationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"class": (IPEXModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"fill-mask": {
"impl": FillMaskPipeline,
"class": (IPEXModelForMaskedLM,),
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/utils/dummy_ipex_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForSeq2SeqLM(metaclass=DummyObject):
_backends = ["ipex"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["ipex"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["ipex"])


class IPEXModelForQuestionAnswering(metaclass=DummyObject):
_backends = ["ipex"]

Expand Down
Loading