Skip to content

Commit

Permalink
openvino causal langage model refactorization (#287)
Browse files Browse the repository at this point in the history
* openvino causal langage model refactorization

* log ignore warning message when unused parameters given during reshaping operation
  • Loading branch information
echarlaix authored Apr 18, 2023
1 parent 1a1dab2 commit 357aa81
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 159 deletions.
256 changes: 102 additions & 154 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@
import openvino
import torch
from openvino.runtime import Core, Tensor
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.exporters import TasksManager
from optimum.exporters.onnx import export
from optimum.utils import NormalizedConfig, NormalizedConfigManager
from optimum.utils import NormalizedConfigManager

from ..utils.import_utils import is_transformers_version
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING
from .modeling_base import OVBaseModel
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import ONNX_WEIGHTS_NAME


Expand Down Expand Up @@ -81,7 +80,7 @@ def _contiguous_helper(tensor: np.ndarray) -> np.ndarray:
Base OVBaseDecoderModel class.
""",
)
class OVBaseDecoderModel(OVBaseModel):
class OVBaseDecoderModel(OVModel):
def __init__(
self,
model: openvino.runtime.Model,
Expand All @@ -92,35 +91,30 @@ def __init__(
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
self.config = config
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs)
self.model_save_dir = model_save_dir
self._device = device.upper()
self.is_dynamic = True
self.ov_config = ov_config if ov_config is not None else {}
self.preprocessors = kwargs.get("preprocessors", [])
self.model = self._reshape(model, -1, -1)
self.device = torch.device("cpu")
self.main_input_name = "input_ids"
enable_compilation = kwargs.get("compile", True)
normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.decoder = OVDecoder(self.model, self._device, self.use_cache, self.ov_config, normalized_config)

if enable_compilation:
self.compile()

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
else:
from transformers import GenerationConfig

self.generation_config = GenerationConfig.from_model_config(config)
if not dynamic_shapes:
raise ValueError(
"`dynamic_shapes` was set to `False` but static shapes are not supported for causal language model. Please set `dynamic_shapes=True`."
)

# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
self.auto_model_class.register(AutoConfig, self.__class__)
super().__init__(
model,
config,
device=device,
dynamic_shapes=True,
ov_config=ov_config,
model_save_dir=model_save_dir,
**kwargs,
)

use_cache = kwargs.pop("use_cache", True)
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs)
self.main_input_name = "input_ids"
self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]

if use_cache ^ self.use_cache:
raise ValueError(
f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. "
Expand All @@ -129,14 +123,6 @@ def __init__(
"To export your model, simply set `export=True`."
)

if not dynamic_shapes:
logger.warning(
"`dynamic_shapes` was set to `False` but static shapes are not supported for causal language model and will be ignored."
)

def compile(self):
self.decoder._create_inference_request()

@classmethod
def _from_transformers(
cls,
Expand Down Expand Up @@ -188,7 +174,20 @@ def _from_transformers(
**kwargs,
)

def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
def _reshape(
self,
model: openvino.runtime.Model,
batch_size: int,
sequence_length: int,
height: int = None,
width: int = None,
):
if height is not None:
logger.warning(f"`height` set to `{height}` will be ignored during reshaping operation.")

if width is not None:
logger.warning(f"`width` set to `{width}` will be ignored during reshaping operation.")

shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
Expand All @@ -208,19 +207,6 @@ def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self

def to(self, device: str):
"""
Use the specified `device` for inference. For example: "cpu" or "gpu". `device` can
be in upper or lower case. To speed up first inference, call `.compile()` after `.to()`.
"""
self._device = device.upper()
self.decoder._device = device.upper()
self.decoder.request = None
return self

def forward(self, *args, **kwargs):
raise NotImplementedError


@add_start_docstrings(
"""
Expand All @@ -243,20 +229,74 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()

if self.use_cache and past_key_values is not None:
input_ids = input_ids[:, -1:]

outputs = self.decoder(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
return CausalLMOutputWithPast(logits=outputs.logits, past_key_values=outputs.past_key_values)
inputs = {}
if past_key_values is not None:
# Flatten the past_key_values
past_key_values = tuple(
_contiguous_helper(np.array(past_key_value))
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = {
input_name: Tensor(past_key_value, shared_memory=True)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values)
}

# Create empty past_key_values for decoder_with_past first generation step
elif self.use_cache:
shape_input_ids = input_ids.shape
num_attention_heads = (
self.normalized_config.num_attention_heads if self.config.model_type == "bloom" else 1
)
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = shape_input_ids[0] * num_attention_heads
if shape[2].is_dynamic:
shape[2] = 0
if shape[1].is_dynamic:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())

inputs["input_ids"] = np.array(input_ids)

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names and attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)

