Skip to content

Commit

Permalink
Fix compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 31, 2023
1 parent b111ca9 commit d60fe98
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 80 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install git+https://github.com/huggingface/optimum.git
pip install .[openvino,nncf,tests,diffusers]
- name: Test with Pytest
run: |
Expand Down
137 changes: 66 additions & 71 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask
from ..utils.modeling_utils import patch_decoder_attention_mask

from ..utils.modeling_utils import patch_decoder_attention_mask, MULTI_QUERY_ATTN_MODELS

from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS


if is_transformers_version("<", "4.25.0"):
Expand All @@ -47,55 +51,37 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_class(model.config)
if task == "text-generation" and use_cache:
onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True)
if "text-generation" in task:
onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode":
# WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect
pkv = []
for i in range(len(model_inputs["past_key_values"])):
pkv.append([])
for j in range(len(model_inputs["past_key_values"][0])):
pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype))
pkv[i] = tuple(pkv[i])
model_inputs["past_key_values"] = tuple(pkv)
i = model_inputs["input_ids"]
a = model_inputs["attention_mask"]
model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype)
model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype)
return model_inputs

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
# check if the model_inputs is correct.
model(**model_inputs)

torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
if is_torch_version(">", "2.0.1"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
if is_torch_version(">=", "2.1.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
if is_torch_version(">=", "2.0.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)
traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False)

traced_model = torch.jit.freeze(traced_model.eval())
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model


class PreTrainedModel(OptimizedModel):
pass


class BaseModelForCausalLM(PreTrainedModel, GenerationMixin):
class BaseModelForCausalLM(OptimizedModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
main_input_name = "input_ids"
Expand Down Expand Up @@ -156,12 +142,28 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)


position_ids = kwargs.get("position_ids", None)


attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)




return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": self.use_cache,
"position_ids": None,
"attention_mask": kwargs.get("attention_mask", None),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": None,
}

Expand Down Expand Up @@ -258,6 +260,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if attention_mask is None:
Expand All @@ -268,43 +271,40 @@ def forward(
"attention_mask": attention_mask,
}

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
num_layers = self.normalized_config.num_layers
num_attention_heads = self.normalized_config.num_attention_heads
num_key_value_heads = num_attention_heads
if hasattr(self.normalized_config, "num_key_value_heads"):
num_key_value_heads = self.normalized_config.num_key_value_heads
hidden_size = self.normalized_config.hidden_size
d_k = hidden_size // num_attention_heads
if self.config.model_type == "gpt_bigcode":
new_shape = [input_ids.shape[0], 0, d_k * 2]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
past_key_values = tuple([empty_tensor] * num_layers)
elif self.config.model_type != "bloom":
new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = tuple(empty_tensor for _ in range(nb_pkv))
d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
batch_size = input_ids.shape[0]

if model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads

if model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers))
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
shape = (batch_size, 0, d_k * 2)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(pkv for _ in range(num_layers))
else:
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0]
else:
new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k]
empty_tensor = torch.empty(size=new_shape)
if self.model_dtype is not None:
empty_tensor = empty_tensor.to(self.model_dtype)
pkv = pkv + (empty_tensor,)
if past_key_values is None:
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))
shape = (batch_size, num_attention_heads, 0, d_k)
pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device)
past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand Down Expand Up @@ -389,7 +389,7 @@ def _from_transformers(
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs,
):
if is_torch_version("<", "2.0.0"):
if is_torch_version("<", "2.1.0"):
raise ImportError("`torch>=2.0.0` is needed to trace your model")

task = cls.export_feature
Expand All @@ -405,12 +405,7 @@ def _from_transformers(
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

if model.config.model_type == "bloom":
model.transformer._prepare_attn_mask = _prepare_attn_mask

if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
Expand Down
55 changes: 55 additions & 0 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@
from neural_compressor.quantization import fit
from torch.utils.data import DataLoader, RandomSampler
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
DataCollator,
PretrainedConfig,
PreTrainedModel,
XLNetLMHeadModel,
default_data_collator,
)

Expand Down Expand Up @@ -528,3 +537,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t
q_model = convert(q_model, mapping=q_mapping, inplace=True)

return q_model


class IncQuantizedModel(INCModel):
@classmethod
def from_pretrained(cls, *args, **kwargs):
warnings.warn(
f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use "
f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead."
)
return super().from_pretrained(*args, **kwargs)


class IncQuantizedModelForQuestionAnswering(IncQuantizedModel):
auto_model_class = AutoModelForQuestionAnswering


class IncQuantizedModelForSequenceClassification(IncQuantizedModel):
auto_model_class = AutoModelForSequenceClassification


class IncQuantizedModelForTokenClassification(IncQuantizedModel):
auto_model_class = AutoModelForTokenClassification


class IncQuantizedModelForMultipleChoice(IncQuantizedModel):
auto_model_class = AutoModelForMultipleChoice


class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel):
auto_model_class = AutoModelForSeq2SeqLM


class IncQuantizedModelForCausalLM(IncQuantizedModel):
auto_model_class = AutoModelForCausalLM


class IncQuantizedModelForMaskedLM(IncQuantizedModel):
auto_model_class = AutoModelForMaskedLM


class IncQuantizedModelForXLNetLM(IncQuantizedModel):
auto_model_class = XLNetLMHeadModel


class IncQuantizedModelForVision2Seq(IncQuantizedModel):
auto_model_class = AutoModelForVision2Seq
9 changes: 2 additions & 7 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,12 @@
logger = logging.getLogger(__name__)


# workaround to enable compatibility between openvino models and transformers pipelines
class PreTrainedModel(OptimizedModel):
pass


@add_start_docstrings(
"""
Base OVModel class.
""",
)
class OVBaseModel(PreTrainedModel):
class OVBaseModel(OptimizedModel):
auto_model_class = None
export_feature = None

Expand Down Expand Up @@ -302,7 +297,7 @@ def _from_transformers(
@classmethod
def _to_load(
cls,
model: PreTrainedModel,
model: "PreTrainedModel",
config: PretrainedConfig,
onnx_config: OnnxConfig,
use_auth_token: Optional[Union[bool, str]] = None,
Expand Down
Loading

0 comments on commit d60fe98

Please sign in to comment.