Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into jiwaszki/fix_sharin…
Browse files Browse the repository at this point in the history
…g_warnings
  • Loading branch information
jiwaszki committed Sep 19, 2023
2 parents 1eb70cc + 4b8ed24 commit 63317aa
Show file tree
Hide file tree
Showing 23 changed files with 1,235 additions and 122 deletions.
9 changes: 6 additions & 3 deletions examples/neural_compressor/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

The scripts [`run_clm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_clm.py)
and [`run_mlm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_mlm.py)
allow us to apply different quantization approaches (such as dynamic, static and aware-training quantization) as well as pruning
allow us to apply different quantization approaches (such as dynamic, static, weight-only and aware-training quantization) as well as pruning
using the [Intel Neural Compressor ](https://github.com/intel/neural-compressor) library for language modeling tasks.

The SmoothQuant methodology is also available for post-training quantization.
Expand Down Expand Up @@ -67,6 +67,7 @@ python run_clm.py \
--do_eval \
--verify_loading \
--output_dir /tmp/clm_output
```

### RoBERTa/BERT/DistilBERT and masked language modeling

Expand All @@ -91,7 +92,9 @@ python run_mlm.py \
--output_dir /tmp/mlm_output
```

In order to apply dynamic, static or aware-training quantization, `quantization_approach` must be set to
respectively `dynamic`, `static` or `aware_training`.
In order to apply dynamic, static, weight-only or aware-training quantization, `quantization_approach` must be set to
respectively `dynamic`, `static`, `weight_only` or `aware_training`.

The flag `--verify_loading` can be passed along to verify that the resulting quantized model can be loaded correctly.

> **_Note:_** `weight_only` quantization_approach requires neural-compressor >= 2.3
63 changes: 56 additions & 7 deletions examples/neural_compressor/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,28 @@ class OptimizationArguments:
default=False,
metadata={"help": "Whether or not to verify the loading of the quantized model."},
)
bits: int = field(
default=8,
metadata={"help": "Bits for weight only quantization, 1-8 bits."},
)
group_size: int = field(
default=-1,
metadata={
"help": "Group size for weight only quantization. Group_size=[1-N] indicates "
"splitting the input channel elements per group_size. -1 indicates "
"the per-channel quantization per output channel."
},
)
weight_only_scheme: str = field(
default="sym",
metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."},
)
quantization_methodology: str = field(
default="RTN",
metadata={
"help": "Quantization methodology for weight only quantization. Choose from 'RTN', 'AWQ' and 'GPTQ'."
},
)


@dataclass
Expand Down Expand Up @@ -539,7 +561,9 @@ def group_texts(examples):
desc=f"Grouping texts in chunks of {block_size}",
)

if training_args.do_train or (optim_args.apply_quantization and optim_args.quantization_approach == "static"):
if training_args.do_train or (
optim_args.apply_quantization and optim_args.quantization_approach in ["static", "weight_only"]
):
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
Expand Down Expand Up @@ -587,7 +611,7 @@ def compute_metrics(eval_preds):
raise ValueError("`do_train` must be set to True.")

if optim_args.apply_quantization:
supported_approach = {"static", "dynamic", "aware_training"}
supported_approach = {"static", "dynamic", "aware_training", "weight_only"}
if optim_args.quantization_approach not in supported_approach:
raise ValueError(
f"Unknown quantization approach. Supported approach are {supported_approach}."
Expand All @@ -600,7 +624,27 @@ def compute_metrics(eval_preds):
recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": optim_args.smooth_quant_alpha}}
else:
recipes = {}
quantization_config = PostTrainingQuantConfig(approach=optim_args.quantization_approach, recipes=recipes)
if optim_args.quantization_approach == "weight_only":
op_type_dict = {
".*": {
"weight": {
"bits": optim_args.bits,
"group_size": optim_args.group_size,
"scheme": optim_args.weight_only_scheme,
"algorithm": optim_args.quantization_methodology,
},
},
}
if optim_args.quantization_methodology == "GPTQ":
gptq_args = {
"pad_max_length": block_size,
}
recipes.update({"gptq_args": gptq_args})
else:
op_type_dict = {}
quantization_config = PostTrainingQuantConfig(
approach=optim_args.quantization_approach, op_type_dict=op_type_dict, recipes=recipes
)

if optim_args.apply_pruning:
if optim_args.end_step is None:
Expand Down Expand Up @@ -677,19 +721,24 @@ def compute_metrics(eval_preds):
trainer.save_metrics("train", metrics)
trainer.save_state()

if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic"}:
if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic", "weight_only"}:
model = trainer.model if isinstance(trainer.model, PreTrainedModel) else trainer.model._model
quantizer = INCQuantizer.from_pretrained(model)
if optim_args.quantization_approach == "static":
if optim_args.quantization_approach in ["static", "weight_only"]:
num_calibration_samples = min(len(train_dataset), optim_args.num_calibration_samples)
train_dataset = train_dataset.select(range(num_calibration_samples))
quantization_config.calibration_sampling_size = num_calibration_samples

quantizer.quantize(
quantization_config=quantization_config,
save_directory=training_args.output_dir,
calibration_dataset=train_dataset if optim_args.quantization_approach == "static" else None,
batch_size=training_args.per_device_train_batch_size,
calibration_dataset=train_dataset
if optim_args.quantization_approach in ["static", "weight_only"]
else None,
batch_size=1 # batch_size > 1 for GPTQ is WIP
if optim_args.quantization_approach == "weight_only" and optim_args.quantization_methodology == "GPTQ"
else training_args.per_device_train_batch_size,
weight_only=True if optim_args.quantization_approach == "weight_only" else False,
)
trainer.model = quantizer._quantized_model
if optim_args.apply_quantization and optim_args.verify_loading:
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .__main__ import main_export
from .convert import export, export_models, export_pytorch_via_onnx


__all__ = ["main_export", "export", "export_models"]
Loading

0 comments on commit 63317aa

Please sign in to comment.