Skip to content

Commit

Permalink
Fixed load issue for woq model and update docs
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
PenghuiCheng committed Apr 17, 2024
1 parent 0c44e0b commit 31a3e53
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 17 deletions.
27 changes: 27 additions & 0 deletions docs/source/optimization_inc.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,33 @@ mpirun -np <number_of_processes> <RUN_CMD>

Please refer to INC [documentation](https://github.com/intel/neural-compressor/blob/master/docs/source/tuning_strategies.md#distributed-tuning) and [text-classification](https://github.com/huggingface/optimum-intel/tree/main/examples/neural_compressor/text-classification) example for more details.

## Weight-only quantization
As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. Up to now, we support "GPTQ" and "RTN" method.

```python
from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig
# for GPTQ method
quantization_config = GPTQConfig(
damp_percent=0.01,
weight_dtype="int4_clip",
)

# for RTN method
quantization_config = RtnConfig(
weight_dtype="int4_clip",
)
quantizer = INCQuantizer.from_pretrained(model)
quantizer.quantize(
quantization_config=quantization_config,
save_directory="output_dir",
calibration_dataset=(
train_dataset if quantization_config.quant_metod == "gptq" else None
),
)
q_model = quantizer._quantized_model

```
Please refer to [example](https://github.com/huggingface/optimum-intel/tree/main/examples/neural_compressor/text-generation).

## During training optimization

Expand Down
12 changes: 11 additions & 1 deletion examples/neural_compressor/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Based on the script [`run_generation.py`](https://github.com/huggingface/transfo

The original generation task only supported the PyTorch eager model. By calling the `TSModelForCausalLM` class, we can now support a TorchScript model for generation tasks.

This example also allows us to apply different quantization approaches (such as dynamic, static, The example applies post-training static quantization on a gptj model).
This example also allows us to apply different quantization approaches (such as dynamic, static, weight-only and aware-training quantization. The example applies post-training static quantization on a gptj model).

Example usage:
### apply_quantization with post-training static
Expand All @@ -45,3 +45,13 @@ python run_generation.py \
--smooth_quant_alpha 0.7 \
--jit
```

### apply_quantization with weight-only quantization
As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy. Up to now, we support "GPTQ" and "RTN" method.
```bash
python run_generation.py \
--model_type=gptj \
--model_name_or_path=EleutherAI/gpt-j-6b \
--apply_quantization \
--quantization_approach weight_only\
```
118 changes: 108 additions & 10 deletions examples/neural_compressor/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
)

from optimum.intel.neural_compressor import INCModelForCausalLM, INCQuantizer
from optimum.intel.utils.import_utils import (
INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR,
is_intel_extension_for_transformers_available,
)


if is_intel_extension_for_transformers_available():
from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig


logging.basicConfig(
Expand Down Expand Up @@ -281,6 +289,69 @@ def main():
)
parser.add_argument("--dataset_name", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k")
parser.add_argument("--calib_iters", default=100, type=int, help="calibration iters.")
parser.add_argument(
"--bits",
default="4",
type=str,
help="Bits number of weight for weight only quantization. 1~8 bits.",
)
parser.add_argument(
"--weight_dtype",
default="int4_clip",
type=str,
help="weight dtype for weight only quantization.",
)
parser.add_argument(
"--group_size",
default=32,
type=int,
help="Group size for weight only quantization. Group_size=[1-N] indicates "
"splitting the input channel elements per group_size. -1 indicates "
"the per-channel quantization per output channel.",
)
parser.add_argument(
"--weight_only_scheme",
default="sym",
type=str,
help="Scheme for weight only quantization. Choose from 'sym' and 'asym'.",
)
parser.add_argument(
"--quantization_methodology",
choices=["rtn", "gptq"],
default="rtn",
type=str,
help="Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'.",
)
parser.add_argument(
"--damp_percent",
default=0.01,
type=float,
help="Percentage of Hessian's diagonal values average, which will be added to Hessian's diagonal to increase numerical stability, used for GPTQ quantization",
)
parser.add_argument(
"--gptq_block_size",
default=128,
type=int,
help="Block size. sub weight matrix size to run GPTQ.",
)
parser.add_argument(
"--num_calibration_samples",
default=128,
type=int,
help="Number of examples to use for the GPTQ calibration step."
)
parser.add_argument(
"--use_max_length",
default=False,
type=bool,
help="Set all sequence length to be same length of args.gptq_pad_max_length",
)
parser.add_argument(
"--pad_max_length",
default=2048,
type=int,
help="Calibration dataset sequence max length, this should align with your model config",
)
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
Expand Down Expand Up @@ -313,6 +384,43 @@ def main():
model.to(args.device)

if args.apply_quantization:
supported_approach = {"static", "dynamic", "weight_only"}
if args.quantization_approach not in supported_approach:
raise ValueError(
f"Unknown quantization approach. Supported approach are {supported_approach}."
f"{args.quantization_approach} was given."
)
if args.quantization_approach == "weight_only":
if not is_intel_extension_for_transformers_available():
raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("WeightOnly quantization"))

algorithm_args = {
"weight_dtype": args.weight_dtype,
"sym": args.weight_only_scheme == "sym",
"group_size": args.group_size,
}

if args.quantization_methodology == "gptq":
quantization_config = GPTQConfig(
damp_percent=args.damp_percent,
nsamples=args.num_calibration_samples,
blocksize=args.gptq_block_size,
**algorithm_args,
)
else:
quantization_config = RtnConfig(**algorithm_args)

else:
example_inputs = {"input_ids": torch.randint(100, (1, 32)), "attention_mask": torch.ones(1, 32)}
quantization_config = PostTrainingQuantConfig(
approach=args.quantization_approach,
recipes={
"smooth_quant": args.smooth_quant,
"smooth_quant_args": {"alpha": args.smooth_quant_alpha, "folding": True},
},
example_inputs=example_inputs,
)
model.config.return_dict = False
# This is just an example for calibration_fn. If you want to achieve good accuracy,
# you must perform a calibration on your real dataset.
calib_dataset = load_dataset(args.dataset_name, split="train")
Expand Down Expand Up @@ -347,16 +455,6 @@ def calibration_fn(p_model):
do_sample=False,
)

example_inputs = {"input_ids": torch.randint(100, (1, 32)), "attention_mask": torch.ones(1, 32)}
quantization_config = PostTrainingQuantConfig(
approach=args.quantization_approach,
recipes={
"smooth_quant": args.smooth_quant,
"smooth_quant_args": {"alpha": args.smooth_quant_alpha, "folding": True},
},
example_inputs=example_inputs,
)
model.config.return_dict = False
quantizer = INCQuantizer.from_pretrained(model, calibration_fn=calibration_fn)
with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
Expand Down
14 changes: 8 additions & 6 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ def _from_pretrained(

return _BaseQBitsAutoModelClass.from_pretrained(
pretrained_model_name_or_path=model_id,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
subfolder=subfolder,
# The following parameters are not supported in itrex1.4 version and will be supported in the next version
# use_auth_token=use_auth_token,
# revision=revision,
# force_download=force_download,
# cache_dir=cache_dir,
# local_files_only=local_files_only,
# subfolder=subfolder,
trust_remote_code=trust_remote_code,
use_neural_speed=False,
**kwargs,
)
except EnvironmentError:
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def quantize(
)

self._quantized_model.quantization_config = quantization_config
self._quantized_model.config.quantization_config = quantization_config
self._quantized_model.save_pretrained = types.MethodType(save_low_bit, self._quantized_model)
self._quantized_model.save_pretrained(save_directory)

Expand Down

0 comments on commit 31a3e53

Please sign in to comment.