Skip to content

Commit

Permalink
stateful by default fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 10, 2024
1 parent 045cc69 commit 2c3e934
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 24 deletions.
9 changes: 6 additions & 3 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,12 @@ def parse_args_openvino(parser: "ArgumentParser"):
),
)
optional_group.add_argument(
"--stateful",
"--no-stateful",
action="store_true",
help="Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs",
help=(
"Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. "
"In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs.",
),
)


Expand Down Expand Up @@ -144,6 +147,6 @@ def run(self):
pad_token_id=self.args.pad_token_id,
compression_option=self.args.weight_format,
compression_ratio=self.args.ratio,
stateful=self.args.stateful,
stateful=not self.args.no_stateful,
# **input_shapes,
)
7 changes: 6 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main_export(
fn_get_submodels: Optional[Callable] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: Optional[bool] = None,
stateful: Optional[bool] = True,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -280,6 +280,11 @@ class StoreAttr(object):
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")

synonyms_for_task = TasksManager.synonyms_for_task(task)
synonyms_for_task.add(task)
if stateful and "text-generation-with-past" not in synonyms_for_task:
stateful = False

preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
Expand Down
16 changes: 9 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def export(
model_kwargs: Optional[Dict[str, Any]] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = False,
stateful: bool = True,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand All @@ -128,7 +128,7 @@ def export(
Compression ratio between primary and backup precision (only relevant to INT4).
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the exporter.
stateful (`Optional[bool]`, defaults to `False`):
stateful (`Optional[bool]`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Returns:
Expand Down Expand Up @@ -166,7 +166,9 @@ def export(
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
return export_tensorflow(model, config, opset, output)
return export_tensorflow(
model, config, opset, output, compression_option=compression_option, compression_ratio=compression_ratio
)

else:
raise RuntimeError(
Expand Down Expand Up @@ -303,7 +305,7 @@ def export_pytorch(
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`Optional[bool]`, defaults to `False`):
stateful (`Optional[bool]`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Returns:
Expand Down Expand Up @@ -464,7 +466,7 @@ def export_models(
model_kwargs: Optional[Dict[str, Any]] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[int] = None,
stateful: bool = False,
stateful: bool = True,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand All @@ -486,7 +488,7 @@ def export_models(
Compression ratio between primary and backup precision (only relevant to INT4).
model_kwargs (Optional[Dict[str, Any]], optional):
Additional kwargs for model export.
stateful (`Optional[bool]`, defaults to `False`)
stateful (`Optional[bool]`, defaults to `True`)
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs
Raises:
Expand All @@ -497,7 +499,7 @@ def export_models(
"""
if stateful:
# This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
ensure_stateful_is_available()
stateful = ensure_stateful_is_available()
outputs = []

if output_names is not None and len(output_names) != len(models_and_onnx_configs):
Expand Down
5 changes: 3 additions & 2 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ def ensure_stateful_is_available():
Check openvino version and raise error if it does not support stateful models
"""
if is_openvino_version("<", "2023.3"):
raise ValueError(
log.warn(
f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. Install openvino>=2023.3.0."
)
return False
return True


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
Expand All @@ -205,7 +207,6 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
ov_model (`ov.Model`):
openvino model
"""
ensure_stateful_is_available()

key_value_input_names = [
key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names())
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def from_pretrained(
model = TimmForImageClassification.from_pretrained(model_id, **kwargs)
onnx_config = TimmOnnxConfig(model.config)

return cls._to_load(model=model, config=config, onnx_config=onnx_config)
return cls._to_load(model=model, config=config, onnx_config=onnx_config, stateful=False)
else:
return super().from_pretrained(
model_id=model_id,
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def _to_load(
force_download: bool = False,
cache_dir: Optional[str] = None,
local_files_only: bool = False,
stateful: bool = False,
**kwargs,
):
save_dir = TemporaryDirectory()
Expand All @@ -326,6 +327,7 @@ def _to_load(
config=onnx_config,
opset=onnx_config.DEFAULT_ONNX_OPSET,
output=save_dir_path / OV_XML_FILE_NAME,
stateful=stateful,
)

return cls._from_pretrained(
Expand Down
33 changes: 25 additions & 8 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def __init__(

self.is_dynamic = dynamic_shapes
use_cache = kwargs.pop("use_cache", True)
stateful = kwargs.pop("stateful", None) # None means taking a model "as-is"
model_has_sinks = model_has_state(self.model)
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks
stateful = kwargs.pop("stateful", self.use_cache) # stateful model True only if model converted with past
self.stateful = model_has_sinks
self.main_input_name = "input_ids"
self.num_pkv = 2
Expand All @@ -143,7 +143,15 @@ def __init__(
self.model = self._reshape(self.model, -1, -1)

if self.stateful or stateful:
ensure_stateful_is_available()
is_stateful_supported = ensure_stateful_is_available()
stateful = False if not is_stateful_supported else stateful
if model_has_sinks and not is_stateful_supported:
raise ValueError(
"Loaded stateful model, while OpenVINO runtime version does not support stateful model inference. "
"Please update OpenVINO version >= 2023.3.0 "
"or export the original model once again with `stateful=False` when calling the `from_pretrained` method."
"To export your model, simply set `export=True`."
)

def raise_error(model_prop, user_prop, name):
raise ValueError(
Expand All @@ -158,15 +166,22 @@ def raise_error(model_prop, user_prop, name):
raise_error(self.stateful, stateful, "stateful")

if not self.stateful and stateful:
# We can transform stateless model to stateful
self._make_stateful()

if enable_compilation:
self.compile()
if self.use_cache:
# We can transform stateless model to stateful
self._make_stateful()
else:
raise ValueError(
"Making stateful model is applicable only for model converted with use_cache=True, please load model with stateful=False "
"or export the original model once again with use_cache=True when calling the `from_pretrained` method."
"To export your model, simply set `export=True`."
)

if use_cache ^ self.use_cache:
raise_error(self.use_cache, use_cache, "use_cache")

if enable_compilation:
self.compile()

def update_pkv_precision(self, force_fp32=False):
if not self.use_cache or self.stateful:
return
Expand Down Expand Up @@ -251,6 +266,7 @@ def _from_transformers(
compression_option = None
if load_in_8bit is not None:
compression_option = "int8" if load_in_8bit else "fp32"
stateful = kwargs.get("stateful", True)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -263,6 +279,7 @@ def _from_transformers(
force_download=force_download,
trust_remote_code=trust_remote_code,
compression_option=compression_option,
stateful=stateful,
)

config.is_decoder = True
Expand Down Expand Up @@ -345,7 +362,7 @@ def forward(
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()

print(self.stateful)
if self.use_cache and past_key_values is not None:
input_ids = input_ids[:, -1:]

Expand Down
4 changes: 2 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,8 @@ def test_stateful_on_converted_model(self):
# explicit stateful model specified during loading
loaded_stateful_model = OVModelForCausalLM.from_pretrained(model_id, stateful=True)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertTrue(loaded_model.stateful)
self.assertTrue(loaded_model.use_cache)
self.assertTrue(loaded_stateful_model.stateful)
self.assertTrue(loaded_stateful_model.use_cache)
loaded_stateful_model_outputs = loaded_stateful_model(**tokens)
self.assertTrue(torch.equal(loaded_model_outputs.logits, loaded_stateful_model_outputs.logits))
self.assertTrue("past_key_values" in loaded_stateful_model_outputs)
Expand Down

0 comments on commit 2c3e934

Please sign in to comment.