diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 58eb2163d0..9384477eb9 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -79,7 +79,13 @@ def __init__( height = -1 if self.export_feature == "image-classification" else None width = -1 if self.export_feature == "image-classification" else None model = self._reshape(model, -1, -1, height, width) - self.input_names = {key.get_any_name(): idx for idx, key in enumerate(model.inputs)} + + input_names = {} + for idx, key in enumerate(model.inputs): + names = tuple(key.get_names()) + input_names[next((name for name in names if "/" not in name), names[0])] = idx + self.input_names = input_names + self.model = model self.request = None if enable_compilation: @@ -153,6 +159,7 @@ def _from_pretrained( force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, + subfolder: str = "", from_onnx: bool = False, local_files_only: bool = False, load_in_8bit: bool = False, @@ -184,38 +191,59 @@ def _from_pretrained( local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). """ + + model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name - # Load the model from local directory - if os.path.isdir(model_id): - file_name = os.path.join(model_id, file_name) - model_save_dir = model_id - # Download the model from the hub - else: - model_file_names = [file_name] - # If not ONNX then OpenVINO IR + model_cache_path = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit) + return cls(model, config=config, model_save_dir=model_cache_path.parent, **kwargs) - if not from_onnx: - model_file_names.append(file_name.replace(".xml", ".bin")) - file_names = [] + @staticmethod + def _cached_file( + model_path: Union[Path, str], + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + ): + # locates a file in a local folder and repo, downloads and cache it if necessary. + model_path = Path(model_path) + if model_path.is_dir(): + model_cache_path = model_path / file_name + else: + file_name = Path(file_name) + if file_name.suffix != "onnx": + model_file_names = [file_name.with_suffix(".bin"), file_name] + else: + model_file_names = [file_name] for file_name in model_file_names: model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, + repo_id=model_path.as_posix(), + filename=file_name.as_posix(), + subfolder=subfolder, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, ) - file_names.append(model_cache_path) - model_save_dir = Path(model_cache_path).parent - file_name = file_names[0] - - model = cls.load_model(file_name, load_in_8bit=load_in_8bit) + model_cache_path = Path(model_cache_path) - return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + return model_cache_path @classmethod def _from_transformers( diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 2645408158..0e018f9f62 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -31,8 +31,9 @@ from ...exporters.openvino import main_export from ..utils.import_utils import is_transformers_version +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel -from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE +from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE if is_transformers_version("<", "4.25.0"): @@ -78,8 +79,9 @@ "bloom", "codegen", "gpt2", - "gpt_neo", - "gpt_neox", + "gpt-bigcode", + "gpt-neo", + "gpt-neox", "llama", "marian", "opt", @@ -213,7 +215,7 @@ def _from_transformers( load_in_8bit: bool = False, **kwargs, ): - if config.model_type not in _SUPPORTED_ARCHITECTURES: + if config.model_type.replace("_", "-") not in _SUPPORTED_ARCHITECTURES: logger.warning( f"This architecture : {config.model_type} was not validated, only :{', '.join(_SUPPORTED_ARCHITECTURES)} architectures were " "validated, use at your own risk." @@ -347,6 +349,7 @@ def forward( **kwargs, ) -> CausalLMOutputWithPast: self.compile() + inputs = {} if self.use_cache and past_key_values is not None: input_ids = input_ids[:, -1:] @@ -354,41 +357,43 @@ def forward( inputs = {} past_len = 0 if past_key_values is not None: - seq_len_dim = 1 if self.model.input(self.key_value_input_names[0]).get_partial_shape()[1].is_dynamic else 2 - past_len = past_key_values[0][0].shape[seq_len_dim] - if self._pkv_precision == Type.bf16: - # numpy does not support bf16, pretending f16, should change to bf16 - past_key_values = tuple( - Tensor(past_key_value, past_key_value.shape, Type.bf16) - for pkv_per_layer in past_key_values - for past_key_value in pkv_per_layer - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + past_len = past_key_values[0][1].shape[-2] + if self._pkv_precision == Type.bf16: + # numpy does not support bf16, pretending f16, should change to bf16 + past_key_values = tuple( + Tensor(past_key_value, past_key_value.shape, Type.bf16) + for pkv_per_layer in past_key_values + for past_key_value in pkv_per_layer + ) + else: + # Flatten the past_key_values + past_key_values = tuple( + past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer + ) else: - # Flatten the past_key_values - past_key_values = tuple( - past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer - ) + past_len = past_key_values[0].shape[-2] + # Add the past_key_values to the decoder inputs inputs = dict(zip(self.key_value_input_names, past_key_values)) # Create empty past_key_values for decoder_with_past first generation step elif self.use_cache: - shape_input_ids = input_ids.shape - num_attention_heads = ( - self.normalized_config.num_attention_heads if self.config.model_type == "bloom" else 1 - ) + batch_size = input_ids.shape[0] + if self.config.model_type == "bloom": + batch_size *= self.normalized_config.num_attention_heads + for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) shape = model_inputs.get_partial_shape() - shape[0] = shape_input_ids[0] * num_attention_heads + shape[0] = batch_size if shape[2].is_dynamic: shape[2] = 0 - if shape[1].is_dynamic: + else: shape[1] = 0 inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) inputs["input_ids"] = np.array(input_ids) - # Add the attention_mask inputs when needed if "attention_mask" in self.input_names: if attention_mask is not None: @@ -401,16 +406,16 @@ def forward( # Run inference self.request.start_async(inputs, shared_memory=True) self.request.wait() - logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) - past_key_values = tuple( - past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) + past_key_values = tuple( + past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) + ) else: past_key_values = None @@ -418,41 +423,114 @@ def forward( # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - past_key_values = past_key_values or kwargs.get("past", None) - - # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed - if past_key_values is not None and self.config.model_type == "bloom": - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "past_key_values": past_key_values, - "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), - "token_type_ids": None, + "use_cache": use_cache, + "position_ids": position_ids, + "attention_mask": attention_mask, } + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache + @staticmethod def _reorder_cache( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - - if self.config.model_type == "bloom": - return self._reorder_cache_bloom(past_key_values, beam_idx) - - # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values ) - # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - def _reorder_cache_bloom( + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: PretrainedConfig, + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + from_onnx: bool = False, + local_files_only: bool = False, + load_in_8bit: bool = False, + **kwargs, + ): + model_path = Path(model_id) + default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME + file_name = file_name or default_file_name + + model_cache_path = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + + model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit) + + model_type = config.model_type.replace("_", "-") + if model_type == "bloom": + init_cls = OVBloomForCausalLM + elif model_type == "mpt": + init_cls = OVMPTForCausalLM + elif model_type == "opt": + init_cls = OVOPTForCausalLM + elif model_type == "gpt-bigcode": + init_cls = OVGPTBigCodeForCausalLM + else: + init_cls = OVModelForCausalLM + + return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) + + +class OVBloomForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + # only last token for input_ids if past is not None + if past_key_values: + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ @@ -461,7 +539,6 @@ def _reorder_cache_bloom( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) - reordered_past = tuple( ( np.take(layer_past[0], beam_idx, 0), @@ -496,9 +573,6 @@ def _convert_to_standard_cache( """ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ - if self.config.model_type != "bloom": - return past_key_value - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape num_heads = batch_size_times_num_heads // batch_size # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] @@ -511,6 +585,41 @@ def _convert_to_standard_cache( for layer_past in past_key_value ) - def can_generate(self): - """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" - return True + +class OVOPTForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + +class OVMPTForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + +class OVGPTBigCodeForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 17abf1059e..b56e5e4f2d 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,6 +18,12 @@ from transformers.modeling_utils import PreTrainedModel +# from ...utils.modeling_utils import _prepare_decoder_sliding_window_attention_mask + + +MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} + + # Modified from transformers.models.bloom.modeling_bloom._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, @@ -106,6 +112,8 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): model.transformer._prepare_attn_mask = _prepare_attn_mask elif model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + # elif model.config.model_type == "mistral": + # model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model diff --git a/setup.py b/setup.py index 5fc16692ee..6d81b98b2a 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ ], "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], "nncf": ["nncf>=2.6.0"], - "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx", "torch<2.1.0"], + "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index bcd1bdb903..bf1a007844 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -449,6 +449,7 @@ def test_pipeline(self, model_arch): class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", + "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", @@ -459,6 +460,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neox", "llama", "marian", + # "mistral", "mpt", "opt", "pegasus", diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index c2feb4d264..2fa77052eb 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -39,7 +39,7 @@ "distilbert": "hf-internal-testing/tiny-random-distilbert", "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", - # "gpt_bigcode": "bigcode/tiny_starcoder_py", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", @@ -51,8 +51,9 @@ "llama": "fxmarty/tiny-llama-fast-tokenizer", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", - "marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken + "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "echarlaix/tiny-random-mistral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",