Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 4, 2024
1 parent c26a450 commit ce7789f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
8 changes: 4 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,8 +3386,8 @@ class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: OnnxConfig,
model: PreTrainedModel | TFPreTrainedModel,
model_kwargs: Dict[str, Any] | None = None,
model: Union[PreTrainedModel, TFPreTrainedModel],
model_kwargs: Dict[str, Any] = None,
):

model.__orig_forward = model.forward
Expand Down Expand Up @@ -3426,8 +3426,8 @@ class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
def __init__(
self,
config: OnnxConfig,
model: PreTrainedModel | TFPreTrainedModel,
model_kwargs: Dict[str, Any] | None = None,
model: Union[PreTrainedModel, TFPreTrainedModel],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

Expand Down
10 changes: 9 additions & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,15 @@ def get_submodels(model):
return custom_export, fn_get_submodels


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v", "qwen2-vl"]
MULTI_MODAL_TEXT_GENERATION_MODELS = [
"llava",
"llava-next",
"llava-qwen2",
"internvl-chat",
"minicpmv",
"phi3-v",
"qwen2-vl",
]


def save_config(config, save_dir):
Expand Down
12 changes: 6 additions & 6 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def __init__(
for idx, key in enumerate(model.inputs):
names = tuple(key.get_names())
input_names[next((name for name in names if "/" not in name), names[0])] = idx
input_dtypes[
next((name for name in names if "/" not in name), names[0])
] = key.get_element_type().get_type_name()
input_dtypes[next((name for name in names if "/" not in name), names[0])] = (
key.get_element_type().get_type_name()
)
self.input_names = input_names
self.input_dtypes = input_dtypes

Expand All @@ -122,9 +122,9 @@ def __init__(
for idx, key in enumerate(model.outputs):
names = tuple(key.get_names())
output_names[next((name for name in names if "/" not in name), names[0])] = idx
output_dtypes[
next((name for name in names if "/" not in name), names[0])
] = key.get_element_type().get_type_name()
output_dtypes[next((name for name in names if "/" not in name), names[0])] = (
key.get_element_type().get_type_name()
)

self.output_names = output_names
self.output_dtypes = output_dtypes
Expand Down

0 comments on commit ce7789f

Please sign in to comment.