Skip to content

Commit

Permalink
use MODEL_TYPES_REQUIRING_POSITION_IDS
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Oct 18, 2023
1 parent 9a00f0c commit 02b1bf8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
10 changes: 4 additions & 6 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optimum.exporters import TasksManager
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
Expand Down Expand Up @@ -70,7 +71,6 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
has_position_ids = True if "position_ids" in model_inputs else False
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
Expand All @@ -89,7 +89,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model, has_position_ids
return traced_model


class PreTrainedModel(OptimizedModel):
Expand All @@ -108,7 +108,6 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
has_position_ids: bool = False,
**kwargs,
):
super(BaseModelForCausalLM, self).__init__(model=model, config=config)
Expand All @@ -118,7 +117,7 @@ def __init__(
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.has_position_ids = has_position_ids
self.has_position_ids = True if config.model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
Expand Down Expand Up @@ -441,7 +440,7 @@ def _from_transformers(
if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

traced_model, has_position_ids = jit_trace(model, task, use_cache)
traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -456,6 +455,5 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)
3 changes: 1 addition & 2 deletions optimum/intel/ipex/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __enter__(self):
use_cache = False
if hasattr(self._original.config, "use_cache") and self._original.config.use_cache:
use_cache = True
model, has_position_ids = jit_trace(
model = jit_trace(
model=model,
task=self._model.task,
use_cache=use_cache,
Expand All @@ -126,7 +126,6 @@ def __enter__(self):
config=self._original.config,
use_cache=use_cache,
model_dtype=self._original.dtype,
has_position_ids=has_position_ids,
)
except Exception as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
Expand Down
3 changes: 1 addition & 2 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _from_transformers(
if task == "text-generation":
model = patch_decoder_attention_mask(model)

traced_model, has_position_ids = jit_trace(model, task, use_cache)
traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -250,7 +250,6 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)

Expand Down

0 comments on commit 02b1bf8

Please sign in to comment.