diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index be4f41b66a..dcd2827eab 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -114,6 +114,7 @@ else: _import_structure["neural_compressor"] = [ "INCConfig", + "INCModel", "INCModelForCausalLM", "INCModelForMaskedLM", "INCModelForMultipleChoice", diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index 19c06c8c4c..e6ae0f2595 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -31,6 +31,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForVision2Seq, + GenerationMixin, PretrainedConfig, XLNetLMHeadModel, ) @@ -39,11 +40,8 @@ from transformers.utils import is_ipex_available from transformers.utils.generic import ContextManagers -from ...exporters import TasksManager from ...modeling_base import OptimizedModel -from ..generation.modeling import jit_trace from ..utils.import_utils import _torch_version, is_torch_version -from ..utils.modeling_utils import patch_decoder_attention_mask from .configuration import INCConfig from .utils import WEIGHTS_NAME @@ -65,6 +63,7 @@ class INCModel(OptimizedModel): auto_model_class = AutoModel + export_feature = "feature-extraction" base_model_prefix = "inc_model" def __init__( @@ -76,12 +75,13 @@ def __init__( inc_config: Dict = None, **kwargs, ): - super().__init__(model=model, config=config) - + super().__init__(model=model, config=config, **kwargs) self.inc_config = inc_config self._q_config = q_config self.model_save_dir = model_save_dir - self.is_quantized = q_config is not None + self._device = getattr(self.model, "device", None) or torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) if getattr(self.config, "backend", None) == "ipex": if not is_ipex_available(): @@ -109,9 +109,10 @@ def _from_pretrained( revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, - file_name: Optional[str] = WEIGHTS_NAME, + file_name: str = WEIGHTS_NAME, local_files_only: bool = False, subfolder: str = "", + trust_remote_code: bool = False, **kwargs, ): model_name_or_path = kwargs.pop("model_name_or_path", None) @@ -143,8 +144,10 @@ def _from_pretrained( if not is_torch_version("==", inc_config.torch_version): msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found." logger.warning(f"{msg}") - except Exception: - logger.info("Couldn't verify torch version.") + except EnvironmentError: + msg = ( + f"Please check if torch quantization the model was obtained with is compatible with {_torch_version}." + ) if getattr(config, "backend", None) == "ipex" or getattr(config, "torchscript", False): # NOTE: Will improve to use load function when Intel Neural Compressor next 2.1 release. @@ -195,63 +198,26 @@ def forward(self, *args, **kwargs): def eval(self): self.model.eval() + return self - @classmethod - def _from_transformers( - cls, - model_id: str, - config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - torch_dtype: Optional[Union[str, "torch.dtype"]] = None, - **kwargs, - ): - if is_torch_version("<", "2.0.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") - - task = cls.export_feature - kwargs.get("file_name", None) - - model_kwargs = { - "revision": revision, - "use_auth_token": use_auth_token, - "cache_dir": cache_dir, - "subfolder": subfolder, - "local_files_only": local_files_only, - "force_download": force_download, - "torch_dtype": torch_dtype, - } - - if config.torch_dtype == "int8" or config.torch_dtype == torch.int8: - raise ValueError("quantized model cannot be exported") - - model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - - if task == "text-generation": - model = patch_decoder_attention_mask(model) - - traced_model = jit_trace(model, task, use_cache) - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) - config.torchscript = True - - return cls._from_pretrained( - model_id=save_dir_path, - config=config, - use_cache=use_cache, - use_auth_token=use_auth_token, - revision=revision, - force_download=force_download, - cache_dir=cache_dir, - local_files_only=local_files_only, - **kwargs, - ) + @property + def device(self) -> torch.device: + return self._device + + def to(self, device: Union[torch.device, str]): + self._device = device if isinstance(device, torch.device) else torch.device(device) + self.model.to(self._device) + return self + + def can_generate(self): + return isinstance(self.model, GenerationMixin) + + def generate(self, *args, **kwargs): + if not self.can_generate(): + raise TypeError( + f"The current model class {self.model.__class__} is not compatible with `.generate()`, as it doesn't have a language model head." + ) + return self.model.generate(*args, **kwargs) class INCModelForQuestionAnswering(INCModel): diff --git a/optimum/intel/neural_compressor/modeling_decoder.py b/optimum/intel/neural_compressor/modeling_decoder.py index 8d633f8dd1..e284ce4c3e 100644 --- a/optimum/intel/neural_compressor/modeling_decoder.py +++ b/optimum/intel/neural_compressor/modeling_decoder.py @@ -15,7 +15,7 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Union +from typing import Dict, Optional, Union from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings @@ -39,15 +39,25 @@ class INCModelForCausalLM(INCModel, BaseModelForCausalLM): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" forward = BaseModelForCausalLM.forward + generate = BaseModelForCausalLM.generate + can_generate = BaseModelForCausalLM.can_generate def __init__( self, model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + q_config: Dict = None, + inc_config: Dict = None, use_cache: bool = True, **kwargs, ): super(INCModelForCausalLM, self).__init__( - model=model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs + model=model, + config=config, + model_save_dir=model_save_dir, + q_config=q_config, + inc_config=inc_config, + use_cache=use_cache, + **kwargs, ) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index cbaf39258f..c844fdcdec 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -52,7 +52,18 @@ is_ipex_version, is_neural_compressor_version, ) -from .configuration import INCConfig, WeightOnlyQuantConfig +from .configuration import INCConfig +from .modeling_base import ( # noqa + INCModel, + INCModelForMaskedLM, + INCModelForMultipleChoice, + INCModelForQuestionAnswering, + INCModelForSeq2SeqLM, + INCModelForSequenceClassification, + INCModelForTokenClassification, + INCModelForVision2Seq, + INCModelForXLNetLM, +) from .utils import INCDataLoader, _cfgs_to_fx_cfgs diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index fa21122595..3e36065195 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -41,6 +41,7 @@ "question-answering": "INCModelForQuestionAnswering", "multiple-choice": "INCModelForMultipleChoice", "stable-diffusion": "INCStableDiffusionPipeline", + "feature-extraction": "INCModel", } diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 7b2eeb540e..d781f1f513 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -205,7 +205,10 @@ def is_torch_version(operation: str, version: str): """ if not _torch_available: return False - return compare_versions(parse(_torch_version), operation, version) + + import torch + + return compare_versions(parse(parse(torch.__version__).base_version), operation, version) def is_ipex_version(operation: str, version: str): diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index 51ae535920..fc2a310595 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -15,17 +15,20 @@ import os import tempfile +import time import unittest import torch from parameterized import parameterized -from transformers import set_seed +from transformers import AutoTokenizer, pipeline, set_seed from optimum.exporters import TasksManager from optimum.intel import ( # noqa INCConfig, + INCModel, INCModelForCausalLM, INCModelForMaskedLM, + INCModelForMultipleChoice, INCModelForQuestionAnswering, INCModelForSeq2SeqLM, INCModelForSequenceClassification, @@ -35,7 +38,7 @@ INCStableDiffusionPipeline, INCTrainer, ) -from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS +from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -44,22 +47,40 @@ QUANTIZED_MODEL_NAMES_TO_TASK = ( ("echarlaix/distilbert-base-uncased-finetuned-sst-2-english-int8-dynamic", "text-classification"), - ("echarlaix/distilbert-sst2-inc-dynamic-quantization-magnitude-pruning-0.1", "text-classification"), ("Intel/distilbert-base-uncased-distilled-squad-int8-static", "question-answering"), ("Intel/t5-small-xsum-int8-dynamic", "text2text-generation"), - # ("echarlaix/stable-diffusion-v1-5-inc-int8-dynamic", "stable-diffusion") ) MODEL_NAMES_TO_TASK = ( ("hf-internal-testing/tiny-random-gpt2", "text-generation"), - ("hf-internal-testing/tiny-random-bert", "fill-mask"), + ("hf-internal-testing/tiny-random-BertForMaskedLM", "fill-mask"), + ("hf-internal-testing/tiny-random-DistilBertForSequenceClassification", "text-classification"), + ("hf-internal-testing/tiny-random-DebertaV2Model", "feature-extraction"), + ("hf-internal-testing/tiny-random-MobileBertForQuestionAnswering", "question-answering"), + ("hf-internal-testing/tiny-random-BartForConditionalGeneration", "text2text-generation"), + ("hf-internal-testing/tiny-random-RobertaForTokenClassification", "token-classification"), + ("hf-internal-testing/tiny-random-BertForMultipleChoice", "multiple-choice"), ) +DIFFUSERS_MODEL_NAMES_TO_TASK = (("echarlaix/stable-diffusion-v1-5-inc-int8-dynamic", "stable-diffusion"),) + + +class Timer(object): + def __enter__(self): + self.elapsed = time.perf_counter() + return self + + def __exit__(self, type, value, traceback): + self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 + class INCModelingTest(unittest.TestCase): + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.1 + @parameterized.expand(MODEL_NAMES_TO_TASK + QUANTIZED_MODEL_NAMES_TO_TASK) - def test_modeling(self, model_id, task): + def test_compare_to_transformers(self, model_id, task): model_class = eval(_HEAD_TO_AUTOMODELS[task]) inc_model = model_class.from_pretrained(model_id) model_type = inc_model.config.model_type.replace("_", "-") @@ -73,38 +94,71 @@ def test_modeling(self, model_id, task): config = config_class(inc_model.config) model_inputs = config.generate_dummy_inputs(framework="pt") outputs = inc_model(**model_inputs) - with tempfile.TemporaryDirectory() as tmpdirname: inc_model.save_pretrained(tmpdirname) - loaded_model = model_class.from_pretrained(tmpdirname) + loaded_model = model_class.from_pretrained(tmpdirname, file_name=WEIGHTS_NAME) outputs_loaded = loaded_model(**model_inputs) - output_name = "end_logits" if task == "question-answering" else "logits" - self.assertTrue(torch.equal(outputs_loaded[output_name], outputs[output_name])) + if task == "feature-extraction": + output_name = "last_hidden_state" + elif task == "question-answering": + output_name = "end_logits" + else: + output_name = "logits" - @parameterized.expand(MODEL_NAMES_TO_TASK) - def test_export_modeling(self, model_id, task): - model_class = eval(_HEAD_TO_AUTOMODELS[task]) - inc_model = model_class.from_pretrained(model_id) - model_type = inc_model.config.model_type.replace("_", "-") - config_class = TasksManager.get_exporter_config_constructor( - exporter="onnx", - model=inc_model, - task=task, - model_name=model_id, - model_type=model_type, - ) - config = config_class(inc_model.config) - model_inputs = config.generate_dummy_inputs(framework="pt") - outputs = inc_model(**model_inputs) - transformers_model = model_class.auto_model_class.from_pretrained(model_id) - transformers_outputs = transformers_model(**model_inputs) + # Compare to saved and loaded model + self.assertTrue(torch.equal(outputs_loaded[output_name], outputs[output_name])) - with tempfile.TemporaryDirectory() as tmpdirname: - inc_model.save_pretrained(tmpdirname) - loaded_model = model_class.from_pretrained(tmpdirname, export=True) - outputs_loaded = loaded_model(**model_inputs) + if inc_model._q_config is None: + transformers_model = model_class.auto_model_class.from_pretrained(model_id) + transformers_outputs = transformers_model(**model_inputs) + # Compare to original transformers model + self.assertTrue(torch.equal(transformers_outputs[output_name], outputs[output_name])) - output_name = "end_logits" if task == "question-answering" else "logits" - self.assertTrue(torch.equal(outputs_loaded[output_name], outputs[output_name])) - self.assertTrue(torch.equal(transformers_outputs[output_name], outputs[output_name])) + @parameterized.expand(MODEL_NAMES_TO_TASK + QUANTIZED_MODEL_NAMES_TO_TASK) + def test_pipeline(self, model_id, task): + if task == "multiple-choice": + self.skipTest("No pipeline for multiple choice") + + model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(model_id) + model.to("cpu") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_id) + pipe = pipeline(task, model=model, tokenizer=tokenizer) + self.assertEqual(pipe.device, model.device) + + inputs = ["This is a simple input" + (f"{tokenizer.mask_token}" if task == "fill-mask" else "")] + if task == "question-answering": + inputs *= 2 + + pipe(*inputs) + + def test_compare_with_and_without_past_key_values(self): + model_id = "echarlaix/tiny-random-gpt2-torchscript" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + + model_with_pkv = INCModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv") + # Warmup + model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = INCModelForCausalLM.from_pretrained( + model_id, use_cache=False, subfolder="model_without_pkv" + ) + # Warmup + model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + )