diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py
index 6bfec7c221..c2f8f85758 100644
--- a/optimum/commands/export/openvino.py
+++ b/optimum/commands/export/openvino.py
@@ -92,6 +92,17 @@ def parse_args_openvino(parser: "ArgumentParser"):
             "precision (by default 20%% in INT8). This helps to achieve better accuracy after weight compression."
         ),
     )
+    optional_group.add_argument(
+        "--disable-stateful",
+        action="store_true",
+        help=(
+            "Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. "
+            "In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. "
+            "If --disable-stateful option is used, it may result in sub-optimal inference performance. "
+            "Use it when you intentionally want to use a stateless model, for example, to be compatible with existing "
+            "OpenVINO native inference code that expects kv-cache inputs and outputs in the model."
+        ),
+    )
 
 
 class OVExportCommand(BaseOptimumCLICommand):
@@ -138,6 +149,7 @@ def run(self):
             trust_remote_code=self.args.trust_remote_code,
             pad_token_id=self.args.pad_token_id,
             compression_option=self.args.weight_format,
-            compression_ratio=self.args.ratio
+            compression_ratio=self.args.ratio,
+            stateful=not self.args.disable_stateful,
             # **input_shapes,
         )
diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py
index d87d8dda9e..6fd7970a07 100644
--- a/optimum/exporters/openvino/__init__.py
+++ b/optimum/exporters/openvino/__init__.py
@@ -1,5 +1,6 @@
 from .__main__ import main_export
 from .convert import export, export_models, export_pytorch_via_onnx
+from .stateful import ensure_stateful_is_available, patch_stateful
 
 
 __all__ = ["main_export", "export", "export_models"]
diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py
index 54fe1193e5..750005802c 100644
--- a/optimum/exporters/openvino/__main__.py
+++ b/optimum/exporters/openvino/__main__.py
@@ -28,6 +28,7 @@
 
 from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
 from .convert import export_models
+from .stateful import ensure_export_task_support_stateful
 
 
 if is_optimum_version(">=", "1.16.0"):
@@ -65,6 +66,7 @@ def main_export(
     fn_get_submodels: Optional[Callable] = None,
     compression_option: Optional[str] = None,
     compression_ratio: Optional[float] = None,
+    stateful: bool = True,
     **kwargs_shapes,
 ):
     """
@@ -124,6 +126,8 @@ def main_export(
             `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
         compression_ratio (`Optional[float]`, defaults to `None`):
             Compression ratio between primary and backup precision (only relevant to INT4).
+        stateful (`bool`, defaults to `True`):
+            Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
         **kwargs_shapes (`Dict`):
             Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
 
@@ -277,6 +281,9 @@ class StoreAttr(object):
             possible_synonyms = ""
         logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
 
+    task_support_stateful = ensure_export_task_support_stateful(task)
+    stateful = stateful and task_support_stateful
+
     preprocessors = maybe_load_preprocessors(
         model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
     )
@@ -373,6 +380,7 @@ class StoreAttr(object):
         device=device,
         compression_option=compression_option,
         compression_ratio=compression_ratio,
+        stateful=stateful,
         model_kwargs=model_kwargs,
     )
 
diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py
index 56c5a10e5d..947d8bd989 100644
--- a/optimum/exporters/openvino/convert.py
+++ b/optimum/exporters/openvino/convert.py
@@ -32,6 +32,8 @@
 from optimum.utils import is_diffusers_available
 
 from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
+from .model_patcher import patch_model_with_bettertransformer
+from .stateful import ensure_stateful_is_available, patch_stateful
 from .utils import (
     OV_XML_FILE_NAME,
     clear_class_registry,
@@ -102,6 +104,7 @@ def export(
     model_kwargs: Optional[Dict[str, Any]] = None,
     compression_option: Optional[str] = None,
     compression_ratio: Optional[float] = None,
+    stateful: bool = True,
 ) -> Tuple[List[str], List[str]]:
     """
     Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
