Skip to content

Commit

Permalink
Use contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Dec 11, 2024
1 parent 19328ee commit e50f5f4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
41 changes: 20 additions & 21 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import operator
import warnings
from functools import reduce, partial
from functools import reduce
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -46,6 +46,7 @@
MULTI_MODAL_TEXT_GENERATION_MODELS,
clear_class_registry,
deduce_diffusers_dtype,
patch_not_check_trace,
)


Expand Down Expand Up @@ -234,7 +235,7 @@ def main_export(
do_gptq_patching = False
custom_architecture = False
patch_16bit = False
orig_trace = None
patch_trace = False
loading_kwargs = model_loading_kwargs or {}
if library_name == "transformers":
config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -351,8 +352,7 @@ class StoreAttr(object):
if dtype in [torch.float16, torch.bfloat16]:
loading_kwargs["torch_dtype"] = dtype
patch_16bit = True
orig_trace = torch.jit.trace
torch.jit.trace = partial(torch.jit.trace, check_trace=False)
patch_trace = True

if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
Expand Down Expand Up @@ -420,21 +420,22 @@ class StoreAttr(object):
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)

submodel_paths = export_from_model(
model=model,
output=output,
task=task,
ov_config=ov_config,
stateful=stateful,
model_kwargs=model_kwargs,
custom_export_configs=custom_export_configs,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
device=device,
trust_remote_code=trust_remote_code,
patch_16bit_model=patch_16bit,
**kwargs_shapes,
)
with patch_not_check_trace(patch_trace):
submodel_paths = export_from_model(
model=model,
output=output,
task=task,
ov_config=ov_config,
stateful=stateful,
model_kwargs=model_kwargs,
custom_export_configs=custom_export_configs,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
device=device,
trust_remote_code=trust_remote_code,
patch_16bit_model=patch_16bit,
**kwargs_shapes,
)

if convert_tokenizer:
maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task)
Expand Down Expand Up @@ -494,8 +495,6 @@ class StoreAttr(object):
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
if orig_trace is not None:
torch.jit.trace = orig_trace


def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from contextlib import contextmanager
from functools import partial

from transformers import PretrainedConfig
from transformers.utils import is_torch_available
Expand Down Expand Up @@ -288,3 +290,15 @@ def save_preprocessors(
logger.error(f"Saving {type(processor)} failed with {ex}")
else:
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)


@contextmanager
def patch_not_check_trace(to_patch):
original_trace = torch.jit.trace
if to_patch:
torch.jit.trace = partial(torch.jit.trace, check_trace=False)
try:
yield
finally:
if to_patch:
torch.jit.trace = original_trace

0 comments on commit e50f5f4

Please sign in to comment.