From e50f5f432b3f7089ad55c502dd978a46f910bf09 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 11 Dec 2024 18:25:23 +0100 Subject: [PATCH] Use contextmanager --- optimum/exporters/openvino/__main__.py | 41 +++++++++++++------------- optimum/exporters/openvino/utils.py | 14 +++++++++ 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 62142fd22..65648b275 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -16,7 +16,7 @@ import logging import operator import warnings -from functools import reduce, partial +from functools import reduce from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union @@ -46,6 +46,7 @@ MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry, deduce_diffusers_dtype, + patch_not_check_trace, ) @@ -234,7 +235,7 @@ def main_export( do_gptq_patching = False custom_architecture = False patch_16bit = False - orig_trace = None + patch_trace = False loading_kwargs = model_loading_kwargs or {} if library_name == "transformers": config = AutoConfig.from_pretrained( @@ -351,8 +352,7 @@ class StoreAttr(object): if dtype in [torch.float16, torch.bfloat16]: loading_kwargs["torch_dtype"] = dtype patch_16bit = True - orig_trace = torch.jit.trace - torch.jit.trace = partial(torch.jit.trace, check_trace=False) + patch_trace = True if library_name == "open_clip": model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) @@ -420,21 +420,22 @@ class StoreAttr(object): model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code ) - submodel_paths = export_from_model( - model=model, - output=output, - task=task, - ov_config=ov_config, - stateful=stateful, - model_kwargs=model_kwargs, - custom_export_configs=custom_export_configs, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - device=device, - trust_remote_code=trust_remote_code, - patch_16bit_model=patch_16bit, - **kwargs_shapes, - ) + with patch_not_check_trace(patch_trace): + submodel_paths = export_from_model( + model=model, + output=output, + task=task, + ov_config=ov_config, + stateful=stateful, + model_kwargs=model_kwargs, + custom_export_configs=custom_export_configs, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + device=device, + trust_remote_code=trust_remote_code, + patch_16bit_model=patch_16bit, + **kwargs_shapes, + ) if convert_tokenizer: maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task) @@ -494,8 +495,6 @@ class StoreAttr(object): if do_gptq_patching: torch.cuda.is_available = orig_cuda_check GPTQQuantizer.post_init_model = orig_post_init_model - if orig_trace is not None: - torch.jit.trace = orig_trace def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None): diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index db4df6d0d..68629bc10 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -17,6 +17,8 @@ from collections import namedtuple from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from contextlib import contextmanager +from functools import partial from transformers import PretrainedConfig from transformers.utils import is_torch_available @@ -288,3 +290,15 @@ def save_preprocessors( logger.error(f"Saving {type(processor)} failed with {ex}") else: maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) + + +@contextmanager +def patch_not_check_trace(to_patch): + original_trace = torch.jit.trace + if to_patch: + torch.jit.trace = partial(torch.jit.trace, check_trace=False) + try: + yield + finally: + if to_patch: + torch.jit.trace = original_trace