Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 25, 2023
1 parent e4b5072 commit 7e1ce6f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 38 deletions.
16 changes: 15 additions & 1 deletion optimum/exporters/openvino/better_transformer_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@

import logging as log

from optimum.intel.utils.import_utils import is_torch_version


def patch_model_with_bettertransformer(model):
if is_torch_version("<", "2.0"):
log.warn(
"integration Scaled Dot Product Attention optimization supported only with torch > 2.0."
"Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
"It is recommended to upgrade PyTorch version for using stateful model or use stateful=Flase"
)
# model already has required SDPA implementation
if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
return model
try:
model = model.to_bettertransformer()
except Exception as e:
log.warn(f"Cannot apply model.to_bettertransformer because of the exception:\n{e}")
log.warn(
f"Cannot apply model.to_bettertransformer because of the exception:\n{e}."
" Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
)
return model

return model
38 changes: 2 additions & 36 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,11 @@

import openvino as ov
from openvino.runtime import opset13
from optimum.intel.openvino.utils import model_has_input
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
from optimum.utils.normalized_config import NormalizedConfigManager


def model_has_input_output_name(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input or output name
Parameters:
ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology?
name (str):
name of input or output
Returns:
True if input or output with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])


def model_has_input(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input name
Parameters:
ov_model (ov.Model):
opennvino model
name (str):
name of input
Returns:
True if input with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs], [])


def model_has_cache_reorder(ov_model: ov.Model):
return model_has_input(ov_model, "beam_idx")


def model_has_state(ov_model: ov.Model):
# TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
return len(ov_model.get_sinks()) > 0
Expand Down Expand Up @@ -88,7 +54,7 @@ def fuse_cache_reorder(
dimension for gathering cache during reorder pass
"""

assert not model_has_input_output_name(ov_model, "beam_idx")
assert not model_has_input(ov_model, "beam_idx")
input_batch = ov_model.input("input_ids").get_partial_shape()[0]
beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted?
Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from optimum.utils import NormalizedConfigManager

from ...exporters.openvino import main_export, patch_stateful, raise_if_openvino_is_too_old
from ...exporters.openvino.stateful import model_has_state
from ..utils.import_utils import is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
Expand Down Expand Up @@ -126,7 +127,7 @@ def __init__(
self.is_dynamic = dynamic_shapes
use_cache = kwargs.pop("use_cache", True)
stateful = kwargs.pop("stateful", None) # None means taking a model "as-is"
model_has_sinks = len(model.get_sinks()) > 0
model_has_sinks = model_has_state(self.model)
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks
self.stateful = model_has_sinks
self.main_input_name = "input_ids"
Expand Down
32 changes: 32 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from glob import glob

import numpy as np
import openvino as ov
from huggingface_hub import model_info
from openvino.runtime import Type
from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size
Expand Down Expand Up @@ -77,6 +78,37 @@
}


def model_has_input_output_name(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input or output name
Parameters:
ov_model (ov.Model): # TODO: Can we derive the dimensions from the model topology?
name (str):
name of input or output
Returns:
True if input or output with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])


def model_has_input(ov_model: ov.Model, name: str):
"""
Helper function for checking that model has specified input name
Parameters:
ov_model (ov.Model):
opennvino model
name (str):
name of input
Returns:
True if input with requested name exists else False
"""
return name in sum([list(t.get_names()) for t in ov_model.inputs], [])


_HEAD_TO_AUTOMODELS = {
"feature-extraction": "OVModelForFeatureExtraction",
"fill-mask": "OVModelForMaskedLM",
Expand Down
47 changes: 47 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def test_stateful(self, model_arch):
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.stateful)
self.assertTrue(ov_model.use_cache)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
Expand Down Expand Up @@ -666,6 +667,52 @@ def test_stateful(self, model_arch):
del ov_model
gc.collect()

@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
def test_stateful_on_converted_model(self):
model_id = "vuiseng9/ov-gpt2-fp32-kv-cache"
# reference without state
loaded_model = OVModelForCausalLM.from_pretrained(model_id)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertFalse(loaded_model.stateful)
self.assertTrue(loaded_model.use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")
loaded_model_outputs = loaded_model(**tokens)

# explicit stateful model specified during loading
loaded_stateful_model = OVModelForCausalLM.from_pretrained(model_id, stateful=True)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertTrue(loaded_model.stateful)
self.assertTrue(loaded_model.use_cache)
loaded_stateful_model_outputs = loaded_stateful_model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, loaded_stateful_model_outputs.logits))
self.assertTrue("past_key_values" in loaded_stateful_model_outputs)
self.assertIsInstance(loaded_stateful_model_outputs.past_key_values, tuple)
self.assertTrue(
len(loaded_stateful_model_outputs.past_key_values) == 1
and len(loaded_stateful_model_outputs.past_key_values[0]) == 0
)

with tempfile.TemporaryDirectory() as tmpdirname:
loaded_stateful_model.save_pretrained(tmpdirname)
folder_contents = os.listdir(tmpdirname)
self.assertTrue(OV_XML_FILE_NAME in folder_contents)
self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents)
# implicit load stateful model from disk
model = OVModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model.stateful)
self.assertTrue(model.use_cache)

outputs = model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
self.assertTrue("past_key_values" in outputs)
self.assertIsInstance(outputs.past_key_values, tuple)
self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0)
del loaded_model
del loaded_stateful_model
del model
gc.collect()


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down
51 changes: 51 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@


from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
from optimum.intel.utils.import_utils import is_openvino_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8

_TASK_TO_DATASET = {
Expand Down Expand Up @@ -166,6 +167,8 @@ class OVWeightCompressionTest(unittest.TestCase):
(OVStableDiffusionXLPipeline, "stable-diffusion-xl"),
)

IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature
Expand Down Expand Up @@ -239,6 +242,40 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
outputs = model(**tokens)
self.assertTrue("logits" in outputs)

@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, expected_int8, expected_int4):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
model_id = MODEL_NAMES[model_name]
transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
quantizer.quantize(
save_directory=tmp_dir,
weights_only=True,
quantization_config=OVConfig(compression={"type": "int4_sym_g128", "ratio": 0.8}),
)
model = model_cls.from_pretrained(tmp_dir)
self.assertTrue(model.stateful)
self.assertTrue(model.use_cache)

_, num_int8, num_int4 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_int4, num_int4)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)

self.assertTrue("logits" in outputs)
self.assertTrue("past_key_values" in outputs)
self.assertIsInstance(outputs.past_key_values, tuple)
self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True)
Expand All @@ -256,6 +293,20 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
_, num_int8, _ = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8[i], num_int8)

@parameterized.expand((OVModelForCausalLM, "gpt2"))
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type):
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=True)
self.assertTrue(model.stateful)
self.assertTrue(model.use_cache)

models = [model]

expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type]
for i, model in enumerate(models):
_, num_int8, _ = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8[i], num_int8)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=False)
Expand Down

0 comments on commit 7e1ce6f

Please sign in to comment.