Skip to content

Commit

Permalink
check if jit model need position_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Oct 17, 2023
1 parent e3f87a7 commit 9a00f0c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
33 changes: 25 additions & 8 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
has_position_ids = True if "position_ids" in model_inputs else False
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
Expand All @@ -88,7 +89,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model
return traced_model, has_position_ids


class PreTrainedModel(OptimizedModel):
Expand All @@ -107,6 +108,7 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
has_position_ids: bool = False,
**kwargs,
):
super(BaseModelForCausalLM, self).__init__(model=model, config=config)
Expand All @@ -116,6 +118,7 @@ def __init__(
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.has_position_ids = has_position_ids

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
Expand Down Expand Up @@ -276,10 +279,6 @@ def forward(
"attention_mask": attention_mask,
}

position_ids = kwargs.get("position_ids", None)
if position_ids is not None:
inputs.update({"position_ids": position_ids})

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
Expand All @@ -304,8 +303,8 @@ def forward(
pkv = tuple(empty_tensor for _ in range(nb_pkv))
else:
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
for i in range(nb_pkv):
if i % 2 == 0:
new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0]
else:
new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k]
Expand All @@ -318,6 +317,22 @@ def forward(

inputs["past_key_values"] = past_key_values

position_ids = kwargs.get("position_ids", None)
if self.has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif self.has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = past_key_values[0][1].shape[-2]
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not self.has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")

outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand All @@ -326,6 +341,7 @@ def forward(
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


Expand Down Expand Up @@ -425,7 +441,7 @@ def _from_transformers(
if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

traced_model = jit_trace(model, task, use_cache)
traced_model, has_position_ids = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -440,5 +456,6 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)
3 changes: 2 additions & 1 deletion optimum/intel/ipex/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __enter__(self):
use_cache = False
if hasattr(self._original.config, "use_cache") and self._original.config.use_cache:
use_cache = True
model = jit_trace(
model, has_position_ids = jit_trace(
model=model,
task=self._model.task,
use_cache=use_cache,
Expand All @@ -126,6 +126,7 @@ def __enter__(self):
config=self._original.config,
use_cache=use_cache,
model_dtype=self._original.dtype,
has_position_ids=has_position_ids,
)
except Exception as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _from_transformers(
if task == "text-generation":
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
traced_model, has_position_ids = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -250,6 +250,7 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)

Expand Down

0 comments on commit 9a00f0c

Please sign in to comment.