Skip to content

Commit

Permalink
Merge pull request #1 from eaidova/ea/awq_fix
Browse files Browse the repository at this point in the history
enable awq export only if ov support it
  • Loading branch information
eaidova authored Dec 17, 2024
2 parents 9fb1da4 + 04d0cf9 commit b51cdee
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 4 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def main_export(
trust_remote_code=trust_remote_code,
)
quantization_config = getattr(config, "quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] in ["gptq", "awq"]
supported_quant_methods = ["gptq"]
if is_openvino_version(">=", "2024.6.0"):
supported_quant_methods.append("awq")
do_gptq_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
Expand Down
8 changes: 6 additions & 2 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"minicpm",
"mistral",
"mixtral",
"mixtral_awq",
"mpt",
"opt",
"opt_gptq",
Expand Down Expand Up @@ -918,6 +917,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"minicpm3",
)

if is_openvino_version(">=", "2024.6.0"):
SUPPORTED_ARCHITECTURES += ("mixtral_awq",)

GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
"chatglm",
Expand Down Expand Up @@ -1034,7 +1036,9 @@ def test_compare_to_transformers(self, model_arch):

additional_inputs = {"past_key_values": DynamicCache()}
with patch_awq_for_inference("awq" in model_arch):
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs)
transformers_outputs = transformers_model.generate(
**tokens, generation_config=gen_config, **additional_inputs
)
print(f"ov_outputs: {ov_outputs}")
print(f"transformers_outputs: {transformers_outputs}")
self.assertTrue(
Expand Down

0 comments on commit b51cdee

Please sign in to comment.