From c3ad58541fb0b63799e6862257579442c92bbdc5 Mon Sep 17 00:00:00 2001 From: XinyuYe-Intel Date: Mon, 4 Sep 2023 15:34:04 +0800 Subject: [PATCH] Added QLoRA support in NeuralChat finetuning and refined NeuralChat optimization API. (#174) --- .github/workflows/unit-test-engine.yml | 2 +- .github/workflows/unit-test-optimize.yml | 2 +- .../llm/finetuning/data_utils.py | 411 ++++++++++-------- .../llm/finetuning/eval_utils.py | 94 ++++ .../llm/finetuning/finetuning.py | 193 ++++++-- .../llm/inference/inference.py | 39 +- .../llm/quantization/optimization.py | 24 +- .../neural_chat/__init__.py | 1 - .../neural_chat/chatbot.py | 3 +- .../neural_chat/config.py | 59 ++- .../tests/finetuning/test_finetuning.py | 4 +- .../tests/optimization/test_optimization.py | 21 +- 12 files changed, 572 insertions(+), 281 deletions(-) create mode 100644 intel_extension_for_transformers/llm/finetuning/eval_utils.py diff --git a/.github/workflows/unit-test-engine.yml b/.github/workflows/unit-test-engine.yml index 2a6c2d767b0..b99d912bbb9 100644 --- a/.github/workflows/unit-test-engine.yml +++ b/.github/workflows/unit-test-engine.yml @@ -34,7 +34,7 @@ jobs: include: - test_branch: ${{ github.ref }} test_name: "PR-test" - - test_branch: "ut_parallal" + - test_branch: "main" test_name: "baseline" steps: - name: Docker Clean Up diff --git a/.github/workflows/unit-test-optimize.yml b/.github/workflows/unit-test-optimize.yml index f6c1362b710..860c887fbc7 100644 --- a/.github/workflows/unit-test-optimize.yml +++ b/.github/workflows/unit-test-optimize.yml @@ -34,7 +34,7 @@ jobs: include: - test_branch: ${{ github.ref }} test_name: "PR-test" - - test_branch: "ut_parallal" + - test_branch: "main" test_name: "baseline" steps: - name: Docker Clean Up diff --git a/intel_extension_for_transformers/llm/finetuning/data_utils.py b/intel_extension_for_transformers/llm/finetuning/data_utils.py index ad3b54a9c1c..8d5a8d53ab9 100644 --- a/intel_extension_for_transformers/llm/finetuning/data_utils.py +++ b/intel_extension_for_transformers/llm/finetuning/data_utils.py @@ -35,93 +35,6 @@ ), } -conv_header = """<|im_start|>system -- You are a helpful assistant chatbot trained by Intel. -- You answer questions. -- You are excited to be able to help the user, \ -but will refuse to do anything that could be considered harmful to the user. \ -- You are more than just an information source, you are also able to write poetry, \ -short stories, and make jokes.<|im_end|>\n""" - -user = "<|im_start|>user\n" -assistant = "<|im_start|>assistant\n" -end = "<|im_end|>" - -summarization_suffix_template = "\nSummarize the highlights of this article.\n" - -def create_alpaca(examples): - prompts = {} - prompts["source"] = [] - prompts["target"] = [] - for example in examples: - prompt_template = ( - ALPACA_PROMPT_DICT["prompt_with_input"] - if example.get("input") is not None and example.get("input") != "" - else ALPACA_PROMPT_DICT["prompt_without_input"] - ) - source = prompt_template.format_map(example) - prompts["source"].append(source) - prompts["target"].append(example["output"]) - return prompts - - -def tokenize_alpaca(tokenizer, data_args, finetune_args): - def tokenize(prompt, add_eos_token=True): - results = tokenizer( - prompt, - truncation=True, - max_length=data_args.max_seq_length, - padding=False, - return_tensors=None,) - for i in range(len(results["input_ids"])): - if (results["input_ids"][i][-1] != tokenizer.eos_token_id \ - and len(results["input_ids"][i]) < data_args.max_seq_length \ - and add_eos_token \ - ): - results["input_ids"][i].append(tokenizer.eos_token_id) - results["attention_mask"][i].append(1) - results["labels"] = copy.deepcopy(results["input_ids"]) - results["input_id_len"] = [len(result) for result in results["input_ids"]] - return results - - def preprocess_function(examples): - st = [s + t for s, t in zip(examples["prompt_sources"], examples["prompt_targets"])] - examples_tokenized = tokenize(st) - input_ids = examples_tokenized["input_ids"] - labels = examples_tokenized["labels"] - if not finetune_args.train_on_inputs: - sources_tokenized = tokenize(examples["prompt_sources"], add_eos_token=False) - for label, source_len in zip(labels, sources_tokenized["input_id_len"]): - label[:source_len] = [IGNORE_INDEX] * source_len - return dict( - input_ids=input_ids, - labels=labels, - attention_mask=examples_tokenized["attention_mask"], - ) - - return preprocess_function - - -def create_oasst(examples): - prompts = {} - prompts["prompt_sources"] = [] - prompts["prompt_targets"] = [] - - for conv in examples: - conv = conv["messages"] - prompt = conv_header - - for j in range(0, len(conv) - 1, 2): - u = conv[j]["content"] - ass = conv[j+1]["content"] - prompt = prompt + user + u + end + '\n' + assistant - response = ass + end - prompts["prompt_sources"].append(prompt) - prompts["prompt_targets"].append(response) - - prompt += response + '\n' - return prompts - def truncate_sequences(sequences, max_length): words_to_cut = sum(list(map(len, sequences))) - max_length if words_to_cut <= 0: @@ -133,140 +46,269 @@ def truncate_sequences(sequences, max_length): return sequences -def tokenize_oasst(tokenizer, data_args, finetune_args): +class CompletionDataPreprocess: + prompt_template = ALPACA_PROMPT_DICT + + def create_data(self, examples): + prompts = {} + prompts["source"] = [] + prompts["target"] = [] + for example in examples: + prompt_template = ( + self.prompt_template["prompt_with_input"] + if example.get("input") is not None and example.get("input") != "" + else self.prompt_template["prompt_without_input"] + ) + source = prompt_template.format_map(example) + prompts["source"].append(source) + prompts["target"].append(example["output"]) + return prompts + + @staticmethod + def tokenize_func(tokenizer, data_args, finetune_args): + def tokenize(prompt, add_eos_token=True): + results = tokenizer( + prompt, + truncation=True, + max_length=data_args.max_seq_length, + padding=False, + return_tensors=None,) + for i in range(len(results["input_ids"])): + if (results["input_ids"][i][-1] != tokenizer.eos_token_id \ + and len(results["input_ids"][i]) < data_args.max_seq_length \ + and add_eos_token \ + ): + results["input_ids"][i].append(tokenizer.eos_token_id) + results["attention_mask"][i].append(1) + results["labels"] = copy.deepcopy(results["input_ids"]) + results["input_id_len"] = [len(result) for result in results["input_ids"]] + return results + + def preprocess_function(examples): + st = [s + t for s, t in zip(examples["prompt_sources"], examples["prompt_targets"])] + examples_tokenized = tokenize(st) + input_ids = examples_tokenized["input_ids"] + labels = examples_tokenized["labels"] + if not finetune_args.train_on_inputs: + sources_tokenized = tokenize(examples["prompt_sources"], add_eos_token=False) + for label, source_len in zip(labels, sources_tokenized["input_id_len"]): + label[:source_len] = [IGNORE_INDEX] * source_len + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=examples_tokenized["attention_mask"], + ) - # special tokens - assistant_tokens = tokenizer.tokenize(assistant) + return preprocess_function - def preprocess_function(examples): - instructions = [q.strip() for q in examples["prompt_sources"]] - responses = [q.strip() for q in examples["prompt_targets"]] +class ChatDataPreprocess: + base_template = """### System: + - You are a helpful assistant chatbot trained by Intel. + - You answer questions. + - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. + - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.{eos_token}\n""" # pylint: disable=C0301 - examples["input_ids"] = [] - examples["labels"] = [] - examples["attention_mask"] = [] + def __init__(self, eos_token): + self.prompt_template = self.base_template.format_map({"eos_token": eos_token}) + self.user = "### User:\n" + self.assistant = "### Assistant:\n" + self.end = eos_token - for instruction, response in zip(instructions, responses): - header = re.findall("\<\|im_start\|\>system.*?\<\|im_end\|\>", instruction, re.DOTALL)[0] - convs = re.findall("\<\|im_start\|\>.*?\<\|im_end\|\>", instruction, re.DOTALL)[1:] + def create_data(self, examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] - convs_tokens = [ - tokenizer.tokenize(conv) + tokenizer.tokenize("\n") - for conv in convs - ] - header_tokens = tokenizer.tokenize(header) + tokenizer.tokenize("\n") + for conv in examples: + conv = conv["messages"] + prompt = self.prompt_template - max_input = data_args.max_source_length - len(header_tokens) - len(assistant_tokens) + for j in range(0, len(conv) - 1, 2): + u = conv[j]["content"] + ass = conv[j+1]["content"] + prompt = prompt + self.user + u + self.end + '\n' + self.assistant + response = ass + self.end + prompts["prompt_sources"].append(prompt) + prompts["prompt_targets"].append(response) - truncated_convs = truncate_sequences(convs_tokens, - max_input) + prompt += response + '\n' + return prompts - if len(truncated_convs) == 0: - truncated_convs = [convs_tokens[-1][:max_input - 3] + convs_tokens[-1][-3:]] + def tokenize_func(self, tokenizer, data_args, finetune_args): - prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] - prompt_ids = [tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens] - prompt_ids = list(chain(*prompt_ids)) + # special tokens + assistant_tokens = tokenizer.tokenize(self.assistant) - resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(response.strip())) - # keep last and eos_id - max_resp = data_args.max_seq_length - len(prompt_ids) - 1 - if len(resp_ids) > max_resp: - resp_ids = resp_ids[:max_resp - 1] + resp_ids[-1:] + def preprocess_function(examples): - input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id] - if not finetune_args.train_on_inputs: - labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id] - else: - labels = prompt_ids + resp_ids + [tokenizer.eos_token_id] + instructions = [q.strip() for q in examples["prompt_sources"]] + responses = [q.strip() for q in examples["prompt_targets"]] - # padding - input_len = len(input_ids) - pad_len = data_args.max_seq_length - input_len - input_ids = input_ids + [tokenizer.eos_token_id] * pad_len - labels = labels + [-100] * pad_len - attention_mask = [1] * input_len + [0] * pad_len + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] - assert len(input_ids) == data_args.max_seq_length - assert len(prompt_ids) <= data_args.max_source_length - assert len(labels) == len(input_ids) == len(attention_mask) + for instruction, response in zip(instructions, responses): + header = re.findall(r"### System.*?{}".format(self.end), instruction, re.DOTALL)[0] + convs = re.findall(r"### User.*?{0}|### Assistant.*?{0}".format(self.end), instruction, re.DOTALL) + convs_tokens = [ + tokenizer.tokenize(conv) + tokenizer.tokenize("\n") + for conv in convs + ] + header_tokens = tokenizer.tokenize(header) + tokenizer.tokenize("\n") - examples["input_ids"].append(input_ids) - examples["labels"].append(labels) - examples["attention_mask"].append(attention_mask) + max_input = data_args.max_source_length - len(header_tokens) - len(assistant_tokens) - return examples + truncated_convs = truncate_sequences(convs_tokens, + max_input) - return preprocess_function + if len(truncated_convs) == 0: + truncated_convs = [convs_tokens[-1][:max_input - 3] + convs_tokens[-1][-3:]] -def tokenize_cnn(tokenizer, data_args, finetune_args): - template_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summarization_suffix_template)) + prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] + prompt_ids = [tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens] + prompt_ids = list(chain(*prompt_ids)) + + resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(response.strip())) + # keep last and eos_id + max_resp = data_args.max_seq_length - len(prompt_ids) - 1 + if len(resp_ids) > max_resp: + resp_ids = resp_ids[:max_resp - 1] + resp_ids[-1:] - def preprocess_function(examples): + input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id] + if not finetune_args.train_on_inputs: + labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id] + else: + labels = prompt_ids + resp_ids + [tokenizer.eos_token_id] - articles = [q.strip() for q in examples["article"]] - highlights = [q.strip() for q in examples["highlights"]] + # padding + input_len = len(input_ids) + pad_len = data_args.max_seq_length - input_len + input_ids = input_ids + [tokenizer.eos_token_id] * pad_len + labels = labels + [IGNORE_INDEX] * pad_len + attention_mask = [1] * input_len + [0] * pad_len - examples["input_ids"] = [] - examples["labels"] = [] - examples["attention_mask"] = [] + assert len(input_ids) == data_args.max_seq_length + assert len(prompt_ids) <= data_args.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) - for article, highlight in zip(articles, highlights): - max_input = data_args.max_source_length - len(template_ids) + examples["input_ids"].append(input_ids) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) - article_tokens = tokenizer.tokenize(article)[:max_input] - prompt_ids = tokenizer.convert_tokens_to_ids(article_tokens) + template_ids + return examples - max_resp = data_args.max_seq_length - len(prompt_ids) - 1 - resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight))[:max_resp] + return preprocess_function - input_ids = prompt_ids + resp_ids + [tokenizer.eos_token_id] - if not finetune_args.train_on_inputs: - labels = [-100] * len(prompt_ids) + resp_ids + [tokenizer.eos_token_id] - else: - labels = prompt_ids + resp_ids + [tokenizer.eos_token_id] - # padding - input_len = len(input_ids) - pad_len = data_args.max_seq_length - input_len - input_ids = input_ids + [tokenizer.eos_token_id] * pad_len - labels = labels + [-100] * pad_len - attention_mask = [1] * input_len + [0] * pad_len +class SummarizationDataPreprocess: + prompt_template = "\nSummarize the highlights of this article.\n" + + def tokenize_func(self, tokenizer, data_args, finetune_args): + template_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(self.prompt_template)) + + def preprocess_function(examples): + + articles = [q.strip() for q in examples["article"]] + highlights = [q.strip() for q in examples["highlights"]] + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + examples["decoder_input_ids"] = [] + examples["decoder_attention_mask"] = [] + examples["decoder_labels"] = [] + + for article, highlight in zip(articles, highlights): + max_input = data_args.max_source_length - len(template_ids) + + article_tokens = tokenizer.tokenize(article)[:max_input] + prompt_ids = tokenizer.convert_tokens_to_ids(article_tokens) + template_ids + + # for inference + decoder_input_ids = copy.deepcopy(prompt_ids) + + max_resp = data_args.max_seq_length - len(prompt_ids) - 1 + resp_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight))[:max_resp] + \ + [tokenizer.eos_token_id] - assert len(input_ids) == data_args.max_seq_length - assert len(prompt_ids) <= data_args.max_source_length - assert len(labels) == len(input_ids) == len(attention_mask) + # for inference + max_decoder_labels_len = data_args.max_seq_length - data_args.max_source_length - 1 + decoder_labels = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(highlight) + )[:max_decoder_labels_len] + [tokenizer.eos_token_id] - examples["input_ids"].append(input_ids) - examples["labels"].append(labels) - examples["attention_mask"].append(attention_mask) + input_ids = prompt_ids + resp_ids + if not finetune_args.train_on_inputs: + labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + else: + labels = prompt_ids + resp_ids - return examples + # padding + input_len = len(input_ids) + pad_len = data_args.max_seq_length - input_len + input_ids = input_ids + [tokenizer.eos_token_id] * pad_len + labels = labels + [IGNORE_INDEX] * pad_len + attention_mask = [1] * input_len + [0] * pad_len - return preprocess_function + assert len(input_ids) == data_args.max_seq_length + assert len(prompt_ids) <= data_args.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(input_ids) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + # left padding for inference + input_len = len(decoder_input_ids) + pad_len = data_args.max_source_length - input_len + decoder_input_ids = [tokenizer.eos_token_id] * pad_len + decoder_input_ids + decoder_attention_mask = [0] * pad_len + [1] * input_len + + input_len = len(decoder_labels) + pad_len = data_args.max_seq_length - data_args.max_source_length - input_len + decoder_labels = decoder_labels + [IGNORE_INDEX] * pad_len + examples["decoder_input_ids"].append(decoder_input_ids) + examples["decoder_labels"].append(decoder_labels) + examples["decoder_attention_mask"].append(decoder_attention_mask) + + + return examples + + return preprocess_function def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args): - dataset_name = data_args.dataset_name if data_args.dataset_name is not None else data_args.train_file - if "oasst" in dataset_name: + if finetune_args.task == "chat": + preprocess = ChatDataPreprocess(tokenizer.eos_token) new_datasets = datasets.DatasetDict() - for key in ["train_ift"]: - prompts = create_oasst(raw_datasets[key]) - new_datasets["train"] = datasets.Dataset.from_dict(prompts) + for key in raw_datasets: + prompts = preprocess.create_data(raw_datasets[key]) - preprocess_fn = tokenize_oasst(tokenizer, data_args, finetune_args) + # deal irregular column name + if "train" in key: + new_key = "train" + if "val" in key: + new_key = "validation" + if "test" in key: + new_key = "test" + + new_datasets[new_key] = datasets.Dataset.from_dict(prompts) + + preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args) return new_datasets, preprocess_fn - elif "cnn" in dataset_name: - preprocess_fn = tokenize_cnn(tokenizer, data_args, finetune_args) - return raw_datasets, preprocess_fn - else: - # default use alpaca instruction template + elif finetune_args.task == "summarization": + preprocess = SummarizationDataPreprocess() + preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args) + + elif finetune_args.task == "completion": + # default use alpaca template + preprocess = CompletionDataPreprocess() for key in raw_datasets: - prompts = create_alpaca(raw_datasets[key]) + prompts = preprocess.create_data(raw_datasets[key]) columns_to_be_removed = list(raw_datasets[key].features.keys()) raw_datasets[key] = raw_datasets[key].add_column( "prompt_sources", prompts["source"] @@ -276,6 +318,9 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args): ) raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed) - preprocess_fn = tokenize_alpaca(tokenizer, data_args, finetune_args) + preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args) + + else: + raise NotImplementedError(f'finetune task data preprocessing is not support currently.') - return raw_datasets, preprocess_fn + return raw_datasets, preprocess_fn diff --git a/intel_extension_for_transformers/llm/finetuning/eval_utils.py b/intel_extension_for_transformers/llm/finetuning/eval_utils.py new file mode 100644 index 00000000000..a866ae43efe --- /dev/null +++ b/intel_extension_for_transformers/llm/finetuning/eval_utils.py @@ -0,0 +1,94 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import evaluate +import nltk +import numpy as np +import torch +from torch.utils.data import DataLoader + +@torch.no_grad() +def compute_rouge_metric(model, tokenizer, eval_dataset, training_args, gen_kwargs): + model.eval() + model.config.bos_token_id = tokenizer.bos_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + # Metric + metric = evaluate.load("rouge") + + def collate_fn(batch): + input_ids = [torch.tensor(ins["decoder_input_ids"]) for ins in batch] + labels = [torch.tensor(ins["decoder_labels"]) for ins in batch] + attention_mask = [torch.tensor(ins["decoder_attention_mask"]) for ins in batch] + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=tokenizer.eos_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) + attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) + + # TODO: support batch_size >1 + eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, + batch_size=1) + + + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + for step, batch in enumerate(eval_dataloader): + preds = model.generate( + input_ids=batch["input_ids"].to(model.device), + attention_mask=batch["attention_mask"].to(model.device), + **gen_kwargs, + ) + labels = batch["labels"] + labels = labels.cpu().numpy() + + preds = preds.cpu().numpy() + + # Replace -100s used for padding as we can't decode them + preds = np.where(preds != -100, preds, tokenizer.pad_token_id).tolist() + # only pred + preds = [pred[batch["input_ids"].shape[1]:] for pred in preds] + + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + labels = np.where(labels != -100, labels, tokenizer.pad_token_id).tolist() + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Some simple post-processing + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + + metric.add_batch( + predictions=decoded_preds, + references=decoded_labels, + ) + + + result = metric.compute(use_stemmer=True) + result = {k: round(v * 100, 4) for k, v in result.items()} + return result \ No newline at end of file diff --git a/intel_extension_for_transformers/llm/finetuning/finetuning.py b/intel_extension_for_transformers/llm/finetuning/finetuning.py index 1d2cf59ceb1..de58b0ff965 100644 --- a/intel_extension_for_transformers/llm/finetuning/finetuning.py +++ b/intel_extension_for_transformers/llm/finetuning/finetuning.py @@ -44,6 +44,7 @@ Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForSeq2SeqLM, + BitsAndBytesConfig, set_seed ) from transformers.trainer_utils import is_main_process, get_last_checkpoint @@ -52,10 +53,12 @@ import evaluate import torch import importlib.util -from transformers.utils.import_utils import is_optimum_available +from transformers.utils.import_utils import is_optimum_available, is_bitsandbytes_available from .data_utils import preprocess_dataset, ALPACA_PROMPT_DICT from intel_extension_for_transformers.neural_chat.config import BaseFinetuningConfig +if is_bitsandbytes_available(): + import bitsandbytes as bnb # pylint: disable=E0401 def is_optimum_habana_available(): return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None @@ -69,6 +72,11 @@ def __init__(self, finetuning_config: BaseFinetuningConfig): finetuning_config.training_args, finetuning_config.finetune_args ) + if finetuning_config.finetune_args.device == "auto": + if torch.cuda.is_available(): + finetuning_config.finetune_args.device = "cuda" + else: + finetuning_config.finetune_args.device = "cpu" if finetuning_config.finetune_args.device == "cpu": finetuning_config.training_args.no_cuda = True Arguments = type(finetuning_config.training_args) @@ -235,11 +243,40 @@ def load_tokenizer(self, model_args): return tokenizer def finetune(self): + model_args, data_args, training_args, finetune_args = \ + self.model_args, self.data_args, self.training_args, self.finetune_args + if not (is_bitsandbytes_available() and torch.cuda.is_available() and training_args.device.type == "cuda"): + finetune_args.qlora = False + if finetune_args.qlora: + # finetune_args.lora_all_linear = True + finetune_args.peft = "lora" + compute_dtype = ( + torch.float16 if training_args.fp16 else + (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + self.device_map = "auto" + self.bitsandbytes_quant_config = BitsAndBytesConfig( + load_in_4bit=finetune_args.bits == 4, + load_in_8bit=finetune_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=finetune_args.double_quant, + bnb_4bit_quant_type=finetune_args.quant_type, + ) + if finetune_args.bits not in [4, 8]: + raise NotImplementedError( + f"Unsupported bits {finetune_args.bits}, only support 4 and 8 now." + ) + else: + self.device_map = None + self.bitsandbytes_quant_config = None + config = self.load_model_config(self.model_args) if config.architectures[0].endswith("ForCausalLM"): - self.finetune_clm() + self.finetune_clm(model_args, data_args, training_args, finetune_args) elif config.architectures[0].endswith("ForConditionalGeneration"): - self.finetune_seq2seq() + self.finetune_seq2seq(model_args, data_args, training_args, finetune_args) else: raise NotImplementedError( "Unsupported architecture {}, only support CausalLM (CLM) \ @@ -248,22 +285,25 @@ def finetune(self): ) ) - def finetune_clm(self): - model_args, data_args, training_args, finetune_args = \ - self.model_args, self.data_args, self.training_args, self.finetune_args - - def find_all_linear_names(model): - cls = torch.nn.Linear - lora_module_names = set() - for name, module in model.named_modules(): - if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') - return list(lora_module_names) - + def find_all_linear_names(self, model): + cls = torch.nn.Linear + if self.finetune_args.qlora: + if self.finetune_args.bits == 8: + cls = bnb.nn.Linear8bitLt + elif self.finetune_args.bits == 4: + cls = bnb.nn.Linear4bit + + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + def finetune_clm(self, model_args, data_args, training_args, finetune_args): if finetune_args.device == 'habana': if not is_optimum_habana_available(): raise ImportError( @@ -287,18 +327,40 @@ def find_all_linear_names(model): # Load model if model_args.model_name_or_path: - model_dtype = torch.bfloat16 if training_args.bf16 else None - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=True if model_args.trust_remote_code else None, - torch_dtype=model_dtype, - low_cpu_mem_usage=True, + model_dtype = ( + torch.float16 if training_args.fp16 else + (torch.bfloat16 if training_args.bf16 else torch.float32) ) + if (re.search("mpt", model_args.model_name_or_path, re.IGNORECASE) or + re.search("neural-chat-7b-v1", model_args.model_name_or_path, re.IGNORECASE)): + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + device_map=self.device_map, + quantization_config=self.bitsandbytes_quant_config, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=True if model_args.trust_remote_code else None, + torch_dtype=model_dtype, + low_cpu_mem_usage=True, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + device_map=self.device_map, + quantization_config=self.bitsandbytes_quant_config, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=True if model_args.trust_remote_code else None, + torch_dtype=model_dtype, + low_cpu_mem_usage=True, + ) + tokenizer.padding_side = "left" # allow batched inference, while mpt series don't support else: raise ValueError( "Must provide model_name_or_path to load a pretrained CausalLM model." @@ -340,14 +402,15 @@ def find_all_linear_names(model): if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = "left" # Allow batched inference raw_datasets, preprocess_function = preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args) + column_names = list(raw_datasets["train"].features) with training_args.main_process_first(desc="dataset map pre-processing"): tokenized_datasets = raw_datasets.map( preprocess_function, batched=True, + remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) @@ -365,13 +428,20 @@ def concatenate_data(dataset, max_seq_length): concatenated_dataset[column] = reshaped_data return datasets.Dataset.from_dict(concatenated_dataset) - tokenized_datasets_ = tokenized_datasets["train"].remove_columns( - ["prompt_sources", "prompt_targets"] - ) tokenized_datasets["train"] = concatenate_data( - tokenized_datasets_, data_args.max_seq_length + tokenized_datasets["train"], data_args.max_seq_length ) + if training_args.do_eval: + if "test" not in tokenized_datasets: + self.logger.info('Splitting train dataset in train and validation according to `eval_dataset_size`') + tokenized_datasets = tokenized_datasets["train"].train_test_split( + test_size=data_args.eval_dataset_size, shuffle=True, seed=42 + ) + eval_dataset = tokenized_datasets["test"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") @@ -379,13 +449,6 @@ def concatenate_data(dataset, max_seq_length): if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) - if training_args.do_eval: - if "test" not in tokenized_datasets: - raise ValueError("--do_eval requires a test dataset") - eval_dataset = tokenized_datasets["test"] - if data_args.max_eval_samples is not None: - eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) - # Data collator # This one will take care of randomly masking the tokens. data_collator = DataCollatorForSeq2Seq( @@ -399,7 +462,7 @@ def concatenate_data(dataset, max_seq_length): # PEFT settings if finetune_args.peft == "lora": if finetune_args.lora_all_linear: - target_modules = find_all_linear_names(model) + target_modules = self.find_all_linear_names(model) else: target_modules = finetune_args.lora_target_modules @@ -474,10 +537,35 @@ def concatenate_data(dataset, max_seq_length): training_args.output_dir, state_dict=unwrapped_model.state_dict() ) - def finetune_seq2seq(self): - model_args, data_args, training_args, finetune_args = \ - self.model_args, self.data_args, self.training_args, self.finetune_args + if finetune_args.do_lm_eval and finetune_args.task != "summarization": + unwrapped_model.eval() + from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate + with training_args.main_process_first(desc="lm_eval"): + if is_main_process(training_args.local_rank): + with torch.no_grad(): + results = evaluate( + model="hf-causal", + model_args='pretrained='+model_args.model_name_or_path+\ + ',tokenizer='+model_args.model_name_or_path+',dtype=float16', + user_model=unwrapped_model, + device=unwrapped_model.device.type, + batch_size=training_args.per_device_eval_batch_size, + tasks=finetune_args.lm_eval_tasks,) + self.logger.info(results) + + if finetune_args.task == "summarization": + from .eval_utils import compute_rouge_metric + gen_kwargs = { + "num_beams": data_args.num_beams, + "max_new_tokens": data_args.max_target_length, + } + with training_args.main_process_first(desc="summarization eval"): + if is_main_process(training_args.local_rank): + results = compute_rouge_metric(unwrapped_model, tokenizer, eval_dataset, + training_args, gen_kwargs) + self.logger.info(results) + def finetune_seq2seq(self, model_args, data_args, training_args, finetune_args): # Detecting last checkpoint. last_checkpoint = None if os.path.isdir(training_args.output_dir) \ @@ -627,13 +715,20 @@ def preprocess_logits_for_metrics(logits, labels): # Load model if model_args.model_name_or_path: + model_dtype = ( + torch.float16 if training_args.fp16 else + (torch.bfloat16 if training_args.bf16 else torch.float32) + ) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, + device_map=self.device_map, + quantization_config=self.bitsandbytes_quant_config, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, + torch_dtype=model_dtype, ) model.resize_token_embeddings(len(tokenizer)) else: @@ -641,11 +736,15 @@ def preprocess_logits_for_metrics(logits, labels): # PEFT settings if finetune_args.peft == "lora": + if finetune_args.lora_all_linear: + target_modules = self.find_all_linear_names(model) + else: + target_modules = finetune_args.lora_target_modules peft_config = LoraConfig( r=finetune_args.lora_rank, lora_alpha=finetune_args.lora_alpha, lora_dropout=finetune_args.lora_dropout, - target_modules=finetune_args.lora_target_modules, + target_modules=target_modules, bias="none", task_type=TaskType.SEQ_2_SEQ_LM, ) diff --git a/intel_extension_for_transformers/llm/inference/inference.py b/intel_extension_for_transformers/llm/inference/inference.py index e09bf517a2d..1faa1722a6f 100644 --- a/intel_extension_for_transformers/llm/inference/inference.py +++ b/intel_extension_for_transformers/llm/inference/inference.py @@ -34,6 +34,11 @@ StoppingCriteriaList, StoppingCriteria, ) +from intel_extension_for_transformers.neural_chat.config import ( + AMPConfig, + WeightOnlyQuantizationConfig, + BitsAndBytesConfig +) # Set necessary env variables os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") @@ -389,26 +394,27 @@ def load_model( elif device == "cpu": set_cpu_running_env() - if optimization_config: - if optimization_config.amp_config: - dtype = optimization_config.amp_config.dtype - else: - dtype = "float32" - if optimization_config.bitsandbytes_config: - if device == "cuda" and is_bitsandbytes_available() and torch.cuda.is_available(): - bitsandbytes_quant_config = optimization_config.bitsandbytes_config - else: - logger.warning( - "CUDA device or bitsandbytes is not available, please make sure CUDA device and bitsandbytes" \ - + " library is available, ignoring bitsandbytes config now." - ) + if isinstance(optimization_config, AMPConfig): + dtype = optimization_config.dtype + else: + dtype = "float32" + + bitsandbytes_quant_config = None + if isinstance(optimization_config, BitsAndBytesConfig): + if device == "cuda" and is_bitsandbytes_available() and torch.cuda.is_available(): + bitsandbytes_quant_config = optimization_config else: - bitsandbytes_quant_config = None + logger.warning( + "CUDA device or bitsandbytes is not available, please make sure CUDA device and bitsandbytes" \ + + " library is available, ignoring bitsandbytes config now." + ) if dtype == "bfloat16": torch_dtype = torch.bfloat16 elif dtype == "float16": torch_dtype = torch.float16 + elif dtype == "float32": + torch_dtype = torch.float32 else: logger.warning(f"Unsupported dtype {dtype}, using float32 now.") torch_dtype = torch.float32 @@ -431,7 +437,6 @@ def load_model( ) elif (re.search("mpt", model_name, re.IGNORECASE) or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)): - from transformers import AutoModelForCausalLM with smart_context_manager(use_deepspeed=use_deepspeed): model = AutoModelForCausalLM.from_pretrained( @@ -492,7 +497,7 @@ def load_model( if model.generation_config.eos_token_id is None: model.generation_config.eos_token_id = tokenizer.eos_token_id - if optimization_config: + if isinstance(optimization_config, WeightOnlyQuantizationConfig): from intel_extension_for_transformers.neural_chat.chatbot import optimize_model model = optimize_model(model, optimization_config) @@ -535,8 +540,6 @@ def load_model( ) if cpu_jit and (re.search("mpt-7b", model_name, re.IGNORECASE) or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)): - from transformers import AutoModelForCausalLM - # TDDO # model = jit_trace_mpt_7b(model) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, diff --git a/intel_extension_for_transformers/llm/quantization/optimization.py b/intel_extension_for_transformers/llm/quantization/optimization.py index a73ef9e1651..b2cafe4d0fa 100644 --- a/intel_extension_for_transformers/llm/quantization/optimization.py +++ b/intel_extension_for_transformers/llm/quantization/optimization.py @@ -15,29 +15,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from intel_extension_for_transformers.neural_chat.config import OptimizationConfig +from typing import Union +from intel_extension_for_transformers.neural_chat.config import ( + AMPConfig, + WeightOnlyQuantizationConfig, + BitsAndBytesConfig +) class Optimization: - def __init__(self, optimization_config: OptimizationConfig): + def __init__( + self, + optimization_config: Union[AMPConfig, WeightOnlyQuantizationConfig, BitsAndBytesConfig] + ): self.optimization_config = optimization_config def optimize(self, model): optimized_model = model config = self.optimization_config - if config.weight_only_quant_config: + if isinstance(config, WeightOnlyQuantizationConfig): print("Applying Weight Only Quantization.") from neural_compressor import PostTrainingQuantConfig, quantization op_type_dict = { '.*':{ # re.match "weight": { - 'bits': config.weight_only_quant_config.bits, # 1-8 bits - 'group_size': config.weight_only_quant_config.group_size, # -1 (per-channel) - 'scheme': config.weight_only_quant_config.scheme, # sym/asym - 'algorithm': config.weight_only_quant_config.algorithm, # RTN/AWQ/TEQ + 'bits': config.bits, # 1-8 bits + 'group_size': config.group_size, # -1 (per-channel) + 'scheme': config.scheme, # sym/asym + 'algorithm': config.algorithm, # RTN/AWQ/TEQ }, }, } - recipes = {"rtn_args": {"sym_full_range": config.weight_only_quant_config.sym_full_range}} + recipes = {"rtn_args": {"sym_full_range": config.sym_full_range}} conf = PostTrainingQuantConfig( approach='weight_only', op_type_dict=op_type_dict, diff --git a/intel_extension_for_transformers/neural_chat/__init__.py b/intel_extension_for_transformers/neural_chat/__init__.py index e3222e0b79d..585d3a3b433 100644 --- a/intel_extension_for_transformers/neural_chat/__init__.py +++ b/intel_extension_for_transformers/neural_chat/__init__.py @@ -23,7 +23,6 @@ CodeGenerationFinetuningConfig, TTSFinetuningConfig ) -from .config import OptimizationConfig from .chatbot import build_chatbot from .chatbot import finetune_model from .chatbot import optimize_model diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index 40463d7b354..f3448b1c26e 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -19,7 +19,6 @@ from intel_extension_for_transformers.llm.finetuning.finetuning import Finetuning from intel_extension_for_transformers.llm.quantization.optimization import Optimization from .config import PipelineConfig -from .config import OptimizationConfig from .config import BaseFinetuningConfig from .plugins import is_plugin_enabled, get_plugin_instance, get_registered_plugins from .config import DeviceOptions @@ -99,7 +98,7 @@ def finetune_model(config: BaseFinetuningConfig): finetuning = Finetuning(config) finetuning.finetune() -def optimize_model(model, config: OptimizationConfig): +def optimize_model(model, config): """Optimize the model based on the provided configuration. Args: diff --git a/intel_extension_for_transformers/neural_chat/config.py b/intel_extension_for_transformers/neural_chat/config.py index fac098e2b5c..fc26f8a6e4f 100644 --- a/intel_extension_for_transformers/neural_chat/config.py +++ b/intel_extension_for_transformers/neural_chat/config.py @@ -209,7 +209,7 @@ class DataArguments: metadata={"help": "The list of special tokens to add in tokenizer."} ) max_source_length: Optional[int] = field( - default=512, + default=384, metadata={ "help": ( "The maximum total input sequence length after tokenization. Sequences longer " @@ -218,7 +218,7 @@ class DataArguments: }, ) max_target_length: Optional[int] = field( - default=256, + default=128, metadata={ "help": ( "The maximum total sequence length for target text after tokenization. Sequences longer " @@ -226,6 +226,18 @@ class DataArguments: ) }, ) + num_beams: Optional[int] = field( + default=4, + metadata={ + "help": ( + "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + ) + }, + ) + eval_dataset_size: int = field( + default=500, metadata={"help": "Size of validation dataset."} + ) streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) preprocessing_num_workers: Optional[int] = field( default=None, @@ -300,7 +312,7 @@ class FinetuningArguments: metadata={"help": "if False, masks out inputs in loss"}, ) device: str = field( - default="cpu", + default="auto", metadata={ "help": "What device to use for finetuning.", "choices": ["cpu", "cuda", "habana", "auto"], @@ -310,6 +322,36 @@ class FinetuningArguments: default=False, metadata={"help": "if True, will add adaptor for all linear for lora finetuning"}, ) + task: Optional[str] = field( + default="completion", + metadata={"help": "task name, different task means different data format.", + "choices": ["completion", "chat", "summarization"] + }, + ) + do_lm_eval: bool = field( + default=False, + metadata={"help": "whether to run the LM evaluation with EleutherAI/lm-evaluation-harness"}, + ) + lm_eval_tasks: Optional[List[str]] = field( + default_factory=lambda: ["truthfulqa_mc"], + metadata={"help": "tasks list for accuracy validation with EleutherAI/lm-evaluation-harness."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether use qlora for finetuning"}, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=4, + metadata={"help": "How many bits to use."} + ) @dataclass class TTSDatasetArguments: @@ -385,12 +427,6 @@ class WeightOnlyQuantizationConfig: class AMPConfig: dtype: str = 'bfloat16' -@dataclass -class OptimizationConfig: - amp_config: AMPConfig = AMPConfig() - weight_only_quant_config: WeightOnlyQuantizationConfig = None - bitsandbytes_config: BitsAndBytesConfig = None - class PipelineConfig: def __init__(self, model_name_or_path="meta-llama/Llama-2-7b-hf", @@ -404,7 +440,10 @@ def __init__(self, self.device = device self.plugins = plugins self.loading_config = loading_config if loading_config is not None else LoadingModelConfig() - self.optimization_config = optimization_config if optimization_config is not None else OptimizationConfig() + self.optimization_config = optimization_config if optimization_config is not None else AMPConfig() + assert type(self.optimization_config) in [AMPConfig, WeightOnlyQuantizationConfig, BitsAndBytesConfig], \ + f"Expect optimization_config be an object of AMPConfig, WeightOnlyQuantizationConfig" + \ + " or BitsAndBytesConfig,got {type(self.optimization_config)}." for plugin_name, plugin_value in self.plugins.items(): if plugin_value['enable']: print(f"create {plugin_name} plugin instance...") diff --git a/intel_extension_for_transformers/neural_chat/tests/finetuning/test_finetuning.py b/intel_extension_for_transformers/neural_chat/tests/finetuning/test_finetuning.py index 70bcf0d8ef7..24c0bd8f783 100644 --- a/intel_extension_for_transformers/neural_chat/tests/finetuning/test_finetuning.py +++ b/intel_extension_for_transformers/neural_chat/tests/finetuning/test_finetuning.py @@ -56,7 +56,7 @@ def test_finetune_clm(self): max_steps=3, overwrite_output_dir=True ) - finetune_args = FinetuningArguments() + finetune_args = FinetuningArguments(device='cpu') finetune_cfg = TextGenerationFinetuningConfig( model_args=model_args, data_args=data_args, @@ -74,7 +74,7 @@ def test_finetune_seq2seq(self): max_steps=3, overwrite_output_dir=True ) - finetune_args = FinetuningArguments() + finetune_args = FinetuningArguments(device='cpu') finetune_cfg = TextGenerationFinetuningConfig( model_args=model_args, data_args=data_args, diff --git a/intel_extension_for_transformers/neural_chat/tests/optimization/test_optimization.py b/intel_extension_for_transformers/neural_chat/tests/optimization/test_optimization.py index 43cf3af75ee..9503c3ef1c6 100644 --- a/intel_extension_for_transformers/neural_chat/tests/optimization/test_optimization.py +++ b/intel_extension_for_transformers/neural_chat/tests/optimization/test_optimization.py @@ -19,8 +19,9 @@ import torch from transformers import BitsAndBytesConfig from transformers.utils.bitsandbytes import is_bitsandbytes_available -from intel_extension_for_transformers.neural_chat.chatbot import build_chatbot -from intel_extension_for_transformers.neural_chat.config import PipelineConfig, OptimizationConfig, WeightOnlyQuantizationConfig +from intel_extension_for_transformers.neural_chat import build_chatbot +from intel_extension_for_transformers.neural_chat.config import PipelineConfig, WeightOnlyQuantizationConfig + class TestChatbotBuilder(unittest.TestCase): def setUp(self): @@ -29,11 +30,17 @@ def setUp(self): def tearDown(self) -> None: return super().tearDown() + def test_build_chatbot_with_AMP(self): + config = PipelineConfig() + chatbot = build_chatbot(config) + self.assertIsNotNone(chatbot) + response = chatbot.predict(query="Tell me about Intel Xeon Scalable Processors.") + print(response) + self.assertIsNotNone(response) + def test_build_chatbot_with_weight_only_quant(self): config = PipelineConfig( - optimization_config=OptimizationConfig( - weight_only_quant_config=WeightOnlyQuantizationConfig() - ) + optimization_config=WeightOnlyQuantizationConfig() ) chatbot = build_chatbot(config) self.assertIsNotNone(chatbot) @@ -45,13 +52,11 @@ def test_build_chatbot_with_bitsandbytes_quant(self): if is_bitsandbytes_available() and torch.cuda.is_available(): config = PipelineConfig( device='cuda', - optimization_config=OptimizationConfig( - bitsandbytes_config=BitsAndBytesConfig( + optimization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype="bfloat16" - ) ) ) chatbot = build_chatbot(config)