Skip to content

Commit

Permalink
Do not check trace for diffusers, saving memory and time
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Dec 11, 2024
1 parent bb51139 commit 19328ee
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import operator
import warnings
from functools import reduce
from functools import reduce, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -234,6 +234,7 @@ def main_export(
do_gptq_patching = False
custom_architecture = False
patch_16bit = False
orig_trace = None
loading_kwargs = model_loading_kwargs or {}
if library_name == "transformers":
config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -350,6 +351,8 @@ class StoreAttr(object):
if dtype in [torch.float16, torch.bfloat16]:
loading_kwargs["torch_dtype"] = dtype
patch_16bit = True
orig_trace = torch.jit.trace
torch.jit.trace = partial(torch.jit.trace, check_trace=False)

if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
Expand Down Expand Up @@ -491,6 +494,8 @@ class StoreAttr(object):
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
if orig_trace is not None:
torch.jit.trace = orig_trace


def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):
Expand Down

0 comments on commit 19328ee

Please sign in to comment.