diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 3a15214f99..f3398858a7 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -30,8 +30,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install cmake>=3.16 + pip install py-cpuinfo + pip install torch==2.1.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu pip install .[neural-compressor,diffusers,tests] - pip install intel-extension-for-pytorch + pip install intel-extension-for-pytorch==2.1.100 - name: Test with Pytest run: | pytest tests/neural_compressor/ diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index c41d3e4b32..5a6256d6b1 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -63,6 +63,8 @@ if is_intel_extension_for_transformers_available(): from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig + from optimum.intel.neural_compressor import ITREXAutoModelForCausalLM + os.environ["CUDA_VISIBLE_DEVICES"] = "" # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -147,7 +149,9 @@ class OptimizationArguments: ) quantization_approach: str = field( default="dynamic", - metadata={"help": "Quantization approach. Supported approach are static, dynamic and aware_training."}, + metadata={ + "help": "Quantization approach. Supported approach are static, dynamic aware_training and weight_only." + }, ) smooth_quant: bool = field( default=False, @@ -200,8 +204,12 @@ class OptimizationArguments: default=False, metadata={"help": "Whether or not to verify the loading of the quantized model."}, ) + bits: str = field( + default="4", + metadata={"help": "Bits number of weight for weight only quantization. 1~8 bits."}, + ) weight_dtype: str = field( - default="int8", + default="int4_clip", metadata={"help": "weight dtype for weight only quantization."}, ) group_size: int = field( @@ -218,9 +226,24 @@ class OptimizationArguments: ) quantization_methodology: str = field( default="RTN", - metadata={ - "help": "Quantization methodology for weight only quantization. Choose from 'RTN', 'AWQ' and 'GPTQ'." - }, + metadata={"help": "Quantization methodology for weight only quantization. Choose from 'RTN' and 'GPTQ'."}, + ) + gptq_percdamp: float = field( + default=0.01, + metadata={"help": "Percent of the average Hessian diagonal to use for dampening."}, + ) + gptq_block_size: int = field( + default=128, + metadata={"help": "Block size. sub weight matrix size to run GPTQ."}, + ) + gptq_nsamples: int = field(default=128, metadata={"help": "Number of calibration data samples."}) + gptq_use_max_length: bool = field( + default=False, + metadata={"help": "Set all sequence length to be same length of args.gptq_pad_max_length"}, + ) + gptq_pad_max_length: int = field( + default=2048, + metadata={"help": "Calibration dataset sequence max length, this should align with your model config"}, ) @@ -636,11 +659,21 @@ def compute_metrics(eval_preds): ) if optim_args.apply_pruning or optim_args.apply_distillation: raise ValueError("Weight only quantization and pruning or distillation cannot be combined.") + if optim_args.quantization_methodology == "GPTQ": + algorithm_args = { + "act_order": False, + "percdamp": optim_args.gptq_percdamp, + "block_size": optim_args.gptq_block_size, + "nsamples": optim_args.gptq_nsamples, + "use_max_length": optim_args.gptq_use_max_length, + "pad_max_length": optim_args.gptq_pad_max_length, + } quantization_config = WeightOnlyQuantConfig( weight_dtype=optim_args.weight_dtype, group_size=optim_args.group_size, scheme=optim_args.weight_only_scheme, algorithm=optim_args.quantization_methodology, + algorithm_args=algorithm_args if optim_args.quantization_methodology == "GPTQ" else None, ) else: quantization_config = PostTrainingQuantConfig( @@ -733,17 +766,20 @@ def compute_metrics(eval_preds): quantizer.quantize( quantization_config=quantization_config, save_directory=training_args.output_dir, - calibration_dataset=train_dataset - if optim_args.quantization_approach in ["static", "weight_only"] - else None, - batch_size=1 - if optim_args.quantization_approach == "weight_only" - else training_args.per_device_train_batch_size, + calibration_dataset=( + train_dataset if optim_args.quantization_approach in ["static", "weight_only"] else None + ), + batch_size=( + 1 if optim_args.quantization_approach == "weight_only" else training_args.per_device_train_batch_size + ), ) trainer.model = quantizer._quantized_model if optim_args.apply_quantization and optim_args.verify_loading: - loaded_model = INCModelForCausalLM.from_pretrained(training_args.output_dir) + if optim_args.quantization_approach == "weight_only": + loaded_model = ITREXAutoModelForCausalLM.from_pretrained(training_args.output_dir) + else: + loaded_model = INCModelForCausalLM.from_pretrained(training_args.output_dir) tokens = tokenizer("This is a sample input", return_tensors="pt") with torch.no_grad(): original_model_outputs = trainer.model(**tokens) diff --git a/examples/neural_compressor/text-generation/run_generation.py b/examples/neural_compressor/text-generation/run_generation.py index e06bba4102..8b1adbd3f8 100755 --- a/examples/neural_compressor/text-generation/run_generation.py +++ b/examples/neural_compressor/text-generation/run_generation.py @@ -368,9 +368,9 @@ def calibration_fn(p_model): args.length = adjust_length_to_model( args.length, - max_sequence_length=model.config.max_position_embeddings - if hasattr(model.config, "max_position_embeddings") - else 0, + max_sequence_length=( + model.config.max_position_embeddings if hasattr(model.config, "max_position_embeddings") else 0 + ), ) logger.info(args) diff --git a/optimum/intel/neural_compressor/__init__.py b/optimum/intel/neural_compressor/__init__.py index a7170120b7..f3a7bffe69 100644 --- a/optimum/intel/neural_compressor/__init__.py +++ b/optimum/intel/neural_compressor/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils.import_utils import is_diffusers_available +from ..utils.import_utils import is_diffusers_available, is_intel_extension_for_transformers_available from .configuration import INCConfig from .modeling_base import ( INCModel, @@ -32,3 +32,7 @@ if is_diffusers_available(): from .modeling_diffusion import INCStableDiffusionPipeline + + +if is_intel_extension_for_transformers_available(): + from .modeling_base import ITREXAutoModelForCausalLM diff --git a/optimum/intel/neural_compressor/configuration.py b/optimum/intel/neural_compressor/configuration.py index 0abdc29cd2..7f5370e5ee 100644 --- a/optimum/intel/neural_compressor/configuration.py +++ b/optimum/intel/neural_compressor/configuration.py @@ -35,7 +35,7 @@ class INCConfig(BaseConfig): def __init__( self, - quantization: Optional[Union[Dict, _BaseQuantizationConfig, "WeightOnlyQuantConfig"]] = None, + quantization: Optional[Union[Dict, _BaseQuantizationConfig]] = None, pruning: Optional[Union[Dict, _BaseQuantizationConfig]] = None, distillation: Optional[Union[Dict, _BaseQuantizationConfig]] = None, save_onnx_model: bool = False, @@ -50,7 +50,7 @@ def __init__( self.save_onnx_model = save_onnx_model @staticmethod - def _create_quantization_config(config): + def _create_quantization_config(config: Union[Dict, _BaseQuantizationConfig]): # TODO : add activations_dtype and weights_dtype if isinstance(config, _BaseQuantizationConfig): approach = _quantization_model[config.approach] diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index 72646a9f94..0226855d64 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -43,7 +43,7 @@ from optimum.intel.generation import BaseModelForCausalLM from ...modeling_base import OptimizedModel -from ..utils.import_utils import _torch_version, is_torch_version +from ..utils.import_utils import _torch_version, is_intel_extension_for_transformers_available, is_torch_version from .configuration import INCConfig from .utils import WEIGHTS_NAME @@ -63,6 +63,14 @@ """ +if is_intel_extension_for_transformers_available(): + from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM as ITREX_WOQ_MODEL + + class ITREXAutoModelForCausalLM(ITREX_WOQ_MODEL): + auto_model_class = AutoModelForCausalLM + export_feature = "text-generation" + + class INCModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 3207ff43dd..7b294a55ec 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -15,6 +15,7 @@ import copy import inspect import logging +import types import warnings from enum import Enum from itertools import chain @@ -79,6 +80,7 @@ if is_intel_extension_for_transformers_available(): from intel_extension_for_transformers.llm.quantization.utils import convert_to_quantized_model + from intel_extension_for_transformers.transformers.modeling.modeling_auto import save_low_bit from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig Config = Union[PostTrainingQuantConfig, WeightOnlyQuantConfig] @@ -185,6 +187,9 @@ def quantize( save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) save_onnx_model = kwargs.pop("save_onnx_model", False) + device = kwargs.pop("device", "cpu") + use_cpu = True if device == torch.device("cpu") or device == "cpu" else False + use_xpu = True if (isinstance(device, torch.device) and device.type == "xpu") or device == "xpu" else False if save_onnx_model and (isinstance(self._original_model, ORTModel) or weight_only): save_onnx_model = False @@ -217,7 +222,10 @@ def quantize( f"For weight-only quantization, `quantization_config` should be an instance of `WeightOnlyQuantConfig`, but got: {type(quantization_config)} instead." ) - if calibration_dataset is None and ("GPTQ" in algo or "AWQ" in algo): + if algo not in ["RTN", "GPTQ"]: + raise ValueError("Weight-only quantization is only support RTN and GPTQ algorithm now!") + + if calibration_dataset is None and quantization_config.tokenizer is None and ("GPTQ" in algo): raise ValueError( "Weight-only quantization needs a calibration dataset for both GPTQ and AWQ methodologies." ) @@ -278,10 +286,24 @@ def quantize( ) if not isinstance(quantization_config, PostTrainingQuantConfig): - self._quantized_model = convert_to_quantized_model(self._original_model, quantization_config) + if use_cpu: + # will remove after intel-extension-for-transformers 1.3.3 released + quantization_config.device = "cpu" + quantization_config.post_init() + elif use_xpu: + # will remove after intel-extension-for-transformers 1.3.3 released + quantization_config.device = "xpu" + quantization_config.post_init_xpu() + self._quantized_model = convert_to_quantized_model( + self._original_model, quantization_config, device=quantization_config.device + ) + # will remove after intel-extension-for-transformers 1.3.3 released + if hasattr(quantization_config, "calib_dataloader"): + quantization_config.calib_dataloader = None + self._quantized_model.quantization_config = quantization_config + self._quantized_model.save_pretrained = types.MethodType(save_low_bit, self._quantized_model) # Save the quantized model - output_path = save_directory.joinpath(file_name or default_name) - self._quantized_model.save_pretrained(output_path) + self._quantized_model.save_pretrained(save_directory) else: if isinstance(self._original_model.config, PretrainedConfig): self._original_model.config.backend = quantization_config.backend diff --git a/setup.py b/setup.py index 49b7a92673..a59721450f 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ "rjieba", "timm", "invisible-watermark>=0.2.0", - "cmake>=3.16", + # Will remove after intel-extension-for-transformers 1.3.3 released. "intel-extension-for-transformers>=1.3", "peft", "auto-gptq", diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 0272892096..260cb97270 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -45,7 +45,7 @@ set_seed, ) from utils_tests import SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import is_torch_version, is_intel_extension_for_transformers_available from optimum.intel import ( @@ -60,11 +60,13 @@ INCSeq2SeqTrainer, INCStableDiffusionPipeline, ) -from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig from optimum.intel.utils.constant import DIFFUSION_WEIGHTS_NAME from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification from optimum.pipelines import ORT_SUPPORTED_TASKS +if is_intel_extension_for_transformers_available(): + from optimum.intel.neural_compressor import ITREXAutoModelForCausalLM + from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig os.environ["CUDA_VISIBLE_DEVICES"] = "" set_seed(SEED) @@ -200,20 +202,24 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_name, expec load_ipex_model=True, ) + @unittest.skipIf( + not is_intel_extension_for_transformers_available(), reason="Intel-extension-for-transformers not available!" + ) def test_weight_only_quantization(self): model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) with tempfile.TemporaryDirectory() as tmp_dir: quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") + calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) quantization_config = WeightOnlyQuantConfig(weight_dtype="int8") q_model = quantizer.quantize( quantization_config=quantization_config, save_directory=tmp_dir, ) + q_model = ITREXAutoModelForCausalLM.from_pretrained(tmp_dir) inp = torch.tensor([calibration_dataset[0]["input_ids"]]) out = model(inp)[0] q_out = q_model(inp)[0] @@ -221,8 +227,14 @@ def test_weight_only_quantization(self): with tempfile.TemporaryDirectory() as tmp_dir: quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") + calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) quantization_config = WeightOnlyQuantConfig( algorithm="GPTQ", + algorithm_args={ + "percdamp": 0.01, + "act_order": False, + "scheme": "sym", + }, weight_dtype="int4_clip", ) q_model = quantizer.quantize( @@ -230,6 +242,7 @@ def test_weight_only_quantization(self): calibration_dataset=calibration_dataset, save_directory=tmp_dir, ) + q_model = ITREXAutoModelForCausalLM.from_pretrained(tmp_dir) inp = torch.tensor([calibration_dataset[0]["input_ids"]]) out = model(inp)[0] q_out = q_model(inp)[0] @@ -237,26 +250,12 @@ def test_weight_only_quantization(self): with tempfile.TemporaryDirectory() as tmp_dir: quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") - quantization_config = WeightOnlyQuantConfig( - algorithm="AWQ", - weight_dtype="int4_clip", - ) - q_model = quantizer.quantize( - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - save_directory=tmp_dir, - ) - inp = torch.tensor([calibration_dataset[0]["input_ids"]]) - out = model(inp)[0] - q_out = q_model(inp)[0] - self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1))) - - with tempfile.TemporaryDirectory() as tmp_dir: - quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") + calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) q_model = quantizer.quantize( weight_only=True, # use RTN quantization method and NF4 weight data type is default. save_directory=tmp_dir, ) + q_model = ITREXAutoModelForCausalLM.from_pretrained(tmp_dir) inp = torch.tensor([calibration_dataset[0]["input_ids"]]) out = model(inp)[0] q_out = q_model(inp)[0] diff --git a/tests/openvino/test_modeling_basic.py b/tests/openvino/test_modeling_basic.py index a443c5fea7..9423ce5683 100644 --- a/tests/openvino/test_modeling_basic.py +++ b/tests/openvino/test_modeling_basic.py @@ -7,6 +7,7 @@ This test is meant to run quickly with tiny test models. More extensive tests are in test_modeling.py. """ + # ruff: noqa import gc