Skip to content

Commit

Permalink
refactor compile
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 11, 2024
1 parent de501f4 commit 4225bf0
Showing 1 changed file with 64 additions and 55 deletions.
119 changes: 64 additions & 55 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@
logger = logging.getLogger(__name__)


_IPEX_SUPPORTED_GENERATION_TASKS = ("text-generation", "text2text-generation")
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit")
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit", "llama", "falcon", "gpt2")


def _is_patched_with_ipex(model, task, use_cache: bool = True):
Expand All @@ -86,15 +85,21 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: Optional[bool] = True,
**kwargs,
):
config = config or model.config
OptimizedModel.__init__(self, model=model, config=config)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32
self.use_cache = kwargs.get("use_cache", False)
self.model_save_dir = model_save_dir
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
self.compiled = False

self.input_names = set(inspect.signature(model.forward).parameters)

Expand All @@ -106,25 +111,10 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

# Non-generation tasks can use torch.compile to get acceleration.
if (
self.model.device.type == "cpu"
and self.export_feature not in _IPEX_SUPPORTED_GENERATION_TASKS
and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization, start warm up")
self.model.forward = torch.compile(self.model.forward)
inputs = prepare_jit_inputs(model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")
self.maybe_apply_torch_compile()

if warmup:
self._init_warmup()

@classmethod
def _from_transformers(cls, *args, **kwargs):
Expand Down Expand Up @@ -194,6 +184,31 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def maybe_apply_torch_compile(self):
if (
not self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
return
if self.use_cache and not self._supports_static_cache:
return
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization")
self.model.forward = torch.compile(self.model.forward)
self.compiled = True

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
with torch.no_grad():
self.model(**inputs)
self.model(**inputs)
logger.info("Warm up end")


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -238,16 +253,10 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
if self._add_patch:
self._supports_cache_class = True
GenerationMixin.__init__(self)
Expand All @@ -271,6 +280,9 @@ def __init__(
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
self,
Expand All @@ -285,6 +297,9 @@ def _prepare_generation_config(
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if self.compiled:
# Use static cache for torch compile
generation_config.cache_implementation = "static"
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
raise ValueError(
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
Expand Down Expand Up @@ -337,6 +352,12 @@ def generate(self, *args, **kwargs):

return result

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
Expand All @@ -348,15 +369,10 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: Optional[bool] = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
self._supports_static_cache = getattr(model, "_supports_static_cache", None)

super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
GenerationMixin.__init__(self)

model_type = self.config.model_type.replace("_", "-")
Expand All @@ -375,23 +391,9 @@ def __init__(

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if (
self._supports_static_cache
and self.model.device.type == "cpu"
and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
):
from torch._inductor import config

# System level optimization
torch._inductor.config.cpp_wrapper = True
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
logger.info("Enable torch.compile optimization, start warm up")
self.model.forward = torch.compile(self.model.forward)
inputs = prepare_jit_inputs(model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4)
logger.info("Warm up end")

if warmup:
self._init_warmup()

@torch.no_grad()
def forward(
Expand All @@ -407,7 +409,8 @@ def _prepare_generation_config(
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
# Use static cache for torch.compile
setattr(generation_config, "cache_implementation", "static")
if self.compiled:
generation_config.cache_implementation = "static"

return generation_config, model_kwargs

Expand All @@ -420,6 +423,12 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
def get_encoder(self, *args, **kwargs):
return self.model.get_encoder(*args, **kwargs)

def _init_warmup(self):
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
logger.info("Warm up end")


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
Expand Down

0 comments on commit 4225bf0

Please sign in to comment.