diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index f0b40fa5d1..cb58f412a6 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -32,7 +32,6 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install git+https://github.com/huggingface/optimum.git pip install .[openvino,nncf,tests,diffusers] - name: Test with Pytest run: | diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index b4c41e0be1..65f83bebf9 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -31,7 +31,11 @@ from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version, is_transformers_version -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ..utils.modeling_utils import patch_decoder_attention_mask + +from ..utils.modeling_utils import patch_decoder_attention_mask, MULTI_QUERY_ATTN_MODELS + +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS if is_transformers_version("<", "4.25.0"): @@ -47,43 +51,29 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_class(model.config) - if task == "text-generation" and use_cache: - onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True) + if "text-generation" in task: + onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache) + else: + onnx_config = onnx_config_class(model.config) + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode": - # WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect - pkv = [] - for i in range(len(model_inputs["past_key_values"])): - pkv.append([]) - for j in range(len(model_inputs["past_key_values"][0])): - pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype)) - pkv[i] = tuple(pkv[i]) - model_inputs["past_key_values"] = tuple(pkv) - i = model_inputs["input_ids"] - a = model_inputs["attention_mask"] - model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype) - model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype) - return model_inputs + + return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) # check if the model_inputs is correct. model(**model_inputs) + torch._C._jit_set_texpr_fuser_enabled(False) if "past_key_values" in model_inputs.keys(): model.config.return_dict = False - if is_torch_version(">", "2.0.1"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + if is_torch_version(">=", "2.1.0"): + traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) else: - if is_torch_version(">=", "2.0.0"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.freeze(traced_model.eval()) traced_model(**model_inputs) traced_model(**model_inputs) @@ -91,11 +81,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): return traced_model -class PreTrainedModel(OptimizedModel): - pass - - -class BaseModelForCausalLM(PreTrainedModel, GenerationMixin): +class BaseModelForCausalLM(OptimizedModel, GenerationMixin): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" main_input_name = "input_ids" @@ -156,12 +142,28 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) + + position_ids = kwargs.get("position_ids", None) + + + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + + + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), + "position_ids": position_ids, + "attention_mask": attention_mask, "token_type_ids": None, } @@ -258,6 +260,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + position_ids: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: if attention_mask is None: @@ -268,43 +271,40 @@ def forward( "attention_mask": attention_mask, } + model_type = self.config.model_type.replace("_", "-") + if self.use_cache: if past_key_values is None: nb_pkv = 2 num_layers = self.normalized_config.num_layers - num_attention_heads = self.normalized_config.num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(self.normalized_config, "num_key_value_heads"): - num_key_value_heads = self.normalized_config.num_key_value_heads - hidden_size = self.normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - if self.config.model_type == "gpt_bigcode": - new_shape = [input_ids.shape[0], 0, d_k * 2] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - past_key_values = tuple([empty_tensor] * num_layers) - elif self.config.model_type != "bloom": - new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = tuple(empty_tensor for _ in range(nb_pkv)) + d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + batch_size = input_ids.shape[0] + + if model_type in {"mistral", "llama"}: + num_attention_heads = self.normalized_config.num_key_value_heads + else: + num_attention_heads = self.normalized_config.num_attention_heads + + if model_type == "bloom": + shape_key = (batch_size * num_attention_heads, d_k, 0) + shape_value = (batch_size * num_attention_heads, 0, d_k) + key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) + value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)) + elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: + shape = (batch_size, 0, d_k * 2) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(pkv for _ in range(num_layers)) else: - pkv = () - for nb_pkv in range(nb_pkv): - if nb_pkv % 2 == 0: - new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0] - else: - new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = pkv + (empty_tensor,) - if past_key_values is None: - past_key_values = tuple(tuple(pkv) for _ in range(num_layers)) + shape = (batch_size, num_attention_heads, 0, d_k) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) inputs["past_key_values"] = past_key_values + + if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + inputs["position_ids"] = position_ids + outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)): @@ -389,7 +389,7 @@ def _from_transformers( torch_dtype: Optional[Union[str, "torch.dtype"]] = None, **kwargs, ): - if is_torch_version("<", "2.0.0"): + if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature @@ -405,12 +405,7 @@ def _from_transformers( } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - - if model.config.model_type == "bloom": - model.transformer._prepare_attn_mask = _prepare_attn_mask - - if model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + model = patch_decoder_attention_mask(model) traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 36f16524c2..599f47e511 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -30,9 +30,18 @@ from neural_compressor.quantization import fit from torch.utils.data import DataLoader, RandomSampler from transformers import ( + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForMultipleChoice, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelForVision2Seq, DataCollator, PretrainedConfig, PreTrainedModel, + XLNetLMHeadModel, default_data_collator, ) @@ -528,3 +537,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t q_model = convert(q_model, mapping=q_mapping, inplace=True) return q_model + + +class IncQuantizedModel(INCModel): + @classmethod + def from_pretrained(cls, *args, **kwargs): + warnings.warn( + f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use " + f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead." + ) + return super().from_pretrained(*args, **kwargs) + + +class IncQuantizedModelForQuestionAnswering(IncQuantizedModel): + auto_model_class = AutoModelForQuestionAnswering + + +class IncQuantizedModelForSequenceClassification(IncQuantizedModel): + auto_model_class = AutoModelForSequenceClassification + + +class IncQuantizedModelForTokenClassification(IncQuantizedModel): + auto_model_class = AutoModelForTokenClassification + + +class IncQuantizedModelForMultipleChoice(IncQuantizedModel): + auto_model_class = AutoModelForMultipleChoice + + +class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel): + auto_model_class = AutoModelForSeq2SeqLM + + +class IncQuantizedModelForCausalLM(IncQuantizedModel): + auto_model_class = AutoModelForCausalLM + + +class IncQuantizedModelForMaskedLM(IncQuantizedModel): + auto_model_class = AutoModelForMaskedLM + + +class IncQuantizedModelForXLNetLM(IncQuantizedModel): + auto_model_class = XLNetLMHeadModel + + +class IncQuantizedModelForVision2Seq(IncQuantizedModel): + auto_model_class = AutoModelForVision2Seq diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 9384477eb9..694c4e58c0 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -43,17 +43,12 @@ logger = logging.getLogger(__name__) -# workaround to enable compatibility between openvino models and transformers pipelines -class PreTrainedModel(OptimizedModel): - pass - - @add_start_docstrings( """ Base OVModel class. """, ) -class OVBaseModel(PreTrainedModel): +class OVBaseModel(OptimizedModel): auto_model_class = None export_feature = None @@ -302,7 +297,7 @@ def _from_transformers( @classmethod def _to_load( cls, - model: PreTrainedModel, + model: "PreTrainedModel", config: PretrainedConfig, onnx_config: OnnxConfig, use_auth_token: Optional[Union[bool, str]] = None, diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 4061e1d9d9..1a3b6fbede 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,5 +12,139 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + +import torch +from transformers.modeling_utils import PreTrainedModel + MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} + + +# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + device: torch.device, + past_key_values_length: int, + dtype: torch.dtype = torch.bool, +) -> torch.BoolTensor: + """ + Make causal mask used for bi-directional self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) + seq_ids = torch.arange(target_length, device=device) + + mask[:, past_key_values_length:] = ( + (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min + if torch.is_floating_point(mask) + else seq_ids[:, None] < seq_ids[None, :] + ) + + return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + + +# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask +def _prepare_attn_mask( + attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int +) -> torch.BoolTensor: + from transformers.models.bloom.modeling_bloom import _expand_mask + + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + from transformers.models.llama.modeling_llama import _expand_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_causal_mask( + input_shape, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + dtype=inputs_embeds.dtype, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask +def _prepare_decoder_sliding_window_attention_mask( + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: int, +): + from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_sliding_window_causal_mask( + input_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def patch_decoder_attention_mask(model: "PreTrainedModel"): + """ + Apply patch on decoder with past model forward to resolve first inference based on model architecture + + Args: + model (PretrainedModel): The model to patch. + + Returns: + model with applied patch + """ + if model.config.model_type in {"bloom", "mpt"}: + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif model.config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type == "mistral": + model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask + elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + return model diff --git a/setup.py b/setup.py index 6d81b98b2a..e2da72bf3e 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum>=1.13.0", + "optimum @ git+https://github.com/huggingface/optimum.git", "transformers>=4.20.0", "datasets>=1.4.0", "sentencepiece",