Skip to content

Commit

Permalink
Introduce support for NF4 data type for OV weight compression (#988)
Browse files Browse the repository at this point in the history
* Add NF4 weight format

* remove test

* Update optimum/intel/openvino/configuration.py

Co-authored-by: Nikita Savelyev <[email protected]>

* Update optimum/intel/openvino/configuration.py

Co-authored-by: Nikita Savelyev <[email protected]>

* Add extra checks

* apply black

---------

Co-authored-by: Nikita Savelyev <[email protected]>
  • Loading branch information
l-bat and nikita-savelyevv authored Nov 13, 2024
1 parent 5c879b9 commit 41637d0
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
optional_group.add_argument(
"--weight-format",
type=str,
choices=["fp32", "fp16", "int8", "int4", "mxfp4"],
choices=["fp32", "fp16", "int8", "int4", "mxfp4", "nf4"],
default=None,
help="The weight format of the exported model.",
)
Expand Down
18 changes: 10 additions & 8 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and
compressed layers. Providing a dataset is required to run scale estimation.
weight_format (`str`, defaults to 'int'):
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4'].
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4'].
qptq (`bool`, *optional*):
Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the
difference between activations of a compressed and original layer. Dataset is required to run GPTQ.
Expand Down Expand Up @@ -455,20 +455,22 @@ def post_init(self):

if self.weight_format is None:
self.weight_format = "int4" if self.bits == 4 else "int8"
if self.weight_format not in ["int4", "int8", "mxfp4"]:
if self.weight_format not in ["int4", "int8", "mxfp4", "nf4"]:
raise ValueError(
f"Weight format must be one of the following: ['int4', 'int8', 'mxfp4'], but found: {self.weight_format}."
f"Weight format must be one of the following: ['int4', 'int8', 'mxfp4', 'nf4'], but found: {self.weight_format}."
)
if self.weight_format == "mxfp4":
if self.weight_format in ["mxfp4", "nf4"]:
if self.bits != 4:
raise ValueError(
f"When applying weight compression with 'mxfp4' weight format the `bits` parameters must be set to 4, but found {self.bits}"
f"When applying weight compression with '{self.weight_format}' weight format, the `bits` parameter must be set to 4, but found {self.bits}"
)
if self.quant_method == OVQuantizationMethod.AWQ:
raise ValueError("The AWQ algorithm is not supported for 'mxfp4' weight format")
raise ValueError(f"The AWQ algorithm is not supported for '{self.weight_format}' weight format")
if self.scale_estimation:
raise ValueError("The Scale Estimation algorithm is not supported for 'mxfp4' weight format")
if self.gptq:
raise ValueError(
f"The Scale Estimation algorithm is not supported for '{self.weight_format}' weight format"
)
if self.weight_format == "mxfp4" and self.gptq:
raise ValueError("The GPTQ algorithm is not supported for 'mxfp4' weight format")


Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ def _weight_only_quantization(

if config.weight_format == "mxfp4":
mode = CompressWeightsMode.E2M1
elif config.weight_format == "nf4":
mode = CompressWeightsMode.NF4
else:
if config.bits == 8:
mode = CompressWeightsMode.INT8_SYM if config.sym else CompressWeightsMode.INT8_ASYM
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class OVCLIExportTestCase(unittest.TestCase):
("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", {"int8": 4, "int4": 72}),
("text-generation-with-past", "opt125m", "int4 --group-size 64", {"int8": 4, "int4": 144}),
("text-generation-with-past", "opt125m", "mxfp4", {"int8": 4, "f4e2m1": 72, "f8e8m0": 72}),
("text-generation-with-past", "opt125m", "nf4", {"int8": 4, "nf4": 72}),
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 8 --all-layers", {"int4": 16}),
(
"text-generation-with-past",
Expand Down Expand Up @@ -267,7 +268,7 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
self.assertEqual(exp_num_fq, num_fq)

@parameterized.expand(TEST_4BIT_CONFIGURATIONS)
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_num_weight_nodes: dict):
def test_exporters_cli_4bit(self, task: str, model_type: str, option: str, expected_num_weight_nodes: dict):
with TemporaryDirectory() as tmpdir:
result = subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
Expand Down
7 changes: 7 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ class OVWeightCompressionTest(unittest.TestCase):
dict(bits=4, weight_format="mxfp4", group_size=32),
{"f4e2m1": 20, "f8e8m0": 20, "int8": 4},
),
(
OVModelForCausalLM,
"gpt2",
False,
dict(bits=4, weight_format="nf4", group_size=32),
{"nf4": 20, "int8": 4},
),
(
OVModelForCausalLM,
"gpt2",
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def get_num_quantized_nodes(model):
"int4": 0,
"f4e2m1": 0,
"f8e8m0": 0,
"nf4": 0,
}
ov_model = model if isinstance(model, ov.Model) else model.model
for elem in ov_model.get_ops():
Expand All @@ -210,4 +211,6 @@ def get_num_quantized_nodes(model):
num_weight_nodes["f4e2m1"] += 1
if type_name == "f8e8m0":
num_weight_nodes["f8e8m0"] += 1
if type_name == "nf4":
num_weight_nodes["nf4"] += 1
return num_fake_quantize, num_weight_nodes

0 comments on commit 41637d0

Please sign in to comment.