diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index 29e9bb2249..cc3953007c 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -75,11 +75,11 @@ def __init__( inc_config: Dict = None, **kwargs, ): - super().__init__(model=model, config=config) - + super().__init__(model=model, config=config, **kwargs) self.inc_config = inc_config self._q_config = q_config self.model_save_dir = model_save_dir + self._device = getattr(self.model, "device", None) or torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if getattr(self.config, "backend", None) == "ipex": if not is_ipex_available(): @@ -107,7 +107,7 @@ def _from_pretrained( revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, - file_name: Optional[str] = WEIGHTS_NAME, + file_name: str = WEIGHTS_NAME, local_files_only: bool = False, subfolder: str = "", **kwargs, @@ -176,8 +176,8 @@ def _from_pretrained( model, config=config, model_save_dir=model_save_dir, q_config=q_config, inc_config=inc_config, **kwargs ) - def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) + def _save_pretrained(self, save_directory: Union[str, Path], file_name : str = WEIGHTS_NAME): + output_path = os.path.join(save_directory, file_name) if isinstance(self.model, torch.nn.Module): state_dict = self.model.state_dict() @@ -198,11 +198,12 @@ def eval(self): return self @property - def device(self): - return self.model.device + def device(self) -> torch.device: + return self._device - def to(self, device: str): - self.model.to(device) + def to(self, device: Union[torch.device, str]): + self._device = device if isinstance(device, torch.device) else torch.device(device) + self.model.to(self._device) return self def can_generate(self): diff --git a/optimum/intel/neural_compressor/modeling_decoder.py b/optimum/intel/neural_compressor/modeling_decoder.py index 7d2a2db931..80851b1f43 100644 --- a/optimum/intel/neural_compressor/modeling_decoder.py +++ b/optimum/intel/neural_compressor/modeling_decoder.py @@ -15,7 +15,7 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Union +from typing import Optional, Union, Dict from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings @@ -47,9 +47,11 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + q_config: Dict = None, + inc_config: Dict = None, use_cache: bool = True, **kwargs, ): super(INCModelForCausalLM, self).__init__( - model=model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs + model=model, config=config, model_save_dir=model_save_dir, q_config=q_config, inc_config=inc_config, use_cache=use_cache, **kwargs ) diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index 7470429489..93e849d04a 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -16,6 +16,7 @@ import os import tempfile import unittest +import time import torch from parameterized import parameterized @@ -65,7 +66,21 @@ DIFFUSERS_MODEL_NAMES_TO_TASK = (("echarlaix/stable-diffusion-v1-5-inc-int8-dynamic", "stable-diffusion"),) + +class Timer(object): + def __enter__(self): + self.elapsed = time.perf_counter() + return self + + def __exit__(self, type, value, traceback): + self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 + + + class INCModelingTest(unittest.TestCase): + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.1 + @parameterized.expand(MODEL_NAMES_TO_TASK + QUANTIZED_MODEL_NAMES_TO_TASK) def test_compare_to_transformers(self, model_id, task): model_class = eval(_HEAD_TO_AUTOMODELS[task]) @@ -120,3 +135,31 @@ def test_pipeline(self, model_id, task): inputs *= 2 pipe(*inputs) + + def test_compare_with_and_without_past_key_values(self): + model_id = "echarlaix/tiny-random-gpt2-torchscript" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + + model_with_pkv = INCModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv") + # Warmup + model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = INCModelForCausalLM.from_pretrained(model_id, use_cache=False, subfolder="model_without_pkv") + # Warmup + model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + )