Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transformers v4.35.0 compatibility #471

Merged
merged 9 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test_inc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[neural-compressor,ipex,diffusers,tests]
pip install .[neural-compressor,diffusers,tests]
pip install intel-extension-for-pytorch
- name: Test with Pytest
run: |
pytest tests/neural_compressor/
60 changes: 58 additions & 2 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import inspect
import logging
import warnings
from enum import Enum
from itertools import chain
from pathlib import Path
Expand All @@ -30,16 +31,25 @@
from neural_compressor.quantization import fit
from torch.utils.data import DataLoader, RandomSampler
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
DataCollator,
PretrainedConfig,
PreTrainedModel,
XLNetLMHeadModel,
default_data_collator,
)

from optimum.exporters import TasksManager
from optimum.exporters.onnx import OnnxConfig
from optimum.onnxruntime import ORTModel
from optimum.onnxruntime.modeling_decoder import ORTModelDecoder
from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM
from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration
from optimum.onnxruntime.utils import ONNX_DECODER_NAME
from optimum.quantization_base import OptimumQuantizer
Expand Down Expand Up @@ -256,7 +266,7 @@ def quantize(
if isinstance(self._original_model, ORTModelForConditionalGeneration):
raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization")

if isinstance(self._original_model, ORTModelDecoder):
if isinstance(self._original_model, ORTModelForCausalLM):
model_or_path = self._original_model.onnx_paths
if len(model_or_path) > 1:
raise RuntimeError(
Expand Down Expand Up @@ -528,3 +538,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t
q_model = convert(q_model, mapping=q_mapping, inplace=True)

return q_model


class IncQuantizedModel(INCModel):
@classmethod
def from_pretrained(cls, *args, **kwargs):
warnings.warn(
f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use "
f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead."
)
return super().from_pretrained(*args, **kwargs)


class IncQuantizedModelForQuestionAnswering(IncQuantizedModel):
auto_model_class = AutoModelForQuestionAnswering


class IncQuantizedModelForSequenceClassification(IncQuantizedModel):
auto_model_class = AutoModelForSequenceClassification


class IncQuantizedModelForTokenClassification(IncQuantizedModel):
auto_model_class = AutoModelForTokenClassification


class IncQuantizedModelForMultipleChoice(IncQuantizedModel):
auto_model_class = AutoModelForMultipleChoice


class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel):
auto_model_class = AutoModelForSeq2SeqLM


class IncQuantizedModelForCausalLM(IncQuantizedModel):
auto_model_class = AutoModelForCausalLM


class IncQuantizedModelForMaskedLM(IncQuantizedModel):
auto_model_class = AutoModelForMaskedLM


class IncQuantizedModelForXLNetLM(IncQuantizedModel):
auto_model_class = XLNetLMHeadModel


class IncQuantizedModelForVision2Seq(IncQuantizedModel):
auto_model_class = AutoModelForVision2Seq
Loading