Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 27, 2023
1 parent f7feba5 commit 319eac3
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
5 changes: 2 additions & 3 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.modeling_utils import patch_decoder_attention_mask
from .convert import export_models

Expand Down Expand Up @@ -214,7 +215,6 @@ def main_export(
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")


if not task.startswith("text-generation"):
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
model=model,
Expand All @@ -229,9 +229,8 @@ def main_export(
# TODO : ModelPatcher will be added in next optimum release
model = patch_decoder_attention_mask(model)

use_cache = task.endswith("-with-past")
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
onnx_config = onnx_config_constructor(model.config) # TODO : optimum next release need use_past_in_inputs=use_cache
onnx_config = onnx_config_constructor(model.config)
models_and_onnx_configs = {"model": (model, onnx_config)}

if not is_stable_diffusion:
Expand Down
1 change: 0 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from transformers.file_utils import add_start_docstrings

from optimum.exporters.onnx import OnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.modeling_base import OptimizedModel

from ...exporters.openvino import export, main_export
Expand Down
8 changes: 2 additions & 6 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from transformers import PretrainedConfig
from transformers.file_utils import add_start_docstrings

from optimum.exporters import TasksManager
from optimum.exporters.onnx import get_encoder_decoder_models_for_export

from ...exporters.openvino import export_models, main_export
from ...exporters.openvino import main_export
from ..utils.import_utils import is_transformers_version
from .modeling_base import OVBaseModel
from .utils import (
Expand Down Expand Up @@ -267,8 +264,7 @@ def _from_transformers(
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)

return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)

def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
shapes = {}
Expand Down
6 changes: 2 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.exporters import TasksManager
from optimum.utils import NormalizedConfigManager

from ...exporters.openvino import export, main_export
from ...exporters.openvino import main_export
from ..utils.import_utils import is_transformers_version
from ..utils.modeling_utils import patch_decoder_attention_mask
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE

Expand Down Expand Up @@ -244,7 +242,7 @@ def _from_transformers(
config.is_decoder = True
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)
return cls._from_pretrained(model_id=save_dir_path, config=config, use_cache=use_cache, **kwargs)

def _reshape(
self,
Expand Down

0 comments on commit 319eac3

Please sign in to comment.