Skip to content

Commit

Permalink
Add INC modeling position_ids generation (#456)
Browse files Browse the repository at this point in the history
* add position_ids in forward

* check if jit model need position_ids

* use MODEL_TYPES_REQUIRING_POSITION_IDS

* fix has_position_ids

* fix position_ids length

* rm useless params

* check model inputs by input names

* fix format

* check input names in graph model

* fix style

* consider eager model in input_names

* add input names

* add text input names

* fix styl;e

* Update optimum/intel/generation/modeling.py

* fix format

* Update optimum/intel/generation/modeling.py

---------

Co-authored-by: Ella Charlaix <[email protected]>
Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2024
1 parent 03e1fa6 commit c64025d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
38 changes: 22 additions & 16 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask


if is_transformers_version("<", "4.25.0"):
from transformers.generation_utils import GenerationMixin
else:
from transformers.generation import GenerationMixin


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -112,12 +106,14 @@ def __init__(
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
if isinstance(model, torch.jit.ScriptModule):
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
else:
from transformers import GenerationConfig
self.input_names = set()

self.generation_config = GenerationConfig.from_model_config(config)
self.generation_config = GenerationConfig.from_model_config(config)

# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand Down Expand Up @@ -267,6 +263,7 @@ def forward(
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
# 1. Prepare model inputs
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

Expand All @@ -275,6 +272,15 @@ def forward(
"attention_mask": attention_mask,
}

if "position_ids" in self.input_names and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

if "position_ids" in self.input_names or not self.input_names:
inputs["position_ids"] = position_ids

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
Expand Down Expand Up @@ -308,17 +314,17 @@ def forward(

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

# 2. Model forward
outputs = self.model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
logits = outputs[0]
past_key_values = outputs[1] if self.use_cache else None
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


Expand Down
12 changes: 12 additions & 0 deletions tests/generation/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,15 @@ def test_compare_with_and_without_past_key_values(self):
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_input_names(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = TSModelForCausalLM.from_pretrained(model_id, export=True)
self.assertTrue(isinstance(model.input_names, set))
self.assertTrue("input_ids" in model.input_names)
self.assertTrue("attention_mask" in model.input_names)
if model.use_cache:
self.assertTrue("past_key_values" in model.input_names)
else:
self.assertTrue("past_key_values" not in model.input_names)

0 comments on commit c64025d

Please sign in to comment.