diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index db35324a9..755236ebe 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -50,6 +50,11 @@ jobs: name: Install specific dependencies and versions required for older transformers run: | pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.* diffusers==0.30.* transformers_stream_generator + + - if: ${{ matrix.transformers-version == 'latest' && matrix.test-pattern == '*modeling*'}} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu - if: ${{ matrix.test-pattern == '*modeling*' }} name: Uninstall NNCF diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index 8c3d9b2d3..a4e8a046b 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -49,6 +49,11 @@ jobs: name: Install specific dependencies and versions required for older transformers run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator + - if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu + - name: Pip freeze run: pip freeze diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 592cd85a4..3015b20e8 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -232,6 +232,7 @@ def main_export( ) do_gptq_patching = False + do_quant_patching = False custom_architecture = False patch_16bit = False loading_kwargs = model_loading_kwargs or {} @@ -247,7 +248,11 @@ def main_export( trust_remote_code=trust_remote_code, ) quantization_config = getattr(config, "quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + supported_quant_methods = ["gptq"] + if is_openvino_version(">=", "2024.6.0"): + supported_quant_methods.append("awq") + do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods + do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq" model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True @@ -296,7 +301,6 @@ def main_export( if ( dtype is None and framework == "pt" - and not do_gptq_patching and ( task.startswith("text-generation") or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS @@ -315,28 +319,28 @@ def main_export( patch_16bit = True loading_kwargs["torch_dtype"] = dtype # Patch the modules to export of GPTQ models w/o GPU - if do_gptq_patching: - torch.set_default_dtype(torch.float32) + if do_quant_patching: orig_cuda_check = torch.cuda.is_available torch.cuda.is_available = lambda: True - from optimum.gptq import GPTQQuantizer + if do_gptq_patching: + from optimum.gptq import GPTQQuantizer - orig_post_init_model = GPTQQuantizer.post_init_model + orig_post_init_model = GPTQQuantizer.post_init_model - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length - class StoreAttr(object): - pass + class StoreAttr(object): + pass - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model - GPTQQuantizer.post_init_model = post_init_model + GPTQQuantizer.post_init_model = post_init_model elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"): dtype = deduce_diffusers_dtype( model_name_or_path, @@ -485,9 +489,10 @@ class StoreAttr(object): compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) # Unpatch modules after GPTQ export - if do_gptq_patching: + if do_quant_patching: torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model + if do_gptq_patching: + GPTQQuantizer.post_init_model = orig_post_init_model def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None): diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index c9e18cff6..0edbe48cd 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -450,7 +450,11 @@ def ts_patched_forward(*args, **kwargs): from openvino.frontend.pytorch.patch_model import unpatch_model unpatch_model(model, "_openvino_module_extension_patch_orig_forward") - model.to(torch.float32) + for m in model.modules(): + if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any( + b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False) + ): + m.float() return export_pytorch_via_onnx( model, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 240f4f9e3..8a3e90753 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -15,6 +15,7 @@ import copy import gc import os +import platform import tempfile import time import unittest @@ -62,7 +63,7 @@ ) from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import slow -from utils_tests import MODEL_NAMES, TEST_IMAGE_URL +from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference from optimum.exporters.openvino.model_patcher import patch_update_causal_mask from optimum.intel import ( @@ -872,7 +873,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neo", "gpt_neox", "llama", - # "llama_gptq", "marian", "minicpm", "mistral", @@ -917,6 +917,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "minicpm3", ) + # gptq and awq install disabled for windows test environment + if platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("opt_gptq",) + + # autoawq install disabled for windows test environment + if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("mixtral_awq",) + GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm", @@ -949,9 +957,6 @@ def test_compare_to_transformers(self, model_arch): if is_openvino_version("<", "2024.1"): not_stateful.extend(["llama", "gemma", "gpt_bigcode"]) - if "gptq" in model_arch: - self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") - set_seed(SEED) model_kwargs = {} @@ -978,20 +983,30 @@ def test_compare_to_transformers(self, model_arch): if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + if "awq" in model_arch or "gptq" in model_arch: + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch in ["qwen", "arctic", "glm4"]: transformers_model.to(torch.float32) with torch.no_grad(): - transformers_outputs = transformers_model(**tokens) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model(**tokens) # Compare tensor outputs atol = 1e-3 if model_arch == "minicpm" else 1e-4 + # quantized models have higher tolerance + if "awq" in model_arch: + atol = 1e-2 + elif "gptq" in model_arch: + atol = 0.6 self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) # Qwen tokenizer does not support padding - if model_arch in ["qwen"]: return @@ -1025,7 +1040,12 @@ def test_compare_to_transformers(self, model_arch): from transformers.cache_utils import DynamicCache additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) + print(f"ov_outputs: {ov_outputs}") + print(f"transformers_outputs: {transformers_outputs}") self.assertTrue( torch.allclose(ov_outputs, transformers_outputs), "OV output {ov_outputs}\nTransformers output {transformers_output}", @@ -1261,8 +1281,13 @@ def test_beam_search(self, model_arch): ov_model_stateless = OVModelForCausalLM.from_pretrained( model_id, export=True, use_cache=True, stateful=False, **model_kwargs ) + if "awq" in model_arch or "gptq" in model_arch: + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch == "arctic": transformers_model.to(torch.float32) @@ -1288,9 +1313,10 @@ def test_beam_search(self, model_arch): if model_arch == "gemma2": additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate( - **tokens, generation_config=gen_config, **additional_inputs - ) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) set_seed(SEED) ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) self.assertTrue( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index bf509a044..ba37797db 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + import numpy as np import openvino as ov import torch @@ -77,12 +79,12 @@ "longt5": "hf-internal-testing/tiny-random-longt5", "llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM", "llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM", - "llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ", "llava": "katuni4ka/tiny-random-llava", "llava_next": "katuni4ka/tiny-random-llava-next", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", "opt125m": "facebook/opt-125m", + "opt_gptq": "ybelkada/opt-125m-gptq-4bit", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", @@ -91,6 +93,7 @@ "mistral": "echarlaix/tiny-random-mistral", "mistral-nemo": "katuni4ka/tiny-random-mistral-nemo", "mixtral": "TitanML/tiny-mixtral", + "mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -218,3 +221,58 @@ def get_num_quantized_nodes(model): if type_name == "nf4": num_weight_nodes["nf4"] += 1 return num_fake_quantize, num_weight_nodes + + +@contextmanager +def mock_torch_cuda_is_available(to_patch): + original_is_available = torch.cuda.is_available + if to_patch: + torch.cuda.is_available = lambda: True + try: + yield + finally: + if to_patch: + torch.cuda.is_available = original_is_available + + +@contextmanager +def patch_awq_for_inference(to_patch): + orig_gemm_forward = None + if to_patch: + # patch GEMM module to allow inference without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + try: + yield + finally: + if orig_gemm_forward is not None: + WQLinearMMFunction.forward = orig_gemm_forward