diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index b17d93aa5e..cb011706c8 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -27,7 +27,6 @@ from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ...intel.utils.import_utils import is_nncf_available -from ...intel.utils.modeling_utils import patch_decoder_attention_mask from .convert import export_models @@ -257,24 +256,18 @@ class StoreAttr(object): preprocessors = maybe_load_preprocessors( model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code ) - if not task.startswith("text-generation"): - onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( - model=model, - task=task, - monolith=False, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, - custom_architecture=custom_architecture, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - _variant="default", - ) - else: - # TODO : ModelPatcher will be added in next optimum release - model = patch_decoder_attention_mask(model) - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config) - models_and_onnx_configs = {"model": (model, onnx_config)} + onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( + model=model, + task=task, + monolith=False, + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_architecture=custom_architecture, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + _variant="default", + legacy=False, + ) if int8 is None: int8 = False diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index b4c41e0be1..bbfc3db63d 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -26,12 +26,13 @@ from transformers.utils import WEIGHTS_NAME from optimum.exporters import TasksManager +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version, is_transformers_version -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask if is_transformers_version("<", "4.25.0"): @@ -47,43 +48,29 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_class(model.config) - if task == "text-generation" and use_cache: - onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True) + if "text-generation" in task: + onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache) + else: + onnx_config = onnx_config_class(model.config) + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode": - # WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect - pkv = [] - for i in range(len(model_inputs["past_key_values"])): - pkv.append([]) - for j in range(len(model_inputs["past_key_values"][0])): - pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype)) - pkv[i] = tuple(pkv[i]) - model_inputs["past_key_values"] = tuple(pkv) - i = model_inputs["input_ids"] - a = model_inputs["attention_mask"] - model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype) - model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype) - return model_inputs + + return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) # check if the model_inputs is correct. model(**model_inputs) + torch._C._jit_set_texpr_fuser_enabled(False) if "past_key_values" in model_inputs.keys(): model.config.return_dict = False - if is_torch_version(">", "2.0.1"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + if is_torch_version(">=", "2.1.0"): + traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) else: - if is_torch_version(">=", "2.0.0"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.freeze(traced_model.eval()) traced_model(**model_inputs) traced_model(**model_inputs) @@ -91,11 +78,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): return traced_model -class PreTrainedModel(OptimizedModel): - pass - - -class BaseModelForCausalLM(PreTrainedModel, GenerationMixin): +class BaseModelForCausalLM(OptimizedModel, GenerationMixin): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" main_input_name = "input_ids" @@ -156,12 +139,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) + position_ids = kwargs.get("position_ids", None) + + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), + "position_ids": position_ids, + "attention_mask": attention_mask, "token_type_ids": None, } @@ -258,6 +252,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + position_ids: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: if attention_mask is None: @@ -268,43 +263,42 @@ def forward( "attention_mask": attention_mask, } + model_type = self.config.model_type.replace("_", "-") + if self.use_cache: if past_key_values is None: nb_pkv = 2 num_layers = self.normalized_config.num_layers - num_attention_heads = self.normalized_config.num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(self.normalized_config, "num_key_value_heads"): - num_key_value_heads = self.normalized_config.num_key_value_heads - hidden_size = self.normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - if self.config.model_type == "gpt_bigcode": - new_shape = [input_ids.shape[0], 0, d_k * 2] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - past_key_values = tuple([empty_tensor] * num_layers) - elif self.config.model_type != "bloom": - new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = tuple(empty_tensor for _ in range(nb_pkv)) + d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + batch_size = input_ids.shape[0] + + if model_type in {"mistral", "llama"}: + num_attention_heads = self.normalized_config.num_key_value_heads else: - pkv = () - for nb_pkv in range(nb_pkv): - if nb_pkv % 2 == 0: - new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0] - else: - new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = pkv + (empty_tensor,) - if past_key_values is None: - past_key_values = tuple(tuple(pkv) for _ in range(num_layers)) + num_attention_heads = self.normalized_config.num_attention_heads + + if model_type == "bloom": + shape_key = (batch_size * num_attention_heads, d_k, 0) + shape_value = (batch_size * num_attention_heads, 0, d_k) + key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) + value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) + past_key_values = tuple( + tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers) + ) + elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: + shape = (batch_size, 0, d_k * 2) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(pkv for _ in range(num_layers)) + else: + shape = (batch_size, num_attention_heads, 0, d_k) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) inputs["past_key_values"] = past_key_values + + if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + inputs["position_ids"] = position_ids + outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)): @@ -389,7 +383,7 @@ def _from_transformers( torch_dtype: Optional[Union[str, "torch.dtype"]] = None, **kwargs, ): - if is_torch_version("<", "2.0.0"): + if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature @@ -405,12 +399,7 @@ def _from_transformers( } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - - if model.config.model_type == "bloom": - model.transformer._prepare_attn_mask = _prepare_attn_mask - - if model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + model = patch_decoder_attention_mask(model) traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index eb1ed88467..67e8d20502 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -43,17 +43,12 @@ logger = logging.getLogger(__name__) -# workaround to enable compatibility between openvino models and transformers pipelines -class PreTrainedModel(OptimizedModel): - pass - - @add_start_docstrings( """ Base OVModel class. """, ) -class OVBaseModel(PreTrainedModel): +class OVBaseModel(OptimizedModel): auto_model_class = None export_feature = None @@ -86,6 +81,12 @@ def __init__( input_names[next((name for name in names if "/" not in name), names[0])] = idx self.input_names = input_names + output_names = {} + for idx, key in enumerate(model.outputs): + names = tuple(key.get_names()) + output_names[next((name for name in names if "/" not in name), names[0])] = idx + self.output_names = output_names + self.model = model self.request = None if enable_compilation: @@ -302,7 +303,7 @@ def _from_transformers( @classmethod def _to_load( cls, - model: PreTrainedModel, + model, config: PretrainedConfig, onnx_config: OnnxConfig, use_auth_token: Optional[Union[bool, str]] = None, diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 4d87b7eec2..0fa21e3a4a 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -129,7 +129,6 @@ def __init__( self.main_input_name = "input_ids" self.num_pkv = 2 self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} self.key_value_input_names = [key for key in self.input_names if "key_values" in key] self.key_value_output_names = [key for key in self.output_names if "present" in key] self._original_model = self.model.clone() # keep original model for serialization @@ -313,6 +312,7 @@ def forward( input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: self.compile() @@ -362,14 +362,28 @@ def forward( inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: + if "attention_mask" in self.input_names or "position_ids" in self.input_names: if attention_mask is not None: - inputs["attention_mask"] = np.array(attention_mask) + attention_mask = np.array(attention_mask) else: - inputs["attention_mask"] = np.ones( + attention_mask = np.ones( (input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype ) + if "attention_mask" in self.input_names: + inputs["attention_mask"] = attention_mask + + if "position_ids" in self.input_names: + if position_ids is not None: + position_ids = np.array(position_ids) + else: + position_ids = np.cumsum(attention_mask, axis=1) - 1 + position_ids[attention_mask == 0] = 1 + if past_key_values: + position_ids = np.expand_dims(position_ids[:, -1], axis=-1) + + inputs["position_ids"] = position_ids + # Run inference self.request.start_async(inputs, shared_memory=True) self.request.wait() diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index bcc7c2908b..b94f61214d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -39,7 +39,6 @@ from ...exporters.openvino import export, export_pytorch_via_onnx from ..utils.constant import _TASK_ALIASES -from ..utils.modeling_utils import patch_decoder_attention_mask from .configuration import OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel @@ -394,9 +393,10 @@ def _quantize_torchmodel( task = self.task model = self.model self.model.config.save_pretrained(save_directory) - model = patch_decoder_attention_mask(model) - if task == "text-generation": - onnx_config = onnx_config_class(model.config, use_past=model.config.use_cache) + if task.startswith("text-generation"): + onnx_config = onnx_config_class( + model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache + ) else: onnx_config = onnx_config_class(model.config) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index b56e5e4f2d..1a3b6fbede 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,9 +18,6 @@ from transformers.modeling_utils import PreTrainedModel -# from ...utils.modeling_utils import _prepare_decoder_sliding_window_attention_mask - - MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -98,6 +95,40 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, return combined_attention_mask +# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask +def _prepare_decoder_sliding_window_attention_mask( + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: int, +): + from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_sliding_window_causal_mask( + input_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def patch_decoder_attention_mask(model: "PreTrainedModel"): """ Apply patch on decoder with past model forward to resolve first inference based on model architecture @@ -112,8 +143,8 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): model.transformer._prepare_attn_mask = _prepare_attn_mask elif model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - # elif model.config.model_type == "mistral": - # model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask + elif model.config.model_type == "mistral": + model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model diff --git a/setup.py b/setup.py index 0c1feace30..7949f6d11d 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,8 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum>=1.13.0", - "transformers", + "optimum>=1.14.0", + "transformers>=4.20.0", "datasets>=1.4.0", "sentencepiece", "scipy", diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py index 0fd668ad8f..db36b924f4 100644 --- a/tests/generation/test_modeling.py +++ b/tests/generation/test_modeling.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, pipeline, set_seed +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel.generation.modeling import TSModelForCausalLM @@ -28,6 +29,9 @@ "gptj": "hf-internal-testing/tiny-random-gptj", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "mistral": "echarlaix/tiny-random-mistral", + "llama": "fxmarty/tiny-llama-fast-tokenizer", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", } SEED = 42 @@ -48,7 +52,11 @@ class ModelingIntegrationTest(unittest.TestCase): "gpt2", "gptj", "gpt_neo", + "mistral", + "llama", + # "gpt_bigcode", ) + GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -61,7 +69,12 @@ def test_compare_to_transformers(self, model_arch): trfs_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer("This is a sample", return_tensors="pt") - outputs = model(**tokens) + + position_ids = None + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + input_shape = tokens["input_ids"].shape + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + outputs = model(**tokens, position_ids=position_ids) self.assertIsInstance(outputs.logits, torch.Tensor) with torch.no_grad(): trfs_outputs = trfs_model(**tokens) @@ -71,7 +84,8 @@ def test_compare_to_transformers(self, model_arch): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) loaded_model = TSModelForCausalLM.from_pretrained(tmpdirname) - loaded_model_outputs = loaded_model(**tokens) + loaded_model_outputs = loaded_model(**tokens, position_ids=position_ids) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -120,7 +134,6 @@ def test_compare_with_and_without_past_key_values(self): model_id = MODEL_NAMES["gpt2"] tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer("This is a sample input", return_tensors="pt") - model_with_pkv = TSModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True) # Warmup _ = model_with_pkv.generate(**tokens) @@ -136,6 +149,9 @@ def test_compare_with_and_without_past_key_values(self): outputs_model_without_pkv = model_without_pkv.generate( **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) + self.assertTrue(model_with_pkv.use_cache) + self.assertFalse(model_without_pkv.use_cache) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f3978b2965..c29e8c2eef 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -51,6 +51,7 @@ from transformers.onnx.utils import get_preprocessor from utils_tests import MODEL_NAMES +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel import ( OVModelForAudioClassification, OVModelForAudioFrameClassification, @@ -496,7 +497,12 @@ def test_compare_to_transformers(self, model_arch): tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None ) - ov_outputs = ov_model(**tokens) + position_ids = None + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + input_shape = tokens["input_ids"].shape + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + ov_outputs = ov_model(**tokens, position_ids=position_ids) + self.assertTrue("logits" in ov_outputs) self.assertIsInstance(ov_outputs.logits, torch.Tensor) with torch.no_grad(): diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index c1ec95ea9b..3154ae1133 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -63,7 +63,7 @@ class OVQuantizerTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 32, 35), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 22), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 41, 23), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) @@ -145,7 +145,7 @@ class OVWeightCompressionTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = ( (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-BartForCausalLM", 27, 14), ) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 2fa77052eb..8d89d24e18 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -24,7 +24,7 @@ "bart": "hf-internal-testing/tiny-random-bart", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", - "blenderbot": "hf-internal-testing/tiny-random-blenderbot", + "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", @@ -102,7 +102,7 @@ "albert": (42,), "vit": (31,), "blenderbot": (35,), - "gpt2": (22,), + "gpt2": (23,), "wav2vec2": (15,), "distilbert": (33,), "t5": (32, 52, 42),