Skip to content

Commit

Permalink
Fix compatibility for optimum v1.14.0 (#460)
Browse files Browse the repository at this point in the history
* Enable openvino inference for gpt big code models

* fix

* format

* fix input names

* Fix export optimum modifications
  • Loading branch information
echarlaix authored Nov 6, 2023
1 parent c5ed584 commit bf8e95c
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 119 deletions.
29 changes: 11 additions & 18 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


Expand Down Expand Up @@ -257,24 +256,18 @@ class StoreAttr(object):
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
)
else:
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
legacy=False,
)

if int8 is None:
int8 = False
Expand Down
131 changes: 60 additions & 71 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask


if is_transformers_version("<", "4.25.0"):
Expand All @@ -47,55 +48,37 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
if "text-generation" in task:
onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
pkv.append([])
for j in range(len(model_inputs["past_key_values"][0])):
pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype))
pkv[i] = tuple(pkv[i])
model_inputs["past_key_values"] = tuple(pkv)
i = model_inputs["input_ids"]
a = model_inputs["attention_mask"]
model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype)
model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype)
return model_inputs

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)

torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
if is_torch_version(">", "2.0.1"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
if is_torch_version(">=", "2.1.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
if is_torch_version(">=", "2.0.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)

traced_model = torch.jit.freeze(traced_model.eval())
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model


class PreTrainedModel(OptimizedModel):
pass


class BaseModelForCausalLM(PreTrainedModel, GenerationMixin):
class BaseModelForCausalLM(OptimizedModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
main_input_name = "input_ids"
Expand Down Expand Up @@ -156,12 +139,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

position_ids = kwargs.get("position_ids", None)

attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": self.use_cache,
"position_ids": None,
"attention_mask": kwargs.get("attention_mask", None),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": None,
}

Expand Down Expand Up @@ -258,6 +252,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if attention_mask is None:
Expand All @@ -268,43 +263,42 @@ def forward(
"attention_mask": attention_mask,
}

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
num_layers = self.normalized_config.num_layers
num_attention_heads = self.normalized_config.num_attention_heads
num_key_value_heads = num_attention_heads
if hasattr(self.normalized_config, "num_key_value_heads"):
num_key_value_heads = self.normalized_config.num_key_value_heads
hidden_size = self.normalized_config.hidden_size
d_k = hidden_size // num_attention_heads
if self.config.model_type == "gpt_bigcode":
new_shape = [input_ids.shape[0], 0, d_k * 2]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = tuple(empty_tensor for _ in range(nb_pkv))
d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
batch_size = input_ids.shape[0]

if model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0]
else:
new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = pkv + (empty_tensor,)
if past_key_values is None:
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
num_attention_heads = self.normalized_config.num_attention_heads

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(
tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)
)
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
shape = (batch_size, 0, d_k * 2)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand Down Expand Up @@ -389,7 +383,7 @@ def _from_transformers(
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs,
):
if is_torch_version("<", "2.0.0"):
if is_torch_version("<", "2.1.0"):
raise ImportError("`torch>=2.0.0` is needed to trace your model")

task = cls.export_feature
Expand All @@ -405,12 +399,7 @@ def _from_transformers(
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

if model.config.model_type == "bloom":
model.transformer._prepare_attn_mask = _prepare_attn_mask

if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
Expand Down
15 changes: 8 additions & 7 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,12 @@
logger = logging.getLogger(__name__)


# workaround to enable compatibility between openvino models and transformers pipelines
class PreTrainedModel(OptimizedModel):
pass


@add_start_docstrings(
"""
Base OVModel class.
""",
)
class OVBaseModel(PreTrainedModel):
class OVBaseModel(OptimizedModel):
auto_model_class = None
export_feature = None

Expand Down Expand Up @@ -86,6 +81,12 @@ def __init__(
input_names[next((name for name in names if "/" not in name), names[0])] = idx
self.input_names = input_names

output_names = {}
for idx, key in enumerate(model.outputs):
names = tuple(key.get_names())
output_names[next((name for name in names if "/" not in name), names[0])] = idx
self.output_names = output_names

self.model = model
self.request = None
if enable_compilation:
Expand Down Expand Up @@ -302,7 +303,7 @@ def _from_transformers(
@classmethod
def _to_load(
cls,
model: PreTrainedModel,
model,
config: PretrainedConfig,
onnx_config: OnnxConfig,
use_auth_token: Optional[Union[bool, str]] = None,
Expand Down
22 changes: 18 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __init__(
self.main_input_name = "input_ids"
self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]
self._original_model = self.model.clone() # keep original model for serialization
Expand Down Expand Up @@ -313,6 +312,7 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()
Expand Down Expand Up @@ -362,14 +362,28 @@ def forward(

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)
attention_mask = np.array(attention_mask)
else:
inputs["attention_mask"] = np.ones(
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = np.array(position_ids)
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)

inputs["position_ids"] = position_ids

# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

from ...exporters.openvino import export, export_pytorch_via_onnx
from ..utils.constant import _TASK_ALIASES
from ..utils.modeling_utils import patch_decoder_attention_mask
from .configuration import OVConfig
from .modeling_base import OVBaseModel
from .modeling_decoder import OVBaseDecoderModel
Expand Down Expand Up @@ -394,9 +393,10 @@ def _quantize_torchmodel(
task = self.task
model = self.model
self.model.config.save_pretrained(save_directory)
model = patch_decoder_attention_mask(model)
if task == "text-generation":
onnx_config = onnx_config_class(model.config, use_past=model.config.use_cache)
if task.startswith("text-generation"):
onnx_config = onnx_config_class(
model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache
)
else:
onnx_config = onnx_config_class(model.config)

Expand Down
Loading

0 comments on commit bf8e95c

Please sign in to comment.