diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index c165bbef3b..ce9e93e8f0 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -28,6 +28,7 @@ from optimum.exporters import TasksManager from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version, is_transformers_version @@ -70,7 +71,6 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) - has_position_ids = True if "position_ids" in model_inputs else False # check if the model_inputs is correct. model(**model_inputs) torch._C._jit_set_texpr_fuser_enabled(False) @@ -89,7 +89,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): traced_model(**model_inputs) traced_model(**model_inputs) - return traced_model, has_position_ids + return traced_model class PreTrainedModel(OptimizedModel): @@ -108,7 +108,6 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, - has_position_ids: bool = False, **kwargs, ): super(BaseModelForCausalLM, self).__init__(model=model, config=config) @@ -118,7 +117,7 @@ def __init__( self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", None) - self.has_position_ids = has_position_ids + self.has_position_ids = True if config.model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False if is_transformers_version("<=", "4.25.1"): self.generation_config = None @@ -441,7 +440,7 @@ def _from_transformers( if model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - traced_model, has_position_ids = jit_trace(model, task, use_cache) + traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) @@ -456,6 +455,5 @@ def _from_transformers( force_download=force_download, cache_dir=cache_dir, local_files_only=local_files_only, - has_position_ids=has_position_ids, **kwargs, ) diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index 69e596f48d..961c48dce7 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -115,7 +115,7 @@ def __enter__(self): use_cache = False if hasattr(self._original.config, "use_cache") and self._original.config.use_cache: use_cache = True - model, has_position_ids = jit_trace( + model = jit_trace( model=model, task=self._model.task, use_cache=use_cache, @@ -126,7 +126,6 @@ def __enter__(self): config=self._original.config, use_cache=use_cache, model_dtype=self._original.dtype, - has_position_ids=has_position_ids, ) except Exception as e: logger.warning(f"failed to use PyTorch jit mode due to: {e}.") diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index 4df656b3e2..19c06c8c4c 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -235,7 +235,7 @@ def _from_transformers( if task == "text-generation": model = patch_decoder_attention_mask(model) - traced_model, has_position_ids = jit_trace(model, task, use_cache) + traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) @@ -250,7 +250,6 @@ def _from_transformers( force_download=force_download, cache_dir=cache_dir, local_files_only=local_files_only, - has_position_ids=has_position_ids, **kwargs, )