Skip to content

Commit

Permalink
Check neural-compressor version for export function
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
PenghuiCheng committed Jun 24, 2024
1 parent a141b3b commit 1d16103
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.model.torch_model import IPEXModel, PyTorchModel
from neural_compressor.quantization import fit
from neural_compressor.utils.export import torch_to_int8_onnx
from packaging.version import parse
from torch.utils.data import DataLoader, RandomSampler
from transformers import (
Expand Down Expand Up @@ -80,6 +79,12 @@
)


if is_neural_compressor_version("<", "2.6"):
from neural_compressor.experimental.export import torch_to_int8_onnx
else:
from neural_compressor.utils.export import torch_to_int8_onnx


if is_itrex_available():
if is_itrex_version("<", ITREX_MINIMUM_VERSION):
raise ImportError(
Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from neural_compressor import training
from neural_compressor.compression import DistillationCallbacks
from neural_compressor.conf.pythonic_config import _BaseQuantizationConfig
from neural_compressor.utils.export import torch_to_fp32_onnx, torch_to_int8_onnx
from packaging import version
from torch import nn
from torch.utils.data import Dataset, RandomSampler
Expand Down Expand Up @@ -107,6 +106,11 @@
if TYPE_CHECKING:
from optimum.exporters.onnx import OnnxConfig

if is_neural_compressor_version("<", "2.6"):
from neural_compressor.experimental.export import torch_to_fp32_onnx, torch_to_int8_onnx
else:
from neural_compressor.utils.export import torch_to_fp32_onnx, torch_to_int8_onnx


__version__ = "4.22.2"

Expand Down

0 comments on commit 1d16103

Please sign in to comment.