Skip to content

Commit

Permalink
less agressive stateful
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 10, 2024
1 parent 2c3e934 commit 27c0f0a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 125 deletions.
11 changes: 7 additions & 4 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ def make_stateful(
build_state_initializer(ov_model, batch_dim)


def ensure_stateful_is_available():
def ensure_stateful_is_available(warn=True):
"""
Check openvino version and raise error if it does not support stateful models
"""
if is_openvino_version("<", "2023.3"):
log.warn(
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0."
)
if warn:
log.warn(
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0."
)
return False
return True

Expand All @@ -217,6 +218,8 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
not_kv_inputs = [
input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
]
if not key_value_input_names or not key_value_output_names:
return

# By default, batch is the 0-th but chatglm uses 1-st dimension as batch
# TODO: Deduce from a model via ordinal reshape (?) and topology
Expand Down
24 changes: 7 additions & 17 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
use_cache = kwargs.pop("use_cache", True)
model_has_sinks = model_has_state(self.model)
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks
stateful = kwargs.pop("stateful", self.use_cache) # stateful model True only if model converted with past
stateful = kwargs.pop("stateful", None) # stateful model only if it is converted with stateful=True
self.stateful = model_has_sinks
self.main_input_name = "input_ids"
self.num_pkv = 2
Expand All @@ -142,9 +142,10 @@ def __init__(
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)

if self.stateful or stateful:
if self.stateful:
is_stateful_supported = ensure_stateful_is_available()
stateful = False if not is_stateful_supported else stateful
if stateful is None:
stateful = is_stateful_supported
if model_has_sinks and not is_stateful_supported:
raise ValueError(
"Loaded stateful model, while OpenVINO runtime version does not support stateful model inference. "
Expand All @@ -161,21 +162,10 @@ def raise_error(model_prop, user_prop, name):
"To export your model, simply set `export=True`."
)

if stateful is not None and self.stateful and not stateful:
if stateful is not None and stateful ^ self.stateful:
# We cannot transform stateful model to stateless
raise_error(self.stateful, stateful, "stateful")

if not self.stateful and stateful:
if self.use_cache:
# We can transform stateless model to stateful
self._make_stateful()
else:
raise ValueError(
"Making stateful model is applicable only for model converted with use_cache=True, please load model with stateful=False "
"or export the original model once again with use_cache=True when calling the `from_pretrained` method."
"To export your model, simply set `export=True`."
)

if use_cache ^ self.use_cache:
raise_error(self.use_cache, use_cache, "use_cache")

Expand Down Expand Up @@ -266,7 +256,7 @@ def _from_transformers(
compression_option = None
if load_in_8bit is not None:
compression_option = "int8" if load_in_8bit else "fp32"
stateful = kwargs.get("stateful", True)
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -286,7 +276,7 @@ def _from_transformers(
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
return cls._from_pretrained(
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, **kwargs
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, stateful=None, **kwargs
)

def _reshape(
Expand Down
118 changes: 16 additions & 102 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ def test_compare_to_transformers(self, model_arch):
set_seed(SEED)
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.use_cache)
if self.IS_SUPPORT_STATEFUL:
self.assertTrue(ov_model.stateful)
else:
self.assertFalse(ov_model.stateful)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
Expand All @@ -509,6 +514,10 @@ def test_compare_to_transformers(self, model_arch):

self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
if self.IS_SUPPORT_STATEFUL:
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)
# Compare tensor outputs
Expand Down Expand Up @@ -564,8 +573,7 @@ def test_compare_with_and_without_past_key_values(self):
model_id = MODEL_NAMES["gpt2"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")

model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
outputs_model_with_pkv = model_with_pkv.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)
Expand All @@ -576,6 +584,12 @@ def test_compare_with_and_without_past_key_values(self):
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
if self.IS_SUPPORT_STATEFUL:
model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
outputs_model_stateful = model_stateful.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)
self.assertTrue(torch.equal(outputs_model_without_pkv, outputs_model_stateful))

del model_with_pkv
del model_without_pkv
Expand Down Expand Up @@ -613,106 +627,6 @@ def test_default_filling_attention_mask(self):
del model_with_cache
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
def test_stateful(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.stateful)
self.assertTrue(ov_model.use_cache)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
)
position_ids = None
input_shape = tokens["input_ids"].shape
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
ov_outputs = ov_model(**tokens, position_ids=position_ids)

self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)
# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
next_token = torch.argmax(ov_outputs.logits[..., -1:, :], dim=2)
attention_mask = torch.ones((input_shape[0], input_shape[1] + 1), dtype=torch.long)
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
position_ids = position_ids[:, -1:] + 1
pkv = ov_outputs.past_key_values
ov_outputs = ov_model(
input_ids=next_token, position_ids=position_ids, attention_mask=attention_mask, past_key_values=pkv
)
self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
with torch.no_grad():
transformers_outputs = transformers_model(
input_ids=next_token,
attention_mask=attention_mask,
past_key_values=transformers_outputs.past_key_values,
)
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))

del transformers_model
del ov_model
gc.collect()

@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
def test_stateful_on_converted_model(self):
model_id = "vuiseng9/ov-gpt2-fp32-kv-cache"
# reference without state
loaded_model = OVModelForCausalLM.from_pretrained(model_id, stateful=False)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertFalse(loaded_model.stateful)
self.assertTrue(loaded_model.use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")
loaded_model_outputs = loaded_model(**tokens)

# explicit stateful model specified during loading
loaded_stateful_model = OVModelForCausalLM.from_pretrained(model_id, stateful=True)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertTrue(loaded_stateful_model.stateful)
self.assertTrue(loaded_stateful_model.use_cache)
loaded_stateful_model_outputs = loaded_stateful_model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, loaded_stateful_model_outputs.logits))
self.assertTrue("past_key_values" in loaded_stateful_model_outputs)
self.assertIsInstance(loaded_stateful_model_outputs.past_key_values, tuple)
self.assertTrue(
len(loaded_stateful_model_outputs.past_key_values) == 1
and len(loaded_stateful_model_outputs.past_key_values[0]) == 0
)

with tempfile.TemporaryDirectory() as tmpdirname:
loaded_stateful_model.save_pretrained(tmpdirname)
folder_contents = os.listdir(tmpdirname)
self.assertTrue(OV_XML_FILE_NAME in folder_contents)
self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents)
# implicit load stateful model from disk
model = OVModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model.stateful)
self.assertTrue(model.use_cache)

outputs = model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
self.assertTrue("past_key_values" in outputs)
self.assertIsInstance(outputs.past_key_values, tuple)
self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0)
del loaded_model
del loaded_stateful_model
del model
gc.collect()


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down
4 changes: 2 additions & 2 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i

with tempfile.TemporaryDirectory() as tmp_dir:
model_id = MODEL_NAMES[model_name]
transformers_model = model_cls.from_pretrained(model_id, export=True)
transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=False)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, e

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True)
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=False)

if model.export_feature.startswith("text2text-generation"):
models = [model.encoder, model.decoder, model.decoder_with_past]
Expand Down

0 comments on commit 27c0f0a

Please sign in to comment.