Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility for optimum next release #460

Merged
merged 35 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5ee1628
Enable openvino inference for gpt big code models
echarlaix Oct 18, 2023
e904f81
merge main in branch
echarlaix Oct 23, 2023
fec7655
fix
echarlaix Oct 23, 2023
327533e
format
echarlaix Oct 23, 2023
36d482e
fix
echarlaix Oct 23, 2023
f8ba216
fix
echarlaix Oct 23, 2023
f52960b
fix input names
echarlaix Oct 23, 2023
764b23b
Fix export optimum modifications
echarlaix Oct 23, 2023
2016d2a
merge main in branch
echarlaix Oct 23, 2023
b111ca9
add test
echarlaix Oct 31, 2023
d60fe98
Fix compatibility
echarlaix Oct 31, 2023
e1ca1d6
style
echarlaix Oct 31, 2023
1873601
fix compatibility
echarlaix Oct 31, 2023
7a90d64
remove bigcode
echarlaix Oct 31, 2023
0e17883
fix
echarlaix Oct 31, 2023
79b75b9
style
echarlaix Oct 31, 2023
dd1cbff
Merge branch 'main' into release-optimum-fix
echarlaix Oct 31, 2023
8983274
fix test
echarlaix Nov 2, 2023
d6cdc10
fixes
echarlaix Nov 2, 2023
667809e
trigger test
echarlaix Nov 3, 2023
deff847
fix trainer
echarlaix Nov 3, 2023
0ad7dc2
fix trainer
echarlaix Nov 3, 2023
a328aa4
fix trainer
echarlaix Nov 3, 2023
bb8925e
fix test
echarlaix Nov 3, 2023
c21f736
fix trainer
echarlaix Nov 3, 2023
1343a34
fix conflicts
echarlaix Nov 6, 2023
339605f
fix conflits
echarlaix Nov 6, 2023
f188c0f
format
echarlaix Nov 6, 2023
3245732
format
echarlaix Nov 6, 2023
398aaa9
fix transformers version
echarlaix Nov 6, 2023
892374d
fix version
echarlaix Nov 6, 2023
5311077
version
echarlaix Nov 6, 2023
ae99a1b
fix
echarlaix Nov 6, 2023
99a0615
fix
echarlaix Nov 6, 2023
65fa4b2
fix
echarlaix Nov 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models


Expand Down Expand Up @@ -257,24 +256,18 @@ class StoreAttr(object):
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
)
else:
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant="default",
legacy=False,
)

if int8 is None:
int8 = False
Expand Down
131 changes: 60 additions & 71 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask


if is_transformers_version("<", "4.25.0"):
Expand All @@ -47,55 +48,37 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
if "text-generation" in task:
onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
pkv.append([])
for j in range(len(model_inputs["past_key_values"][0])):
pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype))
pkv[i] = tuple(pkv[i])
model_inputs["past_key_values"] = tuple(pkv)
i = model_inputs["input_ids"]
a = model_inputs["attention_mask"]
model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype)
model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype)
return model_inputs

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)

torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
if is_torch_version(">", "2.0.1"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
if is_torch_version(">=", "2.1.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
if is_torch_version(">=", "2.0.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)

traced_model = torch.jit.freeze(traced_model.eval())
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model


class PreTrainedModel(OptimizedModel):
pass


class BaseModelForCausalLM(PreTrainedModel, GenerationMixin):
class BaseModelForCausalLM(OptimizedModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
main_input_name = "input_ids"
Expand Down Expand Up @@ -156,12 +139,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

position_ids = kwargs.get("position_ids", None)

attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": self.use_cache,
"position_ids": None,
"attention_mask": kwargs.get("attention_mask", None),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": None,
}

Expand Down Expand Up @@ -258,6 +252,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if attention_mask is None:
Expand All @@ -268,43 +263,42 @@ def forward(
"attention_mask": attention_mask,
}

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
num_layers = self.normalized_config.num_layers
num_attention_heads = self.normalized_config.num_attention_heads
num_key_value_heads = num_attention_heads
if hasattr(self.normalized_config, "num_key_value_heads"):
num_key_value_heads = self.normalized_config.num_key_value_heads
hidden_size = self.normalized_config.hidden_size
d_k = hidden_size // num_attention_heads
if self.config.model_type == "gpt_bigcode":
new_shape = [input_ids.shape[0], 0, d_k * 2]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = tuple(empty_tensor for _ in range(nb_pkv))
d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
batch_size = input_ids.shape[0]

if model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0]
else:
new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = pkv + (empty_tensor,)
if past_key_values is None:
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
num_attention_heads = self.normalized_config.num_attention_heads

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(
tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)
)
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
shape = (batch_size, 0, d_k * 2)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand Down Expand Up @@ -389,7 +383,7 @@ def _from_transformers(
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs,
):
if is_torch_version("<", "2.0.0"):
if is_torch_version("<", "2.1.0"):
raise ImportError("`torch>=2.0.0` is needed to trace your model")

task = cls.export_feature
Expand All @@ -405,12 +399,7 @@ def _from_transformers(
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

if model.config.model_type == "bloom":
model.transformer._prepare_attn_mask = _prepare_attn_mask

if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
Expand Down
15 changes: 8 additions & 7 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,12 @@
logger = logging.getLogger(__name__)


# workaround to enable compatibility between openvino models and transformers pipelines
class PreTrainedModel(OptimizedModel):
pass


@add_start_docstrings(
"""
Base OVModel class.
""",
)
class OVBaseModel(PreTrainedModel):
class OVBaseModel(OptimizedModel):
auto_model_class = None
export_feature = None

Expand Down Expand Up @@ -86,6 +81,12 @@ def __init__(
input_names[next((name for name in names if "/" not in name), names[0])] = idx
self.input_names = input_names

output_names = {}
for idx, key in enumerate(model.outputs):
names = tuple(key.get_names())
output_names[next((name for name in names if "/" not in name), names[0])] = idx
self.output_names = output_names

self.model = model
self.request = None
if enable_compilation:
Expand Down Expand Up @@ -302,7 +303,7 @@ def _from_transformers(
@classmethod
def _to_load(
cls,
model: PreTrainedModel,
model,
config: PretrainedConfig,
onnx_config: OnnxConfig,
use_auth_token: Optional[Union[bool, str]] = None,
Expand Down
22 changes: 18 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __init__(
self.main_input_name = "input_ids"
self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]
self._original_model = self.model.clone() # keep original model for serialization
Expand Down Expand Up @@ -313,6 +312,7 @@ def forward(
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()
Expand Down Expand Up @@ -362,14 +362,28 @@ def forward(

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)
attention_mask = np.array(attention_mask)
else:
inputs["attention_mask"] = np.ones(
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = np.array(position_ids)
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)

inputs["position_ids"] = position_ids

# Run inference
self.request.start_async(inputs, shared_memory=True)
self.request.wait()
Expand Down
8 changes: 4 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

from ...exporters.openvino import export, export_pytorch_via_onnx
from ..utils.constant import _TASK_ALIASES
from ..utils.modeling_utils import patch_decoder_attention_mask
from .configuration import OVConfig
from .modeling_base import OVBaseModel
from .modeling_decoder import OVBaseDecoderModel
Expand Down Expand Up @@ -394,9 +393,10 @@ def _quantize_torchmodel(
task = self.task
model = self.model
self.model.config.save_pretrained(save_directory)
model = patch_decoder_attention_mask(model)
if task == "text-generation":
onnx_config = onnx_config_class(model.config, use_past=model.config.use_cache)
if task.startswith("text-generation"):
onnx_config = onnx_config_class(
model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache
)
else:
onnx_config = onnx_config_class(model.config)

Expand Down
Loading
Loading