Skip to content

Commit

Permalink
Fix causal lm export (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Nov 8, 2023
1 parent b42698d commit f248835
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
18 changes: 16 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
from optimum.utils import is_diffusers_available

from ...intel.utils.import_utils import is_nncf_available
Expand Down Expand Up @@ -297,14 +298,21 @@ def export_pytorch(
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)
input_info = get_input_shapes(dummy_inputs, inputs)
custom_patcher = type(config).patch_model_for_export != OnnxConfig.patch_model_for_export
patch_model_forward = False
orig_forward = model.forward
try:
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to Flase
if custom_patcher or dict_inputs:
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward
# DecoderModelPatcher does not override model forward
if isinstance(patcher, DecoderModelPatcher) or patcher.orig_forward_name != "forward":
patch_model_forward = True
patched_forward = model.forward
else:
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
Expand All @@ -317,14 +325,20 @@ def ts_patched_forward(*args, **kwargs):
outputs = patched_forward(*args, **kwargs)
return tuple(outputs.values())

patcher.patched_forward = ts_patched_forward
if not patch_model_forward:
patcher.patched_forward = ts_patched_forward
else:
model.forward = ts_patched_forward
with patcher:
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
else:
model.config.torchscript = True
model.config.retun_dict = False
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
if patch_model_forward:
model.forward = orig_forward
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8
)
Expand Down
2 changes: 0 additions & 2 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,4 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
expected = expected_int8[i]
if task == "text-generation":
expected -= 1
self.assertEqual(expected, num_int8)
2 changes: 2 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def preprocess_function(examples, tokenizer):
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
task = model_cls.export_feature
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
if "gpt2" in model_name:
expected_int8 -= 1

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"albert": (42,),
"vit": (31,),
"blenderbot": (35,),
"gpt2": (23,),
"gpt2": (22,),
"wav2vec2": (15,),
"distilbert": (33,),
"t5": (32, 52, 42),
Expand Down

0 comments on commit f248835

Please sign in to comment.