# Run inference
self.request.start_async(inputs)
self.request.wait()

outputs = {
key.get_any_name(): value.data for key, value in zip(self.request.model_outputs, self.request.outputs)
}
logits = torch.from_numpy(outputs["logits"]).to(self.device)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(
torch.from_numpy(outputs[key]).to(self.device) for key in self.key_value_output_names
)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Expand Down Expand Up @@ -362,95 +402,3 @@ def _convert_to_standard_cache(
def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


class OVDecoder:
def __init__(
self, model: openvino.runtime.Model, device: str, use_cache: bool, ov_config: Dict, config: NormalizedConfig
):
self.model = model
self._device = device
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]
self.use_cache = use_cache
self.num_pkv = 2
self.ov_config = ov_config
self.config = config
self.request = None

def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> CausalLMOutputWithPast:
self._create_inference_request()

inputs = {}
if past_key_values is not None:
# Flatten the past_key_values
past_key_values = tuple(
_contiguous_helper(np.array(past_key_value))
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = {
input_name: Tensor(past_key_value, shared_memory=True)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values)
}

# Create empty past_key_values for decoder_with_past first generation step
elif self.use_cache:
shape_input_ids = input_ids.shape
num_attention_heads = self.config.num_attention_heads if self.config.config.model_type == "bloom" else 1
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = shape_input_ids[0] * num_attention_heads
if shape[2].is_dynamic:
shape[2] = 0
if shape[1].is_dynamic:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())

inputs["input_ids"] = np.array(input_ids)

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names and attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)

# Run inference
self.request.start_async(inputs)
self.request.wait()

outputs = {
key.get_any_name(): value.data for key, value in zip(self.request.model_outputs, self.request.outputs)
}
logits = torch.from_numpy(outputs["logits"]).to(self.device)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(
torch.from_numpy(outputs[key]).to(self.device) for key in self.key_value_output_names
)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def _create_inference_request(self):
if self.request is None:
logger.info("Compiling the decoder and creating the inference request ...")
compiled_model = core.compile_model(self.model, self._device, self.ov_config)
self.request = compiled_model.create_infer_request()
13 changes: 8 additions & 5 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,11 @@ def test_compare_to_transformers(self, model_arch):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = OVModelForCausalLM.from_pretrained(model_id, from_transformers=True, use_cache=False)
model.to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = OVModelForCausalLM.from_pretrained(model_id, from_transformers=True, use_cache=False, compile=False)
model.to("cpu")
model.half()
model.compile()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
outputs = pipe("This is a sample", max_length=10)
self.assertEqual(pipe.device, model.device)
Expand All @@ -456,7 +458,7 @@ def test_pipeline(self, model_arch):
def test_multiple_inputs(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
model = OVModelForCausalLM.from_pretrained(model_id, export=True)
model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"]
Expand All @@ -469,8 +471,9 @@ def test_model_and_decoder_same_device(self):
model_id = MODEL_NAMES["gpt2"]
model = OVModelForCausalLM.from_pretrained(model_id, export=True)
model.to("TEST")
self.assertEqual(model._device, model.decoder._device)
self.assertEqual(model.decoder._device, "TEST")
self.assertEqual(model._device, "TEST")
# Verify that request is being reset
self.assertEqual(model.request, None)

def test_compare_with_and_without_past_key_values(self):
model_id = MODEL_NAMES["gpt2"]
Expand Down

0 comments on commit 357aa81

Please sign in to comment.