@@ -125,6 +128,8 @@ def export(
             Compression ratio between primary and backup precision (only relevant to INT4).
         input_shapes (`Optional[Dict]`, defaults to `None`):
             If specified, allows to use specific shapes for the example input provided to the exporter.
+        stateful (`bool`, defaults to `True`):
+            Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
 
     Returns:
         `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -139,6 +144,10 @@ def export(
     if "diffusers" in str(model.__class__) and not is_diffusers_available():
         raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.")
 
+    if stateful:
+        # This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
+        stateful = ensure_stateful_is_available()
+
     if is_torch_available() and isinstance(model, nn.Module):
         return export_pytorch(
             model,
@@ -150,6 +159,7 @@ def export(
             compression_option=compression_option,
             compression_ratio=compression_ratio,
             model_kwargs=model_kwargs,
+            stateful=stateful,
         )
 
     elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
@@ -160,7 +170,9 @@ def export(
             raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
         if input_shapes is not None:
             logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
-        return export_tensorflow(model, config, opset, output)
+        return export_tensorflow(
+            model, config, opset, output, compression_option=compression_option, compression_ratio=compression_ratio
+        )
 
     else:
         raise RuntimeError(
@@ -271,6 +283,7 @@ def export_pytorch(
     model_kwargs: Optional[Dict[str, Any]] = None,
     compression_option: Optional[str] = None,
     compression_ratio: Optional[float] = None,
+    stateful: bool = False,
 ) -> Tuple[List[str], List[str]]:
     """
     Exports a PyTorch model to an OpenVINO Intermediate Representation.
@@ -291,6 +304,13 @@ def export_pytorch(
             If specified, allows to use specific shapes for the example input provided to the exporter.
         model_kwargs (optional[Dict[str, Any]], defaults to `None`):
             Additional kwargs for model export
+        compression_option (`Optional[str]`, defaults to `None`):
+            The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
+            `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
+        compression_ratio (`Optional[float]`, defaults to `None`):
+            Compression ratio between primary and backup precision (only relevant to INT4).
+        stateful (`bool`, defaults to `False`):
+            Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
 
     Returns:
         `Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -302,6 +322,15 @@ def export_pytorch(
     logger.info(f"Using framework PyTorch: {torch.__version__}")
     output = Path(output)
 
+    if stateful:
+        # Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect
+        # both of them are applied to demonstrate the best performance.
+        # TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation.
+        model = patch_model_with_bettertransformer(model)
+        # TODO: Consider unpatching model after export is done in the end of this function.
+        #       Now it is left as-is because the model is not expected to be used after call export_pytorch, and
+        #       this function is one of the _internal_ steps in a bigger model conversion pipeline.
+
     with torch.no_grad():
         model.config.torchscript = False
         model.config.return_dict = True
@@ -380,6 +409,14 @@ def ts_patched_forward(*args, **kwargs):
             logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
             if patch_model_forward:
                 model.forward = orig_forward
+            if stateful:
+                # cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
+                # TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
+                logger.warn(
+                    "[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
+                    "A stateless model will be exported instead. It may result in sub-optimal inference performance."
+                    "Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
+                )
             return export_pytorch_via_onnx(
                 model,
                 config,
@@ -411,6 +448,10 @@ def ts_patched_forward(*args, **kwargs):
             inp_tensor.get_node().set_partial_shape(static_shape)
             inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
         ov_model.validate_nodes_and_infer_types()
+
+        if stateful:
+            patch_stateful(model.config, ov_model)
+
         _save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio)
         clear_class_registry()
         del model
@@ -430,6 +471,7 @@ def export_models(
     model_kwargs: Optional[Dict[str, Any]] = None,
     compression_option: Optional[str] = None,
     compression_ratio: Optional[int] = None,
+    stateful: bool = True,
 ) -> Tuple[List[List[str]], List[List[str]]]:
     """
     Export the models to OpenVINO IR format
@@ -451,6 +493,8 @@ def export_models(
             Compression ratio between primary and backup precision (only relevant to INT4).
         model_kwargs (Optional[Dict[str, Any]], optional):
             Additional kwargs for model export.
+        stateful (`bool`, defaults to `True`)
+            Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
 
     Raises:
         ValueError: if custom names set not equal of number of models
@@ -481,6 +525,7 @@ def export_models(
                 model_kwargs=model_kwargs,
                 compression_option=compression_option,
                 compression_ratio=compression_ratio,
+                stateful=stateful,
             )
         )
 
diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py
new file mode 100644
index 0000000000..37106eacf8
--- /dev/null
+++ b/optimum/exporters/openvino/model_patcher.py
@@ -0,0 +1,39 @@
+#  Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import logging as log
+
+from optimum.intel.utils.import_utils import is_torch_version
+
+
+def patch_model_with_bettertransformer(model):
+    if is_torch_version("<", "2.0"):
+        log.warn(
+            "integration Scaled Dot Product Attention optimization supported only with torch > 2.0."
+            "Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
+            "It is recommended to upgrade PyTorch version for using stateful model or use stateful=False"
+        )
+    # model already has required SDPA implementation
+    if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
+        return model
+    try:
+        model = model.to_bettertransformer()
+    except Exception as e:
+        log.warn(
+            f"Cannot apply model.to_bettertransformer because of the exception:\n{e}."
+            " Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention"
+        )
+        return model
+
+    return model
diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py
new file mode 100644
index 0000000000..e6ec1879a5
--- /dev/null
+++ b/optimum/exporters/openvino/stateful.py
@@ -0,0 +1,225 @@
+#  Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import logging as log
+from typing import List
+
+import numpy as np
+from transformers import PretrainedConfig
+
+import openvino as ov
+from openvino.runtime import opset13
+from optimum.exporters import TasksManager
+from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
+from optimum.utils.normalized_config import NormalizedConfigManager
+
+
+def model_has_state(ov_model: ov.Model):
+    # TODO: Provide a better way based on the variables availability, but OV Python API doesn't expose required methods
+    return len(ov_model.get_sinks()) > 0
+
+
+def model_has_input_output_name(ov_model: ov.Model, name: str):
+    """
+    Helper function for checking that model has specified input or output name
+
+    Parameters:
+      ov_model (ov.Model):   # TODO: Can we derive the dimensions from the model topology?
+      name (str):
+          name of input or output
+
+    Returns:
+      True if input or output with requested name exists else False
+    """
+    return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
+
+
+def fuse_cache_reorder(
+    ov_model: ov.Model, not_kv_inputs: List[str], key_value_input_names: List[str], gather_dim: int
+):
+    """
+    Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
+
+    Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
+    Should be run before make_stateful. Implements optimumum's _reorder_cache
+    inside the model in the beginning of each iteration.
+    Gather works along given gather_dim dimension that may vary from model to model.
+    KV-cache inputs are identified based on names in key_value_input_names.
+    Append the new beam_idx parameter to not_kv_inputs.
+
+    Parameters:
+      ov_model (`ov.Model`):
+          openvino model for processing
+      not_kv_inputs (`List[str]`):
+          list of input nodes in model that not related to past key values
+      key_value_input_names (`List[str]`):
+          list of names for key value input layers
+      gather_dim (int):
+          dimension for gathering cache during reorder pass
+    """
+
+    if model_has_input_output_name(ov_model, "beam_idx"):
+        raise ValueError("Model already has fused cache")
+    input_batch = ov_model.input("input_ids").get_partial_shape()[0]
+    beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
+    beam_idx.output(0).get_tensor().add_names({"beam_idx"})  # why list is not accepted?
+    ov_model.add_parameters([beam_idx])
+    not_kv_inputs.append(ov_model.inputs[-1])
+    # Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
+    for input_name in key_value_input_names:
+        parameter_output_port = ov_model.input(input_name)
+        consumers = parameter_output_port.get_target_inputs()
+        gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
+        for consumer in consumers:
+            consumer.replace_source_output(gather.output(0))
+    ov_model.validate_nodes_and_infer_types()
+
+
+def build_state_initializer(ov_model: ov.Model, batch_dim: int):
+    """
+    Build initialization ShapeOf Expression for all ReadValue ops
+
+    Parameters:
+      ov_model (ov.Model):
+          openvino model
+      batch_dim (int):
+          index of dimension corresponding to batch size
+    """
+    input_ids = ov_model.input("input_ids")
+    batch = opset13.gather(opset13.shape_of(input_ids, output_type="i64"), opset13.constant([0]), opset13.constant(0))
+    for op in ov_model.get_ops():
+        if op.get_type_name() == "ReadValue":
+            dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
+            dims[batch_dim] = batch
+            dims = [opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim for dim in dims]
+            shape = opset13.concat(dims, axis=0)
+            broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
+            op.set_arguments([broadcast])
+    ov_model.validate_nodes_and_infer_types()
+
+
+def make_stateful(
+    ov_model: ov.Model,
+    not_kv_inputs: List[str],
+    key_value_input_names: List[str],
+    key_value_output_names: List[str],
+    batch_dim: int,
+    num_attention_heads: int,
+    num_beams_and_batch: int = None,
+):
+    """
+    Hides kv-cache inputs and outputs inside the model as variables.
+
+    Parameters:
+        ov_model (ov.Model):
+            openvino model
+        not_kv_inputs (`List[str]`):
+            list of input nodes in model that not related to past key values
+        key_value_input_names (`List[str]`):
+            list of names for key value input layers
+        key_value_output_names (`List[str]`):
+            list of names for key value input layers
+        batch_dim (int):
+            index of batch dimension in key value layers
+        num_attention_heads (int):
+            number of attention heads for batch dimension initialization
+        num_beams_an_batch (int):
+            precalculated number of beams and batch for shapes initialization
+    """
+    from openvino._offline_transformations import apply_make_stateful_transformation
+
+    input_output_map = {}
+    # TODO: Can we derive the dimensions from the model topology?
+
+    if num_beams_and_batch is not None:
+        # Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
+        for input in not_kv_inputs:
+            shape = input.get_partial_shape()
+            if shape.rank.get_length() <= 2:  # == 1 for beam_index
+                shape[0] = num_beams_and_batch
+                input.get_node().set_partial_shape(shape)
+            else:
+                log.warn(f"Rank of {input.get_any_name()} input of the model is not 2, batch size is not set")
+
+    for kv_name_pair in zip(key_value_input_names, key_value_output_names):
+        input_output_map[kv_name_pair[0]] = kv_name_pair[1]
+        if num_beams_and_batch is not None:
+            input = ov_model.input(kv_name_pair[0])
+            shape = input.get_partial_shape()
+            shape[batch_dim] = num_beams_and_batch * num_attention_heads
+            input.get_node().set_partial_shape(shape)
+
+    if num_beams_and_batch is not None:
+        # Re-validation model if shapes are altered above
+        ov_model.validate_nodes_and_infer_types()
+
+    apply_make_stateful_transformation(ov_model, input_output_map)
+    if num_beams_and_batch is None:
+        build_state_initializer(ov_model, batch_dim)
+
+
+def ensure_stateful_is_available(warn=True):
+    """
+    Check openvino version and raise error if it does not support stateful models
+    """
+    if is_openvino_version("<", "2023.3"):
+        if warn:
+            log.warn(
+                f"Could not create or use stateful model when using old version of openvino=={_openvino_version}. It may result in sub-optimal inference performance."
+                "Install openvino>=2023.3.0."
+            )
+        return False
+    return True
+
+
+def ensure_export_task_support_stateful(task: str):
+    task = TasksManager.map_from_synonym(task)
+    return task == "text-generation-with-past"
+
+
+def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
+    """
+    Apply stateful transformation to model to hide key values inputs inside model.
+    Select transformation parameters based on model architecture
+
+    Parameters:
+        config (`PretrainedConfig`):
+            model pretrained config
+        ov_model (`ov.Model`):
+            openvino model
+    """
+
+    key_value_input_names = [
+        key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names())
+    ]
+    key_value_output_names = [
+        key.get_any_name() for key in ov_model.outputs if any("present" in key_name for key_name in key.get_names())
+    ]
+    not_kv_inputs = [
+        input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
+    ]
+    if not key_value_input_names or not key_value_output_names:
+        return
+
+    # By default, batch is the 0-th but chatglm uses 1-st dimension as batch
+    # TODO: Deduce from a model via ordinal reshape (?) and topology
+    batch_dim = 1 if config.model_type == "chatglm" else 0
+
+    fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
+
+    normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
+    num_attention_heads = normalized_config.num_attention_heads if config.model_type == "bloom" else 1
+    make_stateful(
+        ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
+    )
diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py
index 8cea5eb7b6..6d0af462cc 100644
--- a/optimum/intel/openvino/modeling.py
+++ b/optimum/intel/openvino/modeling.py
@@ -554,7 +554,7 @@ def from_pretrained(
             model = TimmForImageClassification.from_pretrained(model_id, **kwargs)
             onnx_config = TimmOnnxConfig(model.config)
 
-            return cls._to_load(model=model, config=config, onnx_config=onnx_config)
+            return cls._to_load(model=model, config=config, onnx_config=onnx_config, stateful=False)
         else:
             return super().from_pretrained(
                 model_id=model_id,
diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py
index 32b6b02377..05dc3af9b5 100644
--- a/optimum/intel/openvino/modeling_base.py
+++ b/optimum/intel/openvino/modeling_base.py
@@ -315,6 +315,7 @@ def _to_load(
         force_download: bool = False,
         cache_dir: Optional[str] = None,
         local_files_only: bool = False,
+        stateful: bool = False,
         **kwargs,
     ):
         save_dir = TemporaryDirectory()
@@ -326,6 +327,7 @@ def _to_load(
             config=onnx_config,
             opset=onnx_config.DEFAULT_ONNX_OPSET,
             output=save_dir_path / OV_XML_FILE_NAME,
+            stateful=stateful,
         )
 
         return cls._from_pretrained(
diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py
index 14f8dbcafa..8a2167eae4 100644
--- a/optimum/intel/openvino/modeling_decoder.py
+++ b/optimum/intel/openvino/modeling_decoder.py
@@ -29,7 +29,8 @@
 
 from optimum.utils import NormalizedConfigManager
 
-from ...exporters.openvino import main_export
+from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
+from ...exporters.openvino.stateful import model_has_state
 from ..utils.import_utils import is_transformers_version
 from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
 from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
@@ -125,7 +126,10 @@ def __init__(
 
         self.is_dynamic = dynamic_shapes
         use_cache = kwargs.pop("use_cache", True)
-        self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs)
+        model_has_sinks = model_has_state(self.model)
+        self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks
+        stateful = kwargs.pop("stateful", None)  # stateful model only if it is converted with stateful=True
+        self.stateful = model_has_sinks
         self.main_input_name = "input_ids"
         self.num_pkv = 2
         self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
@@ -133,22 +137,50 @@ def __init__(
         self.key_value_output_names = [key for key in self.output_names if "present" in key]
         self._original_model = self.model.clone()  # keep original model for serialization
         self._pkv_precision = Type.f32
+        self.next_beam_idx = None
         self.update_pkv_precision()
         if self.is_dynamic:
             self.model = self._reshape(self.model, -1, -1)
-        if enable_compilation:
-            self.compile()
+        is_stateful_supported = ensure_stateful_is_available(warn=False)
 
-        if use_cache ^ self.use_cache:
+        if self.use_cache and not self.stateful:
+            logger.warn(
+                "Provided model does not contain state. It may lead to sub-optimal performance."
+                "Please reexport model with updated OpenVINO version >= 2023.3.0 calling the `from_pretrained` method with original model "
+                "and `export=True` parameter"
+            )
+
+        if self.stateful:
+            if stateful is None:
+                stateful = is_stateful_supported
+            if model_has_sinks and not is_stateful_supported:
+                raise ValueError(
+                    "Loaded stateful model, while OpenVINO runtime version does not support stateful model inference. "
+                    "Please update OpenVINO version >= 2023.3.0 "
+                    "or export the original model once again with `stateful=False` when calling the `from_pretrained` method."
+                    "To export your model, simply set `export=True`."
+                )
+
+        def raise_error(model_prop, user_prop, name):
             raise ValueError(
-                f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. "
-                f"Please load your current model with `use_cache={self.use_cache}` or export the original model "
-                f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
+                f"`{name}` was set to `{user_prop}` but the loaded model only supports `{name}={model_prop}`. "
+                f"Please load your current model with `{name}={model_prop}` or export the original model "
+                f"once again with `{name}={user_prop}` when calling the `from_pretrained` method. "
                 "To export your model, simply set `export=True`."
             )
 
+        if stateful is not None and stateful ^ self.stateful:
+            # We cannot transform stateful model to stateless
+            raise_error(self.stateful, stateful, "stateful")
+
+        if use_cache ^ self.use_cache:
+            raise_error(self.use_cache, use_cache, "use_cache")
+
+        if enable_compilation:
+            self.compile()
+
     def update_pkv_precision(self, force_fp32=False):
-        if not self.use_cache:
+        if not self.use_cache or self.stateful:
             return
 
         pkv_precision = Type.f32
@@ -231,6 +263,7 @@ def _from_transformers(
         compression_option = None
         if load_in_8bit is not None:
             compression_option = "int8" if load_in_8bit else "fp32"
+        stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
         main_export(
             model_name_or_path=model_id,
             output=save_dir_path,
@@ -243,13 +276,14 @@ def _from_transformers(
             force_download=force_download,
             trust_remote_code=trust_remote_code,
             compression_option=compression_option,
+            stateful=stateful,
         )
 
         config.is_decoder = True
         config.is_encoder_decoder = False
         config.save_pretrained(save_dir_path)
         return cls._from_pretrained(
-            model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, **kwargs
+            model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, stateful=None, **kwargs
         )
 
     def _reshape(
@@ -276,6 +310,8 @@ def _reshape(
                     shapes[inputs][1] = -1
                 else:
                     shapes[inputs][2] = -1
+            elif input_name.startswith("beam_idx"):
+                shapes[inputs][0] = -1
             else:
                 shapes[inputs][1] = -1
         model.reshape(shapes)
@@ -290,6 +326,10 @@ def compile(self):
             super().compile()
             self.request = self.request.create_infer_request()
 
+    def _make_stateful(self):
+        patch_stateful(self.config, self.model)
+        self.stateful = True
+
 
 @add_start_docstrings(
     """
@@ -319,49 +359,64 @@ def forward(
         **kwargs,
     ) -> CausalLMOutputWithPast:
         self.compile()
-        inputs = {}
-
         if self.use_cache and past_key_values is not None:
             input_ids = input_ids[:, -1:]
 
+        batch_size = input_ids.shape[0]
+        if self.config.model_type == "bloom":
+            batch_size *= self.normalized_config.num_attention_heads
+
         inputs = {}
         past_len = 0
-        if past_key_values is not None:
-            if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
-                past_len = past_key_values[0][1].shape[-2]
-                if self._pkv_precision == Type.bf16:
-                    # numpy does not support bf16, pretending f16, should change to bf16
-                    past_key_values = tuple(
-                        Tensor(past_key_value, past_key_value.shape, Type.bf16)
-                        for pkv_per_layer in past_key_values
-                        for past_key_value in pkv_per_layer
-                    )
-                else:
-                    # Flatten the past_key_values
-                    past_key_values = tuple(
-                        past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
-                    )
-            else:
-                past_len = past_key_values[0].shape[-2]
-
-            # Add the past_key_values to the decoder inputs
-            inputs = dict(zip(self.key_value_input_names, past_key_values))
-
-        # Create empty past_key_values for decoder_with_past first generation step
-        elif self.use_cache:
-            batch_size = input_ids.shape[0]
-            if self.config.model_type == "bloom":
-                batch_size *= self.normalized_config.num_attention_heads
-
-            for input_name in self.key_value_input_names:
-                model_inputs = self.model.input(input_name)
-                shape = model_inputs.get_partial_shape()
-                shape[0] = batch_size
-                if shape[2].is_dynamic:
-                    shape[2] = 0
+        if not self.stateful:
+            if past_key_values is not None:
+                if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
+                    past_len = past_key_values[0][1].shape[-2]
+                    if self._pkv_precision == Type.bf16:
+                        # numpy does not support bf16, pretending f16, should change to bf16
+                        past_key_values = tuple(
+                            Tensor(past_key_value, past_key_value.shape, Type.bf16)
+                            for pkv_per_layer in past_key_values
+                            for past_key_value in pkv_per_layer
+                        )
+                    else:
+                        # Flatten the past_key_values
+                        past_key_values = tuple(
+                            past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
+                        )
                 else:
-                    shape[1] = 0
-                inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
+                    past_len = past_key_values[0].shape[-2]
+
+                # Add the past_key_values to the decoder inputs
+                inputs = dict(zip(self.key_value_input_names, past_key_values))
+
+            # Create empty past_key_values for decoder_with_past first generation step
+            elif self.use_cache:
+                for input_name in self.key_value_input_names:
+                    model_inputs = self.model.input(input_name)
+                    shape = model_inputs.get_partial_shape()
+                    if self.config.model_type == "chatglm":
+                        shape[0] = 0
+                        shape[1] = batch_size
+                    else:
+                        shape[0] = batch_size
+                        if shape[2].is_dynamic:
+                            shape[2] = 0
+                        else:
+                            shape[1] = 0
+                    inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
+        else:
+            # past_key_values are not used explicitly, instead they are handled inside the model
+            if past_key_values is None:
+                # Need a marker to differentiate the first generate iteration from the others in
+                # the first condition at the function beginning above.
+                # It should be something that is not None and it should be True when converted to Boolean.
+                past_key_values = ((),)
+                # This is the first iteration in a sequence, reset all states
+                self.request.reset_state()
+                # Set initial value for the next beam_idx input that will be used at the current iteration
+                # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
+                self.next_beam_idx = np.arange(batch_size, dtype=int)
 
         inputs["input_ids"] = np.array(input_ids)
         # Add the attention_mask inputs when needed
@@ -387,21 +442,27 @@ def forward(
 
             inputs["position_ids"] = position_ids
 
+        if "beam_idx" in self.input_names:
+            inputs["beam_idx"] = (
+                self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
+            )
+
         # Run inference
         self.request.start_async(inputs, share_inputs=True)
         self.request.wait()
         logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
 
-        if self.use_cache:
-            # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
-            past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
-            if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
-                # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
-                past_key_values = tuple(
-                    past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
-                )
-        else:
-            past_key_values = None
+        if not self.stateful:
+            if self.use_cache:
+                # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
+                past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
+                if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
+                    # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
+                    past_key_values = tuple(
+                        past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
+                    )
+            else:
+                past_key_values = None
 
         return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
 
@@ -428,18 +489,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
         }
 
     # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
-    @staticmethod
     def _reorder_cache(
-        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
+        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
     ) -> Tuple[Tuple[torch.Tensor]]:
         """
         This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
         [`~PreTrainedModel.beam_sample`] is called.
         This is required to match `past_key_values` with the correct beam_idx at every generation step.
         """
-        return tuple(
-            tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
-        )
+        if self.stateful:
+            # TODO: Apply it differently based on model type
+            # TODO: At least for bloom we need to replicate values for each attention head
+            self.next_beam_idx = np.array(beam_idx)  # save beam_idx to be used as an input in the next iteration
+            return past_key_values
+        else:
+            return tuple(
+                tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
+            )
 
     def can_generate(self):
         """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -500,7 +566,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
         use_cache = kwargs.get("use_cache", None)
 
         # only last token for input_ids if past is not None
-        if past_key_values:
+        if past_key_values and not self.stateful:
             # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
             if past_key_values[0][0].shape[0] == input_ids.shape[0]:
                 past_key_values = self._convert_to_bloom_cache(past_key_values)
@@ -522,15 +588,23 @@ def _reorder_cache(
         [`~PreTrainedModel.beam_sample`] is called for bloom architecture.
         This is required to match `past_key_values` with the correct beam_idx at every generation step.
         """
-        standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
-        reordered_past = tuple(
-            (
-                np.take(layer_past[0], beam_idx, 0),
-                np.take(layer_past[1], beam_idx, 0),
+        if self.stateful:
+            beam_idx = np.array(beam_idx)
+            batch_size = beam_idx.shape[0]
+            indices = np.array(range(batch_size * self.normalized_config.num_attention_heads))
+            indices = indices.reshape([batch_size, self.normalized_config.num_attention_heads])
+            self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
+            return past_key_values
+        else:
+            standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
+            reordered_past = tuple(
+                (
+                    np.take(layer_past[0], beam_idx, 0),
+                    np.take(layer_past[1], beam_idx, 0),
+                )
+                for layer_past in standardized_past
             )
-            for layer_past in standardized_past
-        )
-        return self._convert_to_bloom_cache(reordered_past)
+            return self._convert_to_bloom_cache(reordered_past)
 
     # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
     @staticmethod
@@ -602,8 +676,11 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
 
 class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
     # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
-    @staticmethod
     def _reorder_cache(
-        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
+        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
     ) -> Tuple[Tuple[torch.Tensor]]:
-        return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)
+        if self.stateful:
+            self.next_beam_idx = np.array(beam_idx)  # save beam_idx to be used as an input in the next iteration
+            return past_key_values
+        else:
+            return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)
diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py
index 1d8762b000..3da486c45c 100644
--- a/optimum/intel/openvino/quantization.py
+++ b/optimum/intel/openvino/quantization.py
@@ -38,6 +38,7 @@
 from optimum.quantization_base import OptimumQuantizer
 
 from ...exporters.openvino import export, export_pytorch_via_onnx
+from ...exporters.openvino.stateful import ensure_export_task_support_stateful
 from ..utils.constant import _TASK_ALIASES
 from .configuration import OVConfig
 from .modeling_base import OVBaseModel
@@ -313,9 +314,11 @@ def start_async(
                 inputs: Any = None,
                 userdata: Any = None,
                 share_inputs: bool = False,
+                *,
+                shared_memory: Any = None,
             ):
                 data_cache.append(inputs)
-                self.request.infer(inputs, share_inputs)
+                self.request.infer(inputs, share_inputs, share_outputs=True, shared_memory=shared_memory)
 
             def wait(self):
                 pass
@@ -415,6 +418,8 @@ def _quantize_torchmodel(
             onnx_config = onnx_config_class(
                 model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache
             )
+            if model.config.use_cache:
+                task = "text-generation-with-past"
         else:
             onnx_config = onnx_config_class(model.config)
 
@@ -423,7 +428,10 @@ def _quantize_torchmodel(
         export_fn = export if not quantization_config.save_onnx_model else export_pytorch_via_onnx
         opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET)
         opset = max(opset, MIN_ONNX_QDQ_OPSET)
-        _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset)
+        kwargs = {}
+        if not quantization_config.save_onnx_model:
+            kwargs = {"stateful": ensure_export_task_support_stateful(task)}
+        _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs)
         if is_onnx:
             # Load and save the compressed model
             model = core.read_model(onnx_path)
diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py
index f778bbfcbd..3f3fa6c55b 100644
--- a/optimum/intel/utils/import_utils.py
+++ b/optimum/intel/utils/import_utils.py
@@ -70,12 +70,17 @@
 _openvino_version = "N/A"
 if _openvino_available:
     try:
-        _openvino_version = importlib_metadata.version("openvino")
-    except importlib_metadata.PackageNotFoundError:
-        try:
-            _openvino_version = importlib_metadata.version("openvino-nightly")
-        except importlib_metadata.PackageNotFoundError:
-            _openvino_available = False
+        from openvino.runtime import get_version
+
+        version = get_version()
+        # avoid invalid format
+        if "-" in version:
+            major_version, dev_info = version.split("-", 1)
+            commit_id = dev_info.split("-")[0]
+            version = f"{major_version}-{commit_id}"
+        _openvino_version = version
+    except ImportError:
+        _openvino_available = False
 
 
 _nncf_available = importlib.util.find_spec("nncf") is not None
diff --git a/setup.py b/setup.py
index c54983da45..76fc40ef4f 100644
--- a/setup.py
+++ b/setup.py
@@ -44,7 +44,7 @@
         "onnxruntime<1.15.0",
         "transformers>=4.34.0",
     ],
-    "openvino": ["openvino>=2023.2", "onnx", "onnxruntime", "transformers>=4.34.0"],
+    "openvino": ["openvino>=2023.2", "onnx", "onnxruntime", "transformers>=4.36.0", "optimum>=1.16.1"],
     "nncf": ["nncf>=2.7.0"],
     "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"],
     "diffusers": ["diffusers"],
diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py
index 65cd03d090..f26bb92fb8 100644
--- a/tests/openvino/test_modeling.py
+++ b/tests/openvino/test_modeling.py
@@ -493,6 +493,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
         "pegasus",
     )
     GENERATION_LENGTH = 100
+    IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
 
     @parameterized.expand(SUPPORTED_ARCHITECTURES)
     def test_compare_to_transformers(self, model_arch):
@@ -500,6 +501,8 @@ def test_compare_to_transformers(self, model_arch):
         set_seed(SEED)
         ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
         self.assertIsInstance(ov_model.config, PretrainedConfig)
+        self.assertTrue(ov_model.use_cache)
+        self.assertEqual(ov_model.stateful, self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode")
         transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
         tokenizer = AutoTokenizer.from_pretrained(model_id)
         tokens = tokenizer(
@@ -513,6 +516,10 @@ def test_compare_to_transformers(self, model_arch):
 
         self.assertTrue("logits" in ov_outputs)
         self.assertIsInstance(ov_outputs.logits, torch.Tensor)
+        self.assertTrue("past_key_values" in ov_outputs)
+        self.assertIsInstance(ov_outputs.past_key_values, tuple)
+        if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode":
+            self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
         with torch.no_grad():
             transformers_outputs = transformers_model(**tokens)
         # Compare tensor outputs
@@ -568,8 +575,7 @@ def test_compare_with_and_without_past_key_values(self):
         model_id = MODEL_NAMES["gpt2"]
         tokenizer = AutoTokenizer.from_pretrained(model_id)
         tokens = tokenizer("This is a sample input", return_tensors="pt")
-
-        model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
+        model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
         outputs_model_with_pkv = model_with_pkv.generate(
             **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
         )
@@ -580,6 +586,12 @@ def test_compare_with_and_without_past_key_values(self):
         self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
         self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
         self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
+        if self.IS_SUPPORT_STATEFUL:
+            model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
+            outputs_model_stateful = model_stateful.generate(
+                **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
+            )
+            self.assertTrue(torch.equal(outputs_model_without_pkv, outputs_model_stateful))
 
         del model_with_pkv
         del model_without_pkv
@@ -606,7 +618,7 @@ def test_default_filling_attention_mask(self):
         attention_mask = tokens.pop("attention_mask")
         outs_without_attn_mask = model_with_cache(**tokens)
         self.assertTrue(torch.allclose(outs.logits, outs_without_attn_mask.logits))
-        input_ids = torch.argmax(outs.logits, dim=2)
+        input_ids = torch.argmax(outs.logits[:, -1:, :], dim=2)
         past_key_values = outs.past_key_values
         attention_mask = torch.ones((input_ids.shape[0], tokens.input_ids.shape[1] + 1), dtype=torch.long)
         outs_step2 = model_with_cache(
diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py
index a08da51aab..d6da6a78ba 100644
--- a/tests/openvino/test_quantization.py
+++ b/tests/openvino/test_quantization.py
@@ -51,6 +51,7 @@
 
 
 from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
+from optimum.intel.utils.import_utils import is_openvino_version
 from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8
 
 _TASK_TO_DATASET = {
@@ -166,6 +167,8 @@ class OVWeightCompressionTest(unittest.TestCase):
         (OVStableDiffusionXLPipeline, "stable-diffusion-xl"),
     )
 
+    IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
+
     @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
     def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
         task = model_cls.export_feature
@@ -218,7 +221,7 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
 
         with tempfile.TemporaryDirectory() as tmp_dir:
             model_id = MODEL_NAMES[model_name]
-            transformers_model = model_cls.from_pretrained(model_id, export=True)
+            transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=False)
             tokenizer = AutoTokenizer.from_pretrained(model_id)
             if tokenizer.pad_token is None:
                 tokenizer.pad_token = tokenizer.eos_token
@@ -239,9 +242,43 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
             outputs = model(**tokens)
             self.assertTrue("logits" in outputs)
 
+    @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
+    @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
+    def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, expected_int8, expected_int4):
+        task = model_cls.export_feature
+
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            model_id = MODEL_NAMES[model_name]
+            transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=True)
+            tokenizer = AutoTokenizer.from_pretrained(model_id)
+            if tokenizer.pad_token is None:
+                tokenizer.pad_token = tokenizer.eos_token
+
+            quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
+            quantizer.quantize(
+                save_directory=tmp_dir,
+                weights_only=True,
+                quantization_config=OVConfig(compression={"type": "int4_sym_g128", "ratio": 0.8}),
+            )
+            model = model_cls.from_pretrained(tmp_dir)
+            self.assertTrue(model.stateful)
+            self.assertTrue(model.use_cache)
+
+            _, num_int8, num_int4 = get_num_quantized_nodes(model)
+            self.assertEqual(expected_int8, num_int8)
+            self.assertEqual(expected_int4, num_int4)
+
+            tokens = tokenizer("This is a sample input", return_tensors="pt")
+            outputs = model(**tokens)
+
+            self.assertTrue("logits" in outputs)
+            self.assertTrue("past_key_values" in outputs)
+            self.assertIsInstance(outputs.past_key_values, tuple)
+            self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0)
+
     @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
     def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
-        model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True)
+        model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=False)
 
         if model.export_feature.startswith("text2text-generation"):
             models = [model.encoder, model.decoder, model.decoder_with_past]
@@ -256,6 +293,17 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
             _, num_int8, _ = get_num_quantized_nodes(model)
             self.assertEqual(expected_ov_int8[i], num_int8)
 
+    @parameterized.expand((OVModelForCausalLM, "gpt2"))
+    @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
+    def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type):
+        model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=True)
+        self.assertTrue(model.stateful)
+        self.assertTrue(model.use_cache)
+
+        expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type][0]
+        _, num_int8, _ = get_num_quantized_nodes(model)
+        self.assertEqual(expected_ov_int8, num_int8)
+
     @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
     def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
         model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=False)