From 7e1ce6f3de134b35653712fa3b9eb69ed6e1265b Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 25 Dec 2023 12:26:13 +0400 Subject: [PATCH] added more tests --- .../openvino/better_transformer_patch.py | 16 +++++- optimum/exporters/openvino/stateful.py | 38 +------------- optimum/intel/openvino/modeling_decoder.py | 3 +- optimum/intel/openvino/utils.py | 32 ++++++++++++ tests/openvino/test_modeling.py | 47 +++++++++++++++++ tests/openvino/test_quantization.py | 51 +++++++++++++++++++ 6 files changed, 149 insertions(+), 38 deletions(-) diff --git a/optimum/exporters/openvino/better_transformer_patch.py b/optimum/exporters/openvino/better_transformer_patch.py index 5e89be2ef8..8cc98185f8 100644 --- a/optimum/exporters/openvino/better_transformer_patch.py +++ b/optimum/exporters/openvino/better_transformer_patch.py @@ -14,12 +14,26 @@ import logging as log +from optimum.intel.utils.import_utils import is_torch_version + def patch_model_with_bettertransformer(model): + if is_torch_version("<", "2.0"): + log.warn( + "integration Scaled Dot Product Attention optimization supported only with torch > 2.0." + "Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" + "It is recommended to upgrade PyTorch version for using stateful model or use stateful=Flase" + ) + # model already has required SDPA implementation + if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": + return model try: model = model.to_bettertransformer() except Exception as e: - log.warn(f"Cannot apply model.to_bettertransformer because of the exception:\n{e}") + log.warn( + f"Cannot apply model.to_bettertransformer because of the exception:\n{e}." + " Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" + ) return model return model diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index baefac2336..7720bf8c76 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -20,45 +20,11 @@ import openvino as ov from openvino.runtime import opset13 +from optimum.intel.openvino.utils import model_has_input from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version from optimum.utils.normalized_config import NormalizedConfigManager -def model_has_input_output_name(ov_model: ov.Model, name: str): - """ - Helper function for checking that model has specified input or output name - - Parameters: - ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology? - name (str): - name of input or output - - Returns: - True if input or output with requested name exists else False - """ - return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) - - -def model_has_input(ov_model: ov.Model, name: str): - """ - Helper function for checking that model has specified input name - - Parameters: - ov_model (ov.Model): - opennvino model - name (str): - name of input - - Returns: - True if input with requested name exists else False - """ - return name in sum([list(t.get_names()) for t in ov_model.inputs], []) - - -def model_has_cache_reorder(ov_model: ov.Model): - return model_has_input(ov_model, "beam_idx") - - def model_has_state(ov_model: ov.Model): # TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods return len(ov_model.get_sinks()) > 0 @@ -88,7 +54,7 @@ def fuse_cache_reorder( dimension for gathering cache during reorder pass """ - assert not model_has_input_output_name(ov_model, "beam_idx") + assert not model_has_input(ov_model, "beam_idx") input_batch = ov_model.input("input_ids").get_partial_shape()[0] beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index bbd6dc2cb4..5c0e2f7200 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -30,6 +30,7 @@ from optimum.utils import NormalizedConfigManager from ...exporters.openvino import main_export, patch_stateful, raise_if_openvino_is_too_old +from ...exporters.openvino.stateful import model_has_state from ..utils.import_utils import is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel @@ -126,7 +127,7 @@ def __init__( self.is_dynamic = dynamic_shapes use_cache = kwargs.pop("use_cache", True) stateful = kwargs.pop("stateful", None) # None means taking a model "as-is" - model_has_sinks = len(model.get_sinks()) > 0 + model_has_sinks = model_has_state(self.model) self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks self.stateful = model_has_sinks self.main_input_name = "input_ids" diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index c05ba9e374..7a2741a108 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -18,6 +18,7 @@ from glob import glob import numpy as np +import openvino as ov from huggingface_hub import model_info from openvino.runtime import Type from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size @@ -77,6 +78,37 @@ } +def model_has_input_output_name(ov_model: ov.Model, name: str): + """ + Helper function for checking that model has specified input or output name + + Parameters: + ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology? + name (str): + name of input or output + + Returns: + True if input or output with requested name exists else False + """ + return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) + + +def model_has_input(ov_model: ov.Model, name: str): + """ + Helper function for checking that model has specified input name + + Parameters: + ov_model (ov.Model): + opennvino model + name (str): + name of input + + Returns: + True if input with requested name exists else False + """ + return name in sum([list(t.get_names()) for t in ov_model.inputs], []) + + _HEAD_TO_AUTOMODELS = { "feature-extraction": "OVModelForFeatureExtraction", "fill-mask": "OVModelForMaskedLM", diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index e7b3f4abe1..cd4f324161 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -621,6 +621,7 @@ def test_stateful(self, model_arch): ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True) self.assertIsInstance(ov_model.config, PretrainedConfig) self.assertTrue(ov_model.stateful) + self.assertTrue(ov_model.use_cache) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( @@ -666,6 +667,52 @@ def test_stateful(self, model_arch): del ov_model gc.collect() + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + def test_stateful_on_converted_model(self): + model_id = "vuiseng9/ov-gpt2-fp32-kv-cache" + # reference without state + loaded_model = OVModelForCausalLM.from_pretrained(model_id) + self.assertIsInstance(loaded_model.config, PretrainedConfig) + self.assertFalse(loaded_model.stateful) + self.assertTrue(loaded_model.use_cache) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + loaded_model_outputs = loaded_model(**tokens) + + # explicit stateful model specified during loading + loaded_stateful_model = OVModelForCausalLM.from_pretrained(model_id, stateful=True) + self.assertIsInstance(loaded_model.config, PretrainedConfig) + self.assertTrue(loaded_model.stateful) + self.assertTrue(loaded_model.use_cache) + loaded_stateful_model_outputs = loaded_stateful_model(**tokens) + self.assertTrue(torch.equal(loaded_model_outputs.logits, loaded_stateful_model_outputs.logits)) + self.assertTrue("past_key_values" in loaded_stateful_model_outputs) + self.assertIsInstance(loaded_stateful_model_outputs.past_key_values, tuple) + self.assertTrue( + len(loaded_stateful_model_outputs.past_key_values) == 1 + and len(loaded_stateful_model_outputs.past_key_values[0]) == 0 + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + loaded_stateful_model.save_pretrained(tmpdirname) + folder_contents = os.listdir(tmpdirname) + self.assertTrue(OV_XML_FILE_NAME in folder_contents) + self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents) + # implicit load stateful model from disk + model = OVModelForCausalLM.from_pretrained(tmpdirname) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + outputs = model(**tokens) + self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits)) + self.assertTrue("past_key_values" in outputs) + self.assertIsInstance(outputs.past_key_values, tuple) + self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0) + del loaded_model + del loaded_stateful_model + del model + gc.collect() + class OVModelForMaskedLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index c3378c08e6..fd63a54988 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -51,6 +51,7 @@ from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG +from optimum.intel.utils.import_utils import is_openvino_version from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8 _TASK_TO_DATASET = { @@ -166,6 +167,8 @@ class OVWeightCompressionTest(unittest.TestCase): (OVStableDiffusionXLPipeline, "stable-diffusion-xl"), ) + IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature @@ -239,6 +242,40 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i outputs = model(**tokens) self.assertTrue("logits" in outputs) + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS) + def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, expected_int8, expected_int4): + task = model_cls.export_feature + + with tempfile.TemporaryDirectory() as tmp_dir: + model_id = MODEL_NAMES[model_name] + transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer.quantize( + save_directory=tmp_dir, + weights_only=True, + quantization_config=OVConfig(compression={"type": "int4_sym_g128", "ratio": 0.8}), + ) + model = model_cls.from_pretrained(tmp_dir) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + _, num_int8, num_int4 = get_num_quantized_nodes(model) + self.assertEqual(expected_int8, num_int8) + self.assertEqual(expected_int4, num_int4) + + tokens = tokenizer("This is a sample input", return_tensors="pt") + outputs = model(**tokens) + + self.assertTrue("logits" in outputs) + self.assertTrue("past_key_values" in outputs) + self.assertIsInstance(outputs.past_key_values, tuple) + self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True) @@ -256,6 +293,20 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_ov_int8[i], num_int8) + @parameterized.expand((OVModelForCausalLM, "gpt2")) + @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") + def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type): + model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=True) + self.assertTrue(model.stateful) + self.assertTrue(model.use_cache) + + models = [model] + + expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] + for i, model in enumerate(models): + _, num_int8, _ = get_num_quantized_nodes(model) + self.assertEqual(expected_ov_int8[i], num_int8) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type): model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=False)