Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate weight-only quantizaion of INC #417

Merged
merged 10 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/neural_compressor/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

The scripts [`run_clm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_clm.py)
and [`run_mlm.py`](https://github.com/huggingface/optimum-intel/blob/main/examples/neural_compressor/language-modeling/run_mlm.py)
allow us to apply different quantization approaches (such as dynamic, static and aware-training quantization) as well as pruning
allow us to apply different quantization approaches (such as dynamic, static, weight-only and aware-training quantization) as well as pruning
using the [Intel Neural Compressor ](https://github.com/intel/neural-compressor) library for language modeling tasks.

The SmoothQuant methodology is also available for post-training quantization.
Expand Down Expand Up @@ -67,6 +67,7 @@ python run_clm.py \
--do_eval \
--verify_loading \
--output_dir /tmp/clm_output
```

### RoBERTa/BERT/DistilBERT and masked language modeling

Expand All @@ -91,7 +92,9 @@ python run_mlm.py \
--output_dir /tmp/mlm_output
```

In order to apply dynamic, static or aware-training quantization, `quantization_approach` must be set to
respectively `dynamic`, `static` or `aware_training`.
In order to apply dynamic, static, weight-only or aware-training quantization, `quantization_approach` must be set to
respectively `dynamic`, `static`, `weight_only` or `aware_training`.

The flag `--verify_loading` can be passed along to verify that the resulting quantized model can be loaded correctly.

> **_Note:_** `weight_only` quantization_approach requires neural-compressor >= 2.3
63 changes: 56 additions & 7 deletions examples/neural_compressor/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,28 @@ class OptimizationArguments:
default=False,
metadata={"help": "Whether or not to verify the loading of the quantized model."},
)
bits: int = field(
default=8,
metadata={"help": "Bits for weight only quantization, 1-8 bits."},
)
group_size: int = field(
default=-1,
metadata={
"help": "Group size for weight only quantization. Group_size=[1-N] indicates "
"splitting the input channel elements per group_size. -1 indicates "
"the per-channel quantization per output channel."
},
)
weight_only_scheme: str = field(
default="sym",
metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."},
)
quantization_methodology: str = field(
default="RTN",
metadata={
"help": "Quantization methodology for weight only quantization. Choose from 'RTN', 'AWQ' and 'GPTQ'."
},
)


@dataclass
Expand Down Expand Up @@ -539,7 +561,9 @@ def group_texts(examples):
desc=f"Grouping texts in chunks of {block_size}",
)

if training_args.do_train or (optim_args.apply_quantization and optim_args.quantization_approach == "static"):
if training_args.do_train or (
optim_args.apply_quantization and optim_args.quantization_approach in ["static", "weight_only"]
):
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
Expand Down Expand Up @@ -587,7 +611,7 @@ def compute_metrics(eval_preds):
raise ValueError("`do_train` must be set to True.")

if optim_args.apply_quantization:
supported_approach = {"static", "dynamic", "aware_training"}
supported_approach = {"static", "dynamic", "aware_training", "weight_only"}
if optim_args.quantization_approach not in supported_approach:
raise ValueError(
f"Unknown quantization approach. Supported approach are {supported_approach}."
Expand All @@ -600,7 +624,27 @@ def compute_metrics(eval_preds):
recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": optim_args.smooth_quant_alpha}}
else:
recipes = {}
quantization_config = PostTrainingQuantConfig(approach=optim_args.quantization_approach, recipes=recipes)
if optim_args.quantization_approach == "weight_only":
op_type_dict = {
".*": {
"weight": {
"bits": optim_args.bits,
"group_size": optim_args.group_size,
"scheme": optim_args.weight_only_scheme,
"algorithm": optim_args.quantization_methodology,
},
},
}
if optim_args.quantization_methodology == "GPTQ":
gptq_args = {
"pad_max_length": block_size,
}
recipes.update({"gptq_args": gptq_args})
else:
op_type_dict = {}
quantization_config = PostTrainingQuantConfig(
approach=optim_args.quantization_approach, op_type_dict=op_type_dict, recipes=recipes
)

if optim_args.apply_pruning:
if optim_args.end_step is None:
Expand Down Expand Up @@ -677,19 +721,24 @@ def compute_metrics(eval_preds):
trainer.save_metrics("train", metrics)
trainer.save_state()

if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic"}:
if optim_args.apply_quantization and optim_args.quantization_approach in {"static", "dynamic", "weight_only"}:
model = trainer.model if isinstance(trainer.model, PreTrainedModel) else trainer.model._model
quantizer = INCQuantizer.from_pretrained(model)
if optim_args.quantization_approach == "static":
if optim_args.quantization_approach in ["static", "weight_only"]:
num_calibration_samples = min(len(train_dataset), optim_args.num_calibration_samples)
train_dataset = train_dataset.select(range(num_calibration_samples))
quantization_config.calibration_sampling_size = num_calibration_samples

quantizer.quantize(
quantization_config=quantization_config,
save_directory=training_args.output_dir,
calibration_dataset=train_dataset if optim_args.quantization_approach == "static" else None,
batch_size=training_args.per_device_train_batch_size,
calibration_dataset=train_dataset
if optim_args.quantization_approach in ["static", "weight_only"]
else None,
batch_size=1 # batch_size > 1 for GPTQ is WIP
if optim_args.quantization_approach == "weight_only" and optim_args.quantization_methodology == "GPTQ"
else training_args.per_device_train_batch_size,
weight_only=True if optim_args.quantization_approach == "weight_only" else False,
)
trainer.model = quantizer._quantized_model
if optim_args.apply_quantization and optim_args.verify_loading:
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/neural_compressor/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"post_training_dynamic_quant": "dynamic",
"post_training_static_quant": "static",
"quant_aware_training": "aware_training",
"post_training_weight_only": "weight_only",
}


Expand Down
44 changes: 42 additions & 2 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
logger = logging.getLogger(__name__)

NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0"
NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0"
IPEX_MINIMUM_VERSION = "2.1.0"

if is_neural_compressor_version("<", NEURAL_COMPRESSOR_MINIMUM_VERSION):
Expand All @@ -87,6 +88,7 @@ class INCQuantizationMode(Enum):
DYNAMIC = "post_training_dynamic_quant"
STATIC = "post_training_static_quant"
AWARE_TRAINING = "quant_aware_training"
WEIGHT_ONLY = "post_training_weight_only"


SUPPORTED_QUANT_MODE = {approach.value for approach in INCQuantizationMode}
Expand Down Expand Up @@ -142,6 +144,7 @@ def quantize(
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
file_name: str = None,
weight_only: bool = False,
**kwargs,
):
"""
Expand All @@ -160,6 +163,9 @@ def quantize(
The function to use to form a batch from a list of elements of the calibration dataset.
remove_unused_columns (`bool`, defaults to `True`):
Whether or not to remove the columns unused by the model forward method.
weight_only (`bool`, defaults to `False`):
Whether compress weights to integer precision (4-bit by default) while keeping activations
floating-point. Fits best for LLM footprint reduction and performance acceleration.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
Expand All @@ -168,7 +174,40 @@ def quantize(
calibration_dataloader = None
self._set_task()

if INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC:
if weight_only:
# check neural-compressor version
if is_neural_compressor_version("<", NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION):
raise ImportError(
f"Found an incompatible version of neural-compressor. Found version {_neural_compressor_version}, "
f"but only version {NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION} or higher supports weight-only quantization."
)

# If op_type_dict of quantization_config is not defined, it will use default values for weight-only quantization:
# {"bits": 4, "group_size": 32, "scheme": "sym", "algorithm": "RTN"}
if isinstance(quantization_config.op_type_dict, dict) and len(quantization_config.op_type_dict) > 0:
algo = []
for _, val in quantization_config.op_type_dict.items():
algo += val.get("weight", {}).get("algorithm", ["RTN"])
else:
algo = ["RTN"]

if calibration_dataset is None and ("GPTQ" in algo or "AWQ" in algo):
raise ValueError(
"Weight-only quantization needs a calibration dataset for both GPTQ and AWQ methodologies."
)

if calibration_dataset is None:
calibration_dataloader = None
else:
calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
use_label=False if "GPTQ" in algo else True,
)

elif INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC:
# Since PyTorch fx trace does not really require an example_inputs, only need calibration_dataset or calibration_fn here.
if calibration_dataset is None and self.calibration_fn is None:
raise ValueError(
Expand Down Expand Up @@ -378,6 +417,7 @@ def _get_calibration_dataloader(
batch_size: int,
remove_unused_columns: bool,
data_collator: Optional[DataCollator] = None,
use_label: Optional[bool] = True,
) -> INCDataLoader:
data_collator = data_collator if data_collator is not None else default_data_collator
if remove_unused_columns:
Expand All @@ -394,7 +434,7 @@ def _get_calibration_dataloader(
drop_last=False,
)

return INCDataLoader.from_pytorch_dataloader(calibration_dataloader)
return INCDataLoader.from_pytorch_dataloader(calibration_dataloader, use_label)

def _remove_unused_columns(self, dataset: Dataset):
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
Expand Down
10 changes: 8 additions & 2 deletions optimum/intel/neural_compressor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@


class INCDataLoader(DataLoader):
use_label = True

@classmethod
def from_pytorch_dataloader(cls, dataloader: DataLoader):
def from_pytorch_dataloader(cls, dataloader: DataLoader, use_label: bool = True):
if not isinstance(dataloader, DataLoader):
raise TypeError(f"Expected a PyTorch DataLoader, got: {type(dataloader)}.")
inc_dataloader = cls(dataloader.dataset)
cls.use_label = use_label
for key, value in dataloader.__dict__.items():
inc_dataloader.__dict__[key] = value
return inc_dataloader
Expand All @@ -63,7 +66,10 @@ def __iter__(self):
if not isinstance(input, (dict, tuple, list, UserDict)):
raise TypeError(f"Model calibration cannot use input of type {type(input)}.")
label = input.get("labels") if isinstance(input, dict) else None
yield input, label
if self.use_label:
yield input, label
else:
yield input


def _cfgs_to_fx_cfgs(op_cfgs: Dict, observer_type: str = "post_training_static_quant") -> Dict:
Expand Down
85 changes: 85 additions & 0 deletions tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,91 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_name, expec
num_samples=num_samples,
)

def test_weight_only_quantization(self):
model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM"
op_type_dict = {
".*": {
"weight": {
"bits": 8,
"group_size": -1,
"scheme": "sym",
"algorithm": "RTN",
},
},
}
quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
quantizer = INCQuantizer.from_pretrained(model, task="text-generation")
calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2)

with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
save_directory=tmp_dir,
weight_only=True,
)
q_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
inp = torch.tensor([calibration_dataset[0]["input_ids"]])
out = model(inp)[0]
q_out = q_model(inp)[0]
self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1)))

op_type_dict = {
".*": {
"weight": {
"bits": 8,
"group_size": -1,
"scheme": "sym",
"algorithm": "AWQ",
},
},
}
quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict)

with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
save_directory=tmp_dir,
weight_only=True,
)
q_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
inp = torch.tensor([calibration_dataset[0]["input_ids"]])
out = model(inp)[0]
q_out = q_model(inp)[0]
self.assertTrue(torch.all(torch.isclose(out, q_out, atol=6e-1)))

op_type_dict = {
".*": {
"weight": {
"bits": 8,
"group_size": -1,
"scheme": "sym",
"algorithm": "GPTQ",
},
},
}
recipes = {"gptq_args": {"pad_max_length": len(calibration_dataset[0]["input_ids"])}}
quantization_config = PostTrainingQuantConfig(
approach="weight_only", op_type_dict=op_type_dict, recipes=recipes
)

with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
save_directory=tmp_dir,
weight_only=True,
)
q_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
inp = torch.tensor([calibration_dataset[0]["input_ids"]])
out = model(inp)[0]
q_out = q_model(inp)[0]
self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1)))

def test_dynamic_accuracy_strategy_quantization(self):
model_name = "distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
Expand Down