diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 7038515ad3..55b6e93522 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,12 +58,11 @@ logger = logging.getLogger(__name__) -_IPEX_SUPPORTED_GENERATION_TASKS = ("text-generation", "text2text-generation") _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" -# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit") +# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit", "llama", "falcon", "gpt2") def _is_patched_with_ipex(model, task, use_cache: bool = True): @@ -86,15 +85,21 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + warmup: Optional[bool] = True, **kwargs, ): config = config or model.config OptimizedModel.__init__(self, model=model, config=config) + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32 self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) + self.compiled = False self.input_names = set(inspect.signature(model.forward).parameters) @@ -106,25 +111,10 @@ def __init__( if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - # Non-generation tasks can use torch.compile to get acceleration. - if ( - self.model.device.type == "cpu" - and self.export_feature not in _IPEX_SUPPORTED_GENERATION_TASKS - and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES - and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) - ): - from torch._inductor import config - - # System level optimization - torch._inductor.config.cpp_wrapper = True - os.environ["TORCHINDUCTOR_FREEZING"] = "1" - logger.info("Enable torch.compile optimization, start warm up") - self.model.forward = torch.compile(self.model.forward) - inputs = prepare_jit_inputs(model, self.export_feature, False) - with torch.no_grad(): - self.model(**inputs) - self.model(**inputs) - logger.info("Warm up end") + self.maybe_apply_torch_compile() + + if warmup: + self._init_warmup() @classmethod def _from_transformers(cls, *args, **kwargs): @@ -194,6 +184,31 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def maybe_apply_torch_compile(self): + if ( + not self.model.device.type != "cpu" + or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES + or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + return + if self.use_cache and not self._supports_static_cache: + return + from torch._inductor import config + + # System level optimization + torch._inductor.config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization") + self.model.forward = torch.compile(self.model.forward) + self.compiled = True + + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + with torch.no_grad(): + self.model(**inputs) + self.model(**inputs) + logger.info("Warm up end") + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -238,16 +253,10 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) - - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_sdpa = getattr(model, "_supports_sdpa", None) - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) - self._supports_static_cache = getattr(model, "_supports_static_cache", None) - + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) if self._add_patch: self._supports_cache_class = True GenerationMixin.__init__(self) @@ -271,6 +280,9 @@ def __init__( if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: + self._init_warmup() + @torch.no_grad() def forward( self, @@ -285,6 +297,9 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value + if self.compiled: + # Use static cache for torch compile + generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: raise ValueError( f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" @@ -337,6 +352,12 @@ def generate(self, *args, **kwargs): return result + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM @@ -348,15 +369,10 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) - - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_sdpa = getattr(model, "_supports_sdpa", None) - self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) - self._supports_static_cache = getattr(model, "_supports_static_cache", None) - + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) GenerationMixin.__init__(self) model_type = self.config.model_type.replace("_", "-") @@ -375,23 +391,9 @@ def __init__( if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache - if ( - self._supports_static_cache - and self.model.device.type == "cpu" - and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES - and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) - ): - from torch._inductor import config - - # System level optimization - torch._inductor.config.cpp_wrapper = True - os.environ["TORCHINDUCTOR_FREEZING"] = "1" - logger.info("Enable torch.compile optimization, start warm up") - self.model.forward = torch.compile(self.model.forward) - inputs = prepare_jit_inputs(model, self.export_feature, False) - self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) - self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) - logger.info("Warm up end") + + if warmup: + self._init_warmup() @torch.no_grad() def forward( @@ -407,7 +409,8 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) # Use static cache for torch.compile - setattr(generation_config, "cache_implementation", "static") + if self.compiled: + generation_config.cache_implementation = "static" return generation_config, model_kwargs @@ -420,6 +423,12 @@ def prepare_inputs_for_generation(self, *args, **kwargs): def get_encoder(self, *args, **kwargs): return self.model.get_encoder(*args, **kwargs) + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):