# Efficient Finetuning of Quantized LLMs --- 低资源的大语言模型量化训练/部署方案
-
[中文](README_zh.md) | English
+
This is the repo for the `Efficient Finetuning of Quantized LLMs` project, which aims to build and share instruction-following Chinese `baichuan-7b/LLaMA/Pythia/GLM` model tuning methods which can be trained on **a single Nvidia RTX-2080TI**, multi-round chatbot which can be trained on **a single Nvidia RTX-3090** with the context len 2048.
We uses [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for quantization and is integrated with Huggingface's [PEFT](https://github.com/huggingface/peft) and [transformers](https://github.com/huggingface/transformers/) libraries.
-
## News
- [23/07/20] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path Llama-2-7b-hf` argument to use the LLaMA-2 model.
@@ -67,6 +66,7 @@ We uses [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for quantiza
As of now, we support the following datasets, most of which are all available in the [Hugging Face datasets library](https://huggingface.co/datasets/).
- For supervised fine-tuning:
+
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3)
@@ -88,6 +88,7 @@ As of now, we support the following datasets, most of which are all available in
- [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)
- For reward model training:
+
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -109,7 +110,6 @@ We provide a number of data preprocessing tools in the [data](./chatllms/data) f
- [sft_dataset.py](./chatllms/data/sft_dataset.py) : Supervised fine-tuning dataset class and collator
- [conv_dataset.py](./chatllms/data/conv_dataset.py) : Conversation dataset class and collator
-
## Model Zoo
We provide a number of models in the [Hugging Face model hub](https://huggingface.co/decapoda-research). These models are trained with QLoRA and can be used for inference and finetuning. We provide the following models:
@@ -129,13 +129,17 @@ We provide a number of models in the [Hugging Face model hub](https://huggingfac
- CUDA >= 11.0
- Python 3.8+ and PyTorch 1.13.1+
+
- 🤗Transformers, Datasets, Accelerate, PEFT and bitsandbytes
+
- jieba, rouge_chinese and nltk (used at evaluation)
+
- gradio (used in gradio_webserver.py)
### Install required packages
To load models in 4bits with transformers and bitsandbytes, you have to install accelerate and transformers from source and make sure you have the latest version of the bitsandbytes library (0.39.0). You can achieve the above with the following commands:
+
```bash
pip install -q -U bitsandbytes
pip install -q -U git+https://github.com/huggingface/transformers.git
@@ -154,11 +158,11 @@ cd Efficient-Tuning-LLMs
## Getting Started
-| main function | Useage | Scripts |
-| ---------------------------------------- | ------------------------------------------------------------------------------------ | ------------------------------------------ |
-| [train.py](./train.py) | Full finetune LLMs on SFT datasets | [full_finetune](./scripts/full_finetune) |
-| [train_lora.py](./train_lora.py) | Finetune LLMs by using Lora (Low-Rank Adaptation of Large Language Models finetune) | [lora_finetune](./scripts/lora_finetune) |
-| [train_qlora.py](train_qlora.py) | Finetune LLMs by using QLora (QLoRA: Efficient Finetuning of Quantized LLMs) | [qlora_finetune](./scripts/qlora_finetune) |
+| main function | Useage | Scripts |
+| -------------------------------- | ------------------------------------------------------------------------------------ | ------------------------------------------ |
+| [train.py](./train.py) | Full finetune LLMs on SFT datasets | [full_finetune](./scripts/full_finetune) |
+| [train_lora.py](./train_lora.py) | Finetune LLMs by using Lora (Low-Rank Adaptation of Large Language Models finetune) | [lora_finetune](./scripts/lora_finetune) |
+| [train_qlora.py](train_qlora.py) | Finetune LLMs by using QLora (QLoRA: Efficient Finetuning of Quantized LLMs) | [qlora_finetune](./scripts/qlora_finetune) |
### QLora int4 Finetune
@@ -170,6 +174,7 @@ python train_qlora.py --model_name_or_path
```
For models larger than 13B, we recommend adjusting the learning rate:
+
```bash
python train_qlora.py –learning_rate 0.0001 --model_name_or_path
```
@@ -220,7 +225,9 @@ python train_qlora.py \
To find more scripts for finetuning and inference, please refer to the `scripts` folder.
## Quantization
+
Quantization parameters are controlled from the `BitsandbytesConfig` ([see HF documenation](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig)) as follows:
+
- Loading in 4 bits is activated through `load_in_4bit`
- The datatype used for the linear layer computations with `bnb_4bit_compute_dtype`
- Nested quantization is activated through `bnb_4bit_use_double_quant`
@@ -245,20 +252,25 @@ Quantization parameters are controlled from the `BitsandbytesConfig` ([see HF do
## Tutorials and Demonstrations
We provide two Google Colab notebooks to demonstrate the use of 4bit models in inference and fine-tuning. These notebooks are intended to be a starting point for further research and development.
+
- [Basic usage Google Colab notebook](https://colab.research.google.com/drive/1ge2F1QSK8Q7h0hn3YKuBCOAS0bK8E0wf?usp=sharing) - This notebook shows how to use 4bit models in inference with all their variants, and how to run GPT-neo-X (a 20B parameter model) on a free Google Colab instance 🤯
- [Fine tuning Google Colab notebook](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing) - This notebook shows how to fine-tune a 4bit model on a downstream task using the Hugging Face ecosystem. We show that it is possible to fine tune GPT-neo-X 20B on a Google Colab instance!
Other examples are found under the examples/ folder.
+
- Finetune LLama-7B (ex1)
- Finetune GPT-neo-X 20B (ex2)
## Using Local Datasets
+
You can specify the path to your dataset using the --dataset argument. If the --dataset_format argument is not set, it will default to the Alpaca format. Here are a few examples:
- Training with an alpaca format dataset:
+
```python
python train_qlora.py --dataset="path/to/your/dataset"
```
+
- Training with a self-instruct format dataset:
```python
@@ -266,9 +278,11 @@ python train_qlora.py --dataset="path/to/your/dataset" --dataset_format="self-in
```
## Multi GPU
+
Multi GPU training and inference work out-of-the-box with Hugging Face's Accelerate. Note that the per_device_train_batch_size and per_device_eval_batch_size arguments are global batch sizes unlike what their name suggest.
When loading a model for training or inference on multiple GPUs you should pass something like the following to AutoModelForCausalLM.from_pretrained():
+
```python
device_map = "auto"
max_memory = {i: '46000MB' for i in range(torch.cuda.device_count())}
@@ -303,15 +317,15 @@ python gradio_webserver.py \
--lora_model_name_or_path `path/to/your/model_dir`
```
-
## Sample Outputs
+
We provide generations for the models described in the paper for both OA and Vicuna queries in the `eval/generations` folder. These are intended to foster further research on model evaluation and analysis.
Can you distinguish ChatGPT from Guanaco? Give it a try!
You can access [the model response Colab here](https://colab.research.google.com/drive/1kK6xasHiav9nhiRUJjPMZb4fAED4qRHb?usp=sharing) comparing ChatGPT and Guanaco 65B on Vicuna prompts.
-
## Known Issues and Limitations
+
Here a list of known issues and bugs. If your issue is not reported here, please open a new issue and describe the problem.
1. 4-bit inference is slow. Currently, our 4-bit inference implementation is not yet integrated with the 4-bit matrix multiplication
@@ -319,13 +333,12 @@ Here a list of known issues and bugs. If your issue is not reported here, please
3. Currently, using `bnb_4bit_compute_type='fp16'` can lead to instabilities. For 7B LLaMA, only 80% of finetuning runs complete without error. We have solutions, but they are not integrated yet into bitsandbytes.
4. Make sure that `tokenizer.bos_token_id = 1` to avoid generation issues.
-
## License
`Efficient Finetuning of Quantized LLMs` is released under the Apache 2.0 license.
-
## Acknowledgements
+
We thank the Huggingface team, in particular Younes Belkada, for their support integrating QLoRA with PEFT and transformers libraries.
We appreciate the work by many open-source contributors, especially:
@@ -338,7 +351,6 @@ We appreciate the work by many open-source contributors, especially:
- [Vicuna](https://github.com/lm-sys/FastChat/)
- [xTuring](https://github.com/stochasticai/xTuring)
-
## Citation
Please cite the repo if you use the data or code in this repo.
diff --git a/README_zh.md b/README_zh.md
index 00b7078..3b2d883 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -13,22 +13,22 @@
👋🤗🤗👋 加入我们 [WeChat](assets/wechat.jpg).
-
+
# Efficient Finetuning of Quantized LLMs --- 低资源的大语言模型量化训练/部署方案
-
[English](README.md) | 中文
+
这里是`Efficient Finetuning of Quantized LLMs`项目的存储库,旨在构建和开源 遵循指令的`baichuan/LLaMA/Pythia/GLM`中文大模型微调训练方法,该方法可以在**单个 Nvidia RTX-2080TI**上进行训练,多轮聊天机器人可以在**单个 Nvidia RTX-3090**上进行上下文长度 2048的模型训练。
我们使用[bitsandbytes](https://github.com/TimDettmers/bitsandbytes)进行量化,并与Huggingface的[PEFT](https://github.com/huggingface/peft)和 [transformers](https://github.com/huggingface/transformers/)库集成。
- 本项目主要内容如下:
+本项目主要内容如下:
- 📗 支持全量参数指令微调、LoRA指令微调(后续将会提供支持), QLoRA低成本高效指令微调。
- 📗 支持绝大部分主流的开源大模型,如百川 baichuan、Ziya、Bloom、LLaMA、Pythia、OPT等。
@@ -88,6 +88,7 @@ QLora 引入了多种创新,旨在在不牺牲性能的情况下减少内存
截至目前,我们支持以下数据集,这些数据集都可以在 [Hugging Face Datasets](https://huggingface.co/datasets) 上找到。我们将在未来添加更多数据集。
- For supervised fine-tuning:
+
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3)
@@ -103,16 +104,16 @@ QLora 引入了多种创新,旨在在不牺牲性能的情况下减少内存
- [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)
- For reward model training:
+
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
-
请参考 [data/README.md](data/README.md) 了解如何使用这些数据集训练自己的 ChatGPT。如果您想探索更多数据集,请参考 [awesome-instruction-datasets](https://github.com/jianzhnie/awesome-instruction-datasets). 默认情况下,我们使用 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集进行训练和微调。
-
部分数据集需要 huggingface 的账号认证确认才能使用,我们建议使用以下命令登录您的 Hugging Face 账户。
+
```bash
pip install --upgrade huggingface_hub
huggingface-cli login
@@ -126,7 +127,6 @@ huggingface-cli login
- sft_dataset.py:有监督的对话数据集类
- conv_dataset.py:多轮对话数据集类
-
## 模型仓库
我们在 [Hugging Face ](https://huggingface.co/GaussianTech/)提供了许多模型。这些模型经过Self- Instruct 数据集的训练,可用于推理和微调:
diff --git a/assets/guanaco.svg b/assets/guanaco.svg
index 64e341f..c126704 100644
--- a/assets/guanaco.svg
+++ b/assets/guanaco.svg
@@ -1,98 +1,98 @@
-
-
-
-
+
+
+
+
diff --git a/chatbot.py b/chatbot.py
deleted file mode 100644
index 31021c0..0000000
--- a/chatbot.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import openai
-import gradio as gr
-
-
-if __name__ == "__main__":
- openai.api_key = "Your API key"
-
- messages = [
- {"role": "system", "content": "You are a helpful and kind AI Assistant."},
- ]
-
- def chatbot(input):
- if input:
- messages.append({"role": "user", "content": input})
- chat = openai.ChatCompletion.create(
- model="gpt-3.5-turbo", messages=messages
- )
- reply = chat.choices[0].message.content
- messages.append({"role": "assistant", "content": reply})
- return reply
-
- inputs = gr.inputs.Textbox(lines=7, label="Chat with AI")
- outputs = gr.outputs.Textbox(label="Reply")
-
- gr.Interface(fn=chatbot, inputs=inputs, outputs=outputs, title="AI Chatbot",
- description="Ask anything you want",
- theme="compact").launch(share=True)
\ No newline at end of file
diff --git a/chatllms/configs/gen_args.py b/chatllms/configs/gen_args.py
index ca5df19..8ed84fd 100644
--- a/chatllms/configs/gen_args.py
+++ b/chatllms/configs/gen_args.py
@@ -4,9 +4,7 @@
@dataclass
class GenerationArguments:
- """
- Arguments pertaining to specify the model generation parameters.
- """
+ """Arguments pertaining to specify the model generation parameters."""
# generation parameters
# 是否使用cache
use_cache: Optional[bool] = field(default=True)
diff --git a/chatllms/data/conv_dataset.py b/chatllms/data/conv_dataset.py
index 4e885f3..5f9f3e9 100644
--- a/chatllms/data/conv_dataset.py
+++ b/chatllms/data/conv_dataset.py
@@ -13,22 +13,22 @@
@dataclass
class VicunaDataset(Dataset):
- """
- Dataset for multi-turn conversations using a Transformer model.
+ """Dataset for multi-turn conversations using a Transformer model.
Attributes:
raw_data: The preprocessed dataset dict to load
tokenizer: Pretrained tokenizer to encode text
max_seq_length: Maximum sequence length for model inputs
"""
+
def __init__(
self,
raw_data: datasets.DatasetDict,
tokenizer: PreTrainedTokenizer,
max_seq_length: int = 1024,
):
- """
- Initialize the dataset with conversations, tokenizer, and max sequence length.
+ """Initialize the dataset with conversations, tokenizer, and max
+ sequence length.
Args:
raw_data: The preprocessed dataset dict to load
@@ -51,8 +51,7 @@ def __init__(
def tokenize_conversation(
self,
conversation: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Tokenize a single conversation into input IDs and labels.
+ """Tokenize a single conversation into input IDs and labels.
Args:
conversation: List of turns in the conversation
@@ -105,8 +104,7 @@ def tokenize_conversation(
return torch.tensor(input_ids), torch.tensor(labels)
def _get_human_prefix(self, turn_id: int, role: str) -> str:
- """
- Get the prefix for a human turn.
+ """Get the prefix for a human turn.
Args:
turn_id: Index of the current turn
@@ -126,8 +124,7 @@ def __len__(self) -> int:
return len(self.raw_data)
def __getitem__(self, index: int) -> Dict:
- """
- Get the input IDs and labels for a specific conversation.
+ """Get the input IDs and labels for a specific conversation.
Args:
index: Index of the conversation
@@ -147,23 +144,22 @@ def __getitem__(self, index: int) -> Dict:
@dataclass
class ConversationDataset(Dataset):
- """
- Dataset for multi-turn conversations using Transformer model.
+ """Dataset for multi-turn conversations using Transformer model.
Attributes:
raw_data: The preprocessed dataset dict to load
tokenizer: Pretrained tokenizer
max_seq_length: Maximum length of sequence
"""
+
def __init__(
self,
raw_data: datasets.DatasetDict,
tokenizer: PreTrainedTokenizer,
max_seq_length: int = 1024,
):
- """
- Initialize the dataset with conversations, tokenizer and max sequence length.
- """
+ """Initialize the dataset with conversations, tokenizer and max
+ sequence length."""
self.raw_data = raw_data
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
@@ -174,8 +170,7 @@ def tokenize_conversation(
self,
conversation: List[Dict],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Tokenize a single conversation into input IDs and labels.
+ """Tokenize a single conversation into input IDs and labels.
Args:
conversation: List of turns in the conversation
@@ -218,8 +213,7 @@ def __len__(self) -> int:
return len(self.raw_data)
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
- """
- Get the input IDs and labels for a specific conversation.
+ """Get the input IDs and labels for a specific conversation.
Args:
index: Index of the conversation
@@ -248,9 +242,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
@dataclass
class ConversationDataCollator(object):
- """
- Collate and pad a batch of conversation examples to prepare for training.
- """
+ """Collate and pad a batch of conversation examples to prepare for
+ training."""
+
def __init__(
self,
tokenizer: PreTrainedTokenizer,
@@ -300,8 +294,7 @@ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def make_conversation_data_module(tokenizer: PreTrainedTokenizer,
args) -> Dict[str, Dataset]:
- """
- Create dataset and collator for conversation modeling.
+ """Create dataset and collator for conversation modeling.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer object.
@@ -310,7 +303,6 @@ def make_conversation_data_module(tokenizer: PreTrainedTokenizer,
Returns:
dict: A dictionary containing the train_dataset and eval_dataset.
-
"""
# Determine the appropriate dataset class based on dataset_type flag
dataset_cls = (VicunaDataset if args.conversation_template == 'vicuna' else
diff --git a/chatllms/data/data_utils.py b/chatllms/data/data_utils.py
index 01e96a2..7c03cfa 100644
--- a/chatllms/data/data_utils.py
+++ b/chatllms/data/data_utils.py
@@ -87,8 +87,7 @@ def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
- """
- Extracts input from an example in the Alpaca dataset.
+ """Extracts input from an example in the Alpaca dataset.
Args:
example: A dictionary containing a single example from the Alpaca dataset.
@@ -100,7 +99,6 @@ def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
>>> example = {'input': 'example input', 'output': 'example output'}
>>> extract_alpaca_dataset(example)
{'input': 'example input'}
-
"""
if example.get('input', '') != '':
prompt_format = ALPACA_PROMPT_DICT['prompt_input']
@@ -110,8 +108,8 @@ def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
def extract_vicuna_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
- """
- Extracts the input and output portions of a single conversation example from the Vicuña format.
+ """Extracts the input and output portions of a single conversation example
+ from the Vicuña format.
Args:
example (Dict[str, Any]): A single conversation example in the Vicuña format.
@@ -184,8 +182,8 @@ def extract_random_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
def local_dataset(dataset_path: str,
eval_dataset_size: float = 0.1) -> Tuple[Dataset, Dataset]:
- """
- Reads in a dataset from a file and returns it as a split train-test dataset.
+ """Reads in a dataset from a file and returns it as a split train-test
+ dataset.
Args:
dataset_path (str): The name of the dataset file to read in. \
@@ -195,7 +193,6 @@ def local_dataset(dataset_path: str,
A tuple containing two datasets - the training subset and the testing subset.
Raises:
ValueError: If the specified file format is unsupported.
-
"""
# Read in the full dataset from file based on the file format
@@ -221,8 +218,7 @@ def local_dataset(dataset_path: str,
def load_data(
dataset_path: str,
eval_dataset_size: float = 0.1) -> Union[Dict[str, Dataset], None]:
- """
- Load a dataset based on its name.
+ """Load a dataset based on its name.
Args:
dataset_path: A string representing the path to the dataset to be loaded.
@@ -238,7 +234,6 @@ def load_data(
Examples:
>>> load_data('alpaca')
{'train': Dataset(...), 'validation': Dataset(...), 'test': Dataset(...)}
-
"""
if not os.path.exists(dataset_path):
# Download dataset from HuggingFace Datasets
@@ -263,9 +258,7 @@ def formate_instruction_dataset(
dataset_name: str,
dataset_format: str,
instruction_template: str = 'default') -> Optional[Dict[str, Dataset]]:
- """
- Formats a given dataset based on its name and format.
-
+ """Formats a given dataset based on its name and format.
Removes unused columns, renames columns to 'input' and 'output',
and applies dataset-specific formatting based on the dataset_name.
@@ -283,6 +276,7 @@ def formate_instruction_dataset(
specified format.
None if the dataset does not exist or if the format is not recognized.
"""
+
def _format_dolly15k(dataset: Dataset) -> Dataset:
"""Format Dolly-15k dataset."""
dataset = dataset.rename_column('context', 'input')
@@ -376,8 +370,8 @@ def split_train_eval(
do_train: bool = True,
max_train_samples: int = None,
) -> Dict[str, Dataset]:
- """
- Prepare the training and evaluation datasets for a machine learning model.
+ """Prepare the training and evaluation datasets for a machine learning
+ model.
Args:
dataset (DatasetDict): The complete dataset containing train, validation, and test splits.
@@ -435,9 +429,8 @@ def split_train_eval(
def make_data_module(args):
- """
- Make dataset and collator for supervised fine-tuning.
- Datasets are expected to have the following columns: { `input`, `output` }
+ """Make dataset and collator for supervised fine-tuning. Datasets are
+ expected to have the following columns: { `input`, `output` }
Available datasets to be selected with `dataset` argument:
- alpaca, 52002 examples
@@ -456,7 +449,6 @@ def make_data_module(args):
- supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
- flan (FLAN v2), up to 20M examples available
- vicuna
-
"""
train_datasets: List[Dataset] = []
eval_datasets: List[Dataset] = []
diff --git a/chatllms/data/sft_dataset.py b/chatllms/data/sft_dataset.py
index 71383dd..a4c0177 100644
--- a/chatllms/data/sft_dataset.py
+++ b/chatllms/data/sft_dataset.py
@@ -16,8 +16,7 @@
class SFTInstructionDataset(Dataset):
- """
- Dataset for supervised fine-tuning of instruction following models.
+ """Dataset for supervised fine-tuning of instruction following models.
Converts raw dataset containing source/target instructions
into tokenized input/target pairs with truncation and padding.
@@ -26,14 +25,13 @@ class SFTInstructionDataset(Dataset):
dataset: The raw dataset containing source/target examples
tokenizer: Tokenizer to use for encoding text
max_seq_len: Maximum sequence length for truncation
-
"""
+
def __init__(self,
raw_data: DatasetDict,
tokenizer: PreTrainedTokenizer,
max_seq_len: int = 1024):
- """
- Initialize the dataset with the raw data and tokenizer.
+ """Initialize the dataset with the raw data and tokenizer.
Args:
raw_data: Raw dataset containing source/target examples
@@ -45,12 +43,11 @@ def __init__(self,
self.max_seq_len = max_seq_len
def __len__(self) -> int:
- """Return number of examples in dataset"""
+ """Return number of examples in dataset."""
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
- """
- Convert an raw example into tokenized input/target pair.
+ """Convert an raw example into tokenized input/target pair.
Args:
idx: Index of the example in the dataset
@@ -107,14 +104,15 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.
- Args:
- hf_dataset (dataset): The preprocesed dataset to load.
- tokenizer (PreTrainedTokenizer): The tokenizer to use when tokenizing the data.
- source_max_len (int): The maximum length allowed for the source text.
- target_max_len (int): The maximum length allowed for the target text.
- train_on_source (bool): If True, the model will be trained on the source text as well as the target text.
- predict_with_generate (bool): If True, the model will generate predictions instead of training.
+ Args:
+ hf_dataset (dataset): The preprocesed dataset to load.
+ tokenizer (PreTrainedTokenizer): The tokenizer to use when tokenizing the data.
+ source_max_len (int): The maximum length allowed for the source text.
+ target_max_len (int): The maximum length allowed for the target text.
+ train_on_source (bool): If True, the model will be trained on the source text as well as the target text.
+ predict_with_generate (bool): If True, the model will generate predictions instead of training.
"""
+
def __init__(
self,
hf_dataset: datasets.DatasetDict,
@@ -185,9 +183,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
@dataclass
class DataCollatorForSupervisedDataset:
- """
- Collate and pad examples for supervised training.
- """
+ """Collate and pad examples for supervised training."""
tokenizer: PreTrainedTokenizer
predict_with_generate: bool = False
@@ -196,8 +192,7 @@ def __call__(
self,
examples: List[Dict[str,
torch.Tensor]]) -> Dict[str, torch.Tensor]:
- """
- Collate examples into dictionary for supervised training.
+ """Collate examples into dictionary for supervised training.
Args:
examples: List of examples, each containing 'input_ids' and 'labels'
diff --git a/chatllms/data/utils/conversation.py b/chatllms/data/utils/conversation.py
index 1ef229a..bcbf4ea 100644
--- a/chatllms/data/utils/conversation.py
+++ b/chatllms/data/utils/conversation.py
@@ -1,6 +1,4 @@
-"""
-Conversation prompt templates.
-"""
+"""Conversation prompt templates."""
import dataclasses
from enum import Enum, auto
@@ -25,7 +23,8 @@ class SeparatorStyle(Enum):
@dataclasses.dataclass
class Conversation:
- """A class that manages prompt templates and keeps all conversation history."""
+ """A class that manages prompt templates and keeps all conversation
+ history."""
# The name of this template
name: str
@@ -163,8 +162,9 @@ def append_message(self, role: str, message: str):
def update_last_message(self, message: str):
"""Update the last output.
- The last message is typically set to be None when constructing the prompt,
- so we need to update it in-place after getting the response from a model.
+ The last message is typically set to be None when constructing the
+ prompt, so we need to update it in-place after getting the response
+ from a model.
"""
self.messages[-1][1] = message
diff --git a/chatllms/data/utils/convert_alpaca.py b/chatllms/data/utils/convert_alpaca.py
index 8af5520..4513072 100644
--- a/chatllms/data/utils/convert_alpaca.py
+++ b/chatllms/data/utils/convert_alpaca.py
@@ -1,5 +1,4 @@
-"""
-Convert alpaca dataset into sharegpt format.
+"""Convert alpaca dataset into sharegpt format.
Usage: python3 -m chatllms.data.convert_alpaca --in alpaca_data.json
"""
diff --git a/chatllms/data/utils/convert_oasst1.py b/chatllms/data/utils/convert_oasst1.py
index a6f06ae..35a538d 100644
--- a/chatllms/data/utils/convert_oasst1.py
+++ b/chatllms/data/utils/convert_oasst1.py
@@ -167,9 +167,9 @@ def test_add_open_assistant(fixup_personality,
only_personality,
deberta_grading,
save_json=True):
- """
- Flatten tree structure into one row per path from root to leaf
- Also turn into human_bot prompting format:
+ """Flatten tree structure into one row per path from root to leaf Also turn
+ into human_bot prompting format:
+
: question\n: answer : question2\n: answer2 Etc.
Also saves a .json locally as side-effect
returns list of dicts, containing intput, prompt_type and source
diff --git a/chatllms/data/vicuna_dataset.py b/chatllms/data/vicuna_dataset.py
index 392d6fe..f5fcbf1 100644
--- a/chatllms/data/vicuna_dataset.py
+++ b/chatllms/data/vicuna_dataset.py
@@ -49,8 +49,7 @@ def tokenize_conversations(
conversations: List[str],
tokenizer: PreTrainedTokenizer,
) -> torch.Tensor:
- """Tokenize conversations
- """
+ """Tokenize conversations."""
input_ids = tokenizer(
conversations,
@@ -70,7 +69,9 @@ def mask_targets(
tokenizer: PreTrainedTokenizer,
conv: Conversation,
) -> None:
- """Mask targets. Only compute loss on the assistant outputs.
+ """Mask targets.
+
+ Only compute loss on the assistant outputs.
"""
# Mask targets
@@ -108,8 +109,7 @@ def mask_targets(
def preprocess(sources: Sequence[Dict[str, str]],
tokenizer: PreTrainedTokenizer) -> Dict[str, List[int]]:
- """
- Preprocesses the data by tokenizing it.
+ """Preprocesses the data by tokenizing it.
Args:
sources (Sequence[Dict[str, str]]): List of conversation sources.
@@ -131,13 +131,13 @@ def preprocess(sources: Sequence[Dict[str, str]],
class SupervisedDataset(Dataset):
- """
- Dataset for supervised fine-tuning.
+ """Dataset for supervised fine-tuning.
Args:
raw_data (List[Dict]): Raw input data.
tokenizer (PreTrainedTokenizer): Tokenizer for preprocessing the data.
"""
+
def __init__(self, raw_data: List[Dict[str, List[str]]],
tokenizer: PreTrainedTokenizer) -> None:
super().__init__()
@@ -160,8 +160,7 @@ def __len__(self) -> int:
return len(self.input_ids)
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
- """
- Get an example from the dataset at the specified index.
+ """Get an example from the dataset at the specified index.
Args:
index (int): Index of the example to retrieve.
@@ -177,16 +176,14 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
class LazySupervisedDataset(Dataset):
- """
- Dataset for supervised fine-tuning.
- """
+ """Dataset for supervised fine-tuning."""
+
def __init__(
self,
raw_data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
):
- """
- Initialize the LazySupervisedDataset.
+ """Initialize the LazySupervisedDataset.
Args:
raw_data (List[Dict[str, str]]): The raw input data for the dataset.
@@ -198,8 +195,7 @@ def __init__(
self.cached_data_dict: Dict[int, Dict[str, torch.Tensor]] = {}
def __len__(self) -> int:
- """
- Get the length of the dataset.
+ """Get the length of the dataset.
Returns:
int: The length of the dataset.
@@ -207,8 +203,7 @@ def __len__(self) -> int:
return len(self.raw_data)
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
- """
- Get an item from the dataset at the given index.
+ """Get an item from the dataset at the given index.
Args:
i (int): The index of the item to retrieve.
@@ -240,8 +235,7 @@ def make_conversation_data_module(
lazy_preprocess: bool,
data_path: str,
) -> Dict[str, Dataset]:
- """
- Make dataset and collator for supervised fine-tuning.
+ """Make dataset and collator for supervised fine-tuning.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer object.
@@ -250,7 +244,6 @@ def make_conversation_data_module(
Returns:
dict: A dictionary containing the train_dataset and eval_dataset.
-
"""
# Determine the appropriate dataset class based on lazy_preprocess flag
diff --git a/chatllms/evaluation/evaluate_zh.py b/chatllms/evaluation/evaluate_zh.py
index f23d611..ca2c26e 100644
--- a/chatllms/evaluation/evaluate_zh.py
+++ b/chatllms/evaluation/evaluate_zh.py
@@ -95,8 +95,7 @@ def __init__(
data_path: str = 'ceval/ceval-exam',
output_dir: str = 'ceval_output',
) -> None:
- """
- Initialize the CEval object.
+ """Initialize the CEval object.
Args:
model (PreTrainedModel): Pre-trained model for question answering.
@@ -112,8 +111,7 @@ def __init__(
self.output_dir = output_dir
def run(self, shot: int, split: str) -> None:
- """
- Run the evaluation for all tasks.
+ """Run the evaluation for all tasks.
Args:
shot (int): Number of additional examples to include in the prompt.
@@ -144,8 +142,7 @@ def run(self, shot: int, split: str) -> None:
def run_single_task(self, task_name: str, shot: int,
split: str) -> Tuple[List[Dict[str, str]], float]:
- """
- Run the evaluation for a single task.
+ """Run the evaluation for a single task.
Args:
task_name (str): Name of the task.
@@ -198,8 +195,7 @@ def run_single_task(self, task_name: str, shot: int,
def build_example(self,
data: Dict[str, str],
with_answer: bool = True) -> str:
- """
- Builds an example string based on the given data.
+ """Builds an example string based on the given data.
Args:
data (Dict[str, str]): A dictionary containing the question, choices, and answer.
diff --git a/chatllms/model/compute_metrics.py b/chatllms/model/compute_metrics.py
index 6488c39..2b311d0 100644
--- a/chatllms/model/compute_metrics.py
+++ b/chatllms/model/compute_metrics.py
@@ -10,14 +10,14 @@
@dataclass
class ComputeMetrics:
- """
- Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
- Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
+ """Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+ Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""
+
def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
- """
- Initialize the ComputeMetrics class with a pre-trained tokenizer object.
+ """Initialize the ComputeMetrics class with a pre-trained tokenizer
+ object.
Args:
tokenizer (PreTrainedTokenizer): A pre-trained tokenizer object to be used for decoding tokenized sequences.
@@ -27,8 +27,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
def __call__(
self, eval_preds: List[Union[np.ndarray, Tuple[np.ndarray]]]
) -> Dict[str, float]:
- """
- Computes evaluation metrics for model predictions.
+ """Computes evaluation metrics for model predictions.
Args:
eval_preds (List[Union[np.ndarray, Tuple[np.ndarray]]]): List of tuples containing prediction and label arrays.
diff --git a/chatllms/model/llm_perplexity.py b/chatllms/model/llm_perplexity.py
index 2da2fe7..30a1c79 100644
--- a/chatllms/model/llm_perplexity.py
+++ b/chatllms/model/llm_perplexity.py
@@ -18,8 +18,7 @@
class LLMPerplexity:
- """
- Language model to compute perplexity.
+ """Language model to compute perplexity.
Args:
cache_dir (str): Directory to cache models.
@@ -31,6 +30,7 @@ class LLMPerplexity:
fp16 (bool): Whether to use 16-bit precision.
device (str): Device to load model to.
"""
+
def __init__(
self,
cache_dir: str = None,
@@ -93,8 +93,7 @@ def __init__(
def get_perplexity(self,
input_texts: Union[str, List[str]],
batch_size: int = None) -> Union[float, List[float]]:
- """
- Compute perplexity on input text(s).
+ """Compute perplexity on input text(s).
Args:
input_texts (Union[str, List[str]]): Input text(s) to compute perplexity for.
diff --git a/chatllms/model/load_pretrain_model.py b/chatllms/model/load_pretrain_model.py
index b917bad..a8e5c38 100644
--- a/chatllms/model/load_pretrain_model.py
+++ b/chatllms/model/load_pretrain_model.py
@@ -27,9 +27,8 @@ def load_model_tokenizer(
is_trainable: Optional[bool] = True,
logger=None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
- """
- Returns a language model and tokenizer for text generation that can be trained with mixed precision.
- Support both training and inference.
+ """Returns a language model and tokenizer for text generation that can be
+ trained with mixed precision. Support both training and inference.
Args:
args: A dictionary containing various hyperparameters.
diff --git a/chatllms/model/mmlueval_callback.py b/chatllms/model/mmlueval_callback.py
index 839b5b1..2e6e9c2 100644
--- a/chatllms/model/mmlueval_callback.py
+++ b/chatllms/model/mmlueval_callback.py
@@ -17,10 +17,9 @@
@dataclass
class MMLUEvalCallback(TrainerCallback):
- """
- A callback function called after each evaluation step during training to evaluate \
- the performance of a model on an
- MMLU (Mean Length of Utterance) dataset.
+ """A callback function called after each evaluation step during training to
+ evaluate the performance of a model on an MMLU (Mean Length of Utterance)
+ dataset.
Args:
trainer (Trainer): The trainer instance to be used.
@@ -28,6 +27,7 @@ class MMLUEvalCallback(TrainerCallback):
data_dir (str): The directory where the MMLU dataset is stored.
args (argparse.Namespace): The command line arguments for the current run.
"""
+
def __init__(
self,
trainer: 'Trainer',
@@ -98,8 +98,8 @@ def on_evaluate(
model: PreTrainedModel,
**kwargs: Any,
) -> None:
- """
- Iterate over the batches of the evaluation dataset and make predictions for MMLU.
+ """Iterate over the batches of the evaluation dataset and make
+ predictions for MMLU.
Args:
args (Dict[str, Any]): Dictionary containing the evaluation arguments.
diff --git a/chatllms/model/sample_generate_callback.py b/chatllms/model/sample_generate_callback.py
index 6c453fa..10a2788 100644
--- a/chatllms/model/sample_generate_callback.py
+++ b/chatllms/model/sample_generate_callback.py
@@ -7,13 +7,14 @@
@dataclass
class SampleGenerateCallback(TrainerCallback):
- """
- A callback that generates text samples from a pre-trained language model during training.
+ """A callback that generates text samples from a pre-trained language model
+ during training.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer used to preprocess inputs.
max_new_tokens (int): The maximum number of tokens to generate in response to each input.
"""
+
def __init__(self, tokenizer: PreTrainedTokenizer,
generation_config: argparse.Namespace, logger: None):
self.tokenizer = tokenizer
@@ -49,8 +50,7 @@ def __init__(self, tokenizer: PreTrainedTokenizer,
def on_evaluate(self, args: Any, state: Dict[str, Any], control: Any,
**kwargs: Any) -> None:
- """
- Generates text samples from the language model during evaluation.
+ """Generates text samples from the language model during evaluation.
Args:
args (Any): Trainer arguments, not used in this method.
diff --git a/chatllms/model/save_peft_model_callback.py b/chatllms/model/save_peft_model_callback.py
index 2c93e8c..fb37f12 100644
--- a/chatllms/model/save_peft_model_callback.py
+++ b/chatllms/model/save_peft_model_callback.py
@@ -7,16 +7,15 @@
class SavePeftModelCallback(TrainerCallback):
- """
- Callback to save PEFT model checkpoints during training.
+ """Callback to save PEFT model checkpoints during training.
Saves both the full model and the adapter model to separate directories
within the checkpoint directory.
"""
+
def save_model(self, args: Any, state: TrainingArguments,
kwargs: Dict[str, Any]) -> None:
- """
- Saves the PEFT model checkpoint.
+ """Saves the PEFT model checkpoint.
Args:
args (Any): The command line arguments passed to the script.
@@ -52,8 +51,8 @@ def save_model(self, args: Any, state: TrainingArguments,
def on_save(self, args: Any, state: TrainingArguments,
control: TrainerControl,
**kwargs: Dict[str, Any]) -> TrainerControl:
- """
- Callback method that calls save_model() and returns `control` argument.
+ """Callback method that calls save_model() and returns `control`
+ argument.
Args:
args (Any): The command line arguments passed to the script.
@@ -74,8 +73,8 @@ def on_save(self, args: Any, state: TrainingArguments,
def on_train_end(self, args: Any, state: TrainingArguments,
control: TrainerControl, **kwargs: Dict[str,
Any]) -> None:
- """
- Callback method that saves the model checkpoint and creates a 'completed' file in the output directory.
+ """Callback method that saves the model checkpoint and creates a
+ 'completed' file in the output directory.
Args:
args (Any): The command line arguments passed to the script.
diff --git a/chatllms/train/training.py b/chatllms/train/training.py
index 0a2a2c8..c473969 100644
--- a/chatllms/train/training.py
+++ b/chatllms/train/training.py
@@ -11,8 +11,7 @@
def train_and_evaluate(trainer: transformers.Trainer, args: argparse.Namespace,
logger: None) -> None:
- """
- Trains and evaluates a machine learning model.
+ """Trains and evaluates a machine learning model.
Args:
trainer (Trainer): The training object to use for training and evaluation.
@@ -75,10 +74,8 @@ def predict_and_save(trainer: transformers.Trainer,
tokenizer: transformers.PreTrainedTokenizer,
predict_dataset: Dataset, args: argparse.Namespace,
logger: None) -> None:
- """
- Make predictions on new data, save them to a file along with input examples,
- and update the overall metrics.
- """
+ """Make predictions on new data, save them to a file along with input
+ examples, and update the overall metrics."""
logger.info('=' * 80)
logger.info('*** Predict ***')
logger.info('=' * 80)
diff --git a/chatllms/utils/apply_lora.py b/chatllms/utils/apply_lora.py
index b46639e..0291f4e 100644
--- a/chatllms/utils/apply_lora.py
+++ b/chatllms/utils/apply_lora.py
@@ -1,5 +1,4 @@
-"""
-Apply the LoRA weights on top of a base model.
+"""Apply the LoRA weights on top of a base model.
Usage:
python3 apply_lora.py --base_model_path ~/model_weights/llama-7b --target_model_path ~/model_weights/baize-7b \
@@ -24,7 +23,8 @@ def apply_lora(
use_auth_token: str = True,
trust_remote_code: bool = True,
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
- """Applies the LoRA adapter to a base model and saves the resulting target model (optional).
+ """Applies the LoRA adapter to a base model and saves the resulting target
+ model (optional).
Args:
base_model_path (str): The path to the base model to which the LoRA adapter will be applied.
@@ -36,7 +36,6 @@ def apply_lora(
Returns:
Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer.
-
"""
# Load the base model and tokenizer
print(f'Loading the base model from {base_model_path}')
diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py
index a053f5d..af4b092 100644
--- a/chatllms/utils/model_utils.py
+++ b/chatllms/utils/model_utils.py
@@ -9,19 +9,18 @@
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from transformers.trainer_utils import get_last_checkpoint
-from transformers.generation.logits_process import LogitsProcessor
-from transformers.generation.utils import LogitsProcessorList
+
from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)
def add_special_tokens_if_missing(tokenizer: PreTrainedTokenizer,
model: PreTrainedModel) -> None:
- """
- If 'llama' or 'baichuan' is in the model name or path, check if the special tokens are set correctly.
- Add any missing special tokens to prevent them from being parsed into different tokens.
- Note that these special tokens are present in the vocabulary.
- Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
+ """If 'llama' or 'baichuan' is in the model name or path, check if the
+ special tokens are set correctly. Add any missing special tokens to prevent
+ them from being parsed into different tokens. Note that these special
+ tokens are present in the vocabulary. Note also that
+ `model.config.pad_token_id` is 0 which corresponds to `` token.
Args:
tokenizer: The pre-trained tokenizer.
@@ -54,8 +53,7 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict[str, str],
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel) -> None:
"""Resize tokenizer and embedding to accommodate new special tokens.
- 改变tokenizer和embedding的尺寸。
- 一般需要将tokenizer和embedding的尺寸设置为64的倍数,方便GPU加速。
+ 改变tokenizer和embedding的尺寸。 一般需要将tokenizer和embedding的尺寸设置为64的倍数,方便GPU加速。
Args:
special_tokens_dict (Dict[str, str]): A dictionary of special tokens to be added to the tokenizer.
@@ -99,11 +97,9 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict[str, str],
def find_all_linear_names(args: argparse.Namespace,
model: torch.nn.Module) -> List[str]:
- """
- Returns a list of names of all linear layers present in the given model.
+ """Returns a list of names of all linear layers present in the given model.
如果args.bits是4,使用bitsandbytes库中的bnb.nn.Linear4bit层;
- 如果args.bits是8,使用bitsandbytes库中的bnb.nn.Linear8bitLt层;
- 否则,使用torch.nn.Linear层;
+ 如果args.bits是8,使用bitsandbytes库中的bnb.nn.Linear8bitLt层; 否则,使用torch.nn.Linear层;
并记录下这些层的名称,保存在lora_module_names集合中。
Args:
@@ -154,8 +150,7 @@ def find_all_linear_names(args: argparse.Namespace,
def print_trainable_parameters(args: argparse.Namespace,
model: torch.nn.Module) -> None:
- """
- Prints the number of trainable parameters in the given model.
+ """Prints the number of trainable parameters in the given model.
Args:
args (argparse.Namespace): A namespace containing arguments of the script. Must contain the 'bits' argument.
@@ -198,8 +193,7 @@ def print_trainable_parameters(args: argparse.Namespace,
def verify_dtypes(model: torch.nn.Module) -> None:
- """
- 检查模型参数的数据类型,并输出各个数据类型在这些张量中所占的比例.
+ """检查模型参数的数据类型,并输出各个数据类型在这些张量中所占的比例.
:param model: 待检查的模型.
:return: 无返回值.
@@ -225,9 +219,9 @@ def verify_dtypes(model: torch.nn.Module) -> None:
def check_training_finished(args: argparse.Namespace,
logger=None) -> Tuple[str, bool]:
- """
- Given a directory containing previous saved checkpoints, returns the path to the last checkpoint
- if available along with a boolean flag indicating whether training has already been completed.
+ """Given a directory containing previous saved checkpoints, returns the
+ path to the last checkpoint if available along with a boolean flag
+ indicating whether training has already been completed.
Args:
checkpoint_dir (str): Path to the directory containing the saved checkpoints.
@@ -276,24 +270,10 @@ def find_last_checkpoint(checkpoint_dir):
last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
return last_checkpoint
-# Avoid runtime error in model.generate(do_sample=True).
-class InvalidScoreLogitsProcessor(LogitsProcessor):
- def __call__(self, input_ids: torch.LongTensor,
- scores: torch.FloatTensor) -> torch.FloatTensor:
- if torch.isnan(scores).any() or torch.isinf(scores).any():
- scores.zero_()
- scores[..., 0] = 1.0
- return scores
-
-
-def get_logits_processor() -> LogitsProcessorList:
- logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
- return logits_processor
-
# Avoid runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
+
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
diff --git a/chatllms/utils/stream_server.py b/chatllms/utils/stream_server.py
index 66bf42f..93288b3 100644
--- a/chatllms/utils/stream_server.py
+++ b/chatllms/utils/stream_server.py
@@ -1,5 +1,5 @@
-"""
-Helpers to support streaming generate output.
+"""Helpers to support streaming generate output.
+
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
"""
import traceback
@@ -10,6 +10,7 @@
class Stream(transformers.StoppingCriteria):
+
def __init__(self, callback_func=None):
self.callback_func = callback_func
@@ -20,10 +21,9 @@ def __call__(self, input_ids, scores) -> bool:
class Iteratorize:
- """
- Transforms a function that takes a callback
- into a lazy iterator (generator).
- """
+ """Transforms a function that takes a callback into a lazy iterator
+ (generator)."""
+
def __init__(self, func, kwargs={}, callback=None):
self.mfunc = func
self.c_callback = callback
diff --git a/chatllms/utils/template.py b/chatllms/utils/template.py
index a96b8fa..e865331 100644
--- a/chatllms/utils/template.py
+++ b/chatllms/utils/template.py
@@ -7,8 +7,7 @@
@dataclass
class PromptTemplate(object):
- """
- A template for formatting a conversation prompt.
+ """A template for formatting a conversation prompt.
Args:
name: Name of template
@@ -16,7 +15,6 @@ class PromptTemplate(object):
prompt: Prompt text
sep: Separator between prompts
use_history: Whether to use conversation history
-
"""
name: str
@@ -31,8 +29,7 @@ def get_prompt(
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
) -> str:
- """
- Returns a string containing prompt without response.
+ """Returns a string containing prompt without response.
Args:
query (str): The input query text.
@@ -67,8 +64,7 @@ def format_example(self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None) -> List[str]:
- """
- Formats the conversation example.
+ """Formats the conversation example.
Args:
query (str): The input query text.
@@ -100,8 +96,7 @@ def register_template(
sep: str,
use_history: Optional[bool] = True,
) -> None:
- """
- Registers a new conversation template.
+ """Registers a new conversation template.
Args:
prefix (str): The prefix text for the prompt.
@@ -116,13 +111,9 @@ def register_template(
self.use_history = use_history
def __post_init__(self):
- """
- Initializes the instance of the class.
- """
+ """Initializes the instance of the class."""
if self.name == 'default':
- """
- Supports language model inference without histories.
- """
+ """Supports language model inference without histories."""
self.register_template(name='vanilla',
prefix='',
prompt='{query}',
diff --git a/cli_demo.py b/cli_demo.py
index 959be76..85e34f0 100644
--- a/cli_demo.py
+++ b/cli_demo.py
@@ -22,8 +22,8 @@ def generate_response(
model: PreTrainedModel,
generation_args: dict,
) -> List[str]:
- """
- Generates a response to the given query using GPT-3.5 model and prints it to the console.
+ """Generates a response to the given query using GPT-3.5 model and prints
+ it to the console.
Args:
query (str): The input query for which a response is to be generated.
diff --git a/data/README.md b/data/README.md
index 5612380..eff7c4f 100644
--- a/data/README.md
+++ b/data/README.md
@@ -1,4 +1,3 @@
-
# How to use the data
## Datasets Supported by the Framework
@@ -44,7 +43,6 @@ We provide the following datasets for the experiments in this framework.
数据集说明:开源了数据规模为145k的价值对齐数据集,该数据集对于每个prompt包括了拒绝&正向建议,(safe and reponsibility) > 拒绝为主(safe) > 风险回复(unsafe)三种类型,可用于增强SFT模型的安全性或用于训练reward模型。
- [CValues-Comparison中文大模型价值观比较数据集](https://modelscope.cn/datasets/damo/CValues-Comparison/summary)
-
## Dataset formation
The `dataset_info.yaml` file contains all the datasets can be used in the experiments. The following is the format of the datasets, main including the following fields.
diff --git a/data/alpaca_zh_pcyn.yaml b/data/alpaca_zh_pcyn.yaml
deleted file mode 100644
index 84b6a01..0000000
--- a/data/alpaca_zh_pcyn.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# The dataset_info.yaml file contains the information of the datasets used in the experiments.
-coig:
- hf_hub_url: BAAI/COIG
- local_path: /userhome/jianzhnie/prompt_data/COIG/train_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-cvalues_comparison_train:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/train_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-cvalues_comparison_test:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/test_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-olcc:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/olcc/olcc_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-100PoisonMpts:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/100PoisonMpts/train_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-safety_prompt_part1:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json
- dataset_format: alpaca
- multi_turn: False
-
-safety_prompt_part2:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
- dataset_format: alpaca
- multi_turn: False
diff --git a/data/dataset_info.py b/data/dataset_info.py
index 7f38e6c..62c96f5 100644
--- a/data/dataset_info.py
+++ b/data/dataset_info.py
@@ -2,9 +2,9 @@
def get_dataset_info(dataset_dir):
- """
- Returns the datasets info to a dataset based on a pre-defined map of dataset names to their corresponding URLs on the internet
- or local file paths.
+ """Returns the datasets info to a dataset based on a pre-defined map of
+ dataset names to their corresponding URLs on the internet or local file
+ paths.
Args:
dataset_dir (str): The local directory where the dataset is stored; this is used for datasets that are stored locally.
diff --git a/data/vicuna_zh_pcyn.yaml b/data/vicuna_zh_pcyn.yaml
deleted file mode 100644
index 610d8e1..0000000
--- a/data/vicuna_zh_pcyn.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# The dataset_info.yaml file contains the information of the datasets used in the experiments.
-coig:
- hf_hub_url: BAAI/COIG
- local_path: /userhome/jianzhnie/prompt_data/COIG/train_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-cvalues_comparison_train:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/train_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-cvalues_comparison_test:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/test_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-olcc:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/olcc/olcc_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-100PoisonMpts:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/100PoisonMpts/train_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-safety_prompt_part1:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/attack_scenarios_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
-
-safety_prompt_part2:
- hf_hub_url: ''
- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json
- dataset_format: sharegpt
- multi_turn: True
diff --git a/examples/clean_sharegpt/clean_sharegpt.py b/examples/clean_sharegpt/clean_sharegpt.py
index 017fecb..ed4fab7 100644
--- a/examples/clean_sharegpt/clean_sharegpt.py
+++ b/examples/clean_sharegpt/clean_sharegpt.py
@@ -99,8 +99,8 @@ def format_roles(
def filter_invalid_roles(
raw_data: List[Dict[str,
any]]) -> List[Dict[str, List[Dict[str, any]]]]:
- """
- Filter out invalid contents based on the roles assigned to each conversation.
+ """Filter out invalid contents based on the roles assigned to each
+ conversation.
Args:
raw_data: A list of dictionaries containing conversation data.
diff --git a/examples/clean_sharegpt/hardcoded_questions.py b/examples/clean_sharegpt/hardcoded_questions.py
index cb89c63..abafaf2 100644
--- a/examples/clean_sharegpt/hardcoded_questions.py
+++ b/examples/clean_sharegpt/hardcoded_questions.py
@@ -2,9 +2,8 @@
def identity_questions():
- """ "
- Adopted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py
- """
+ """" Adopted from https://github.com/young-
+ geng/koala_data_pipeline/blob/main/process_hard_coded_data.py."""
content = []
name = 'Vicuna'
diff --git a/examples/clean_sharegpt/merge.py b/examples/clean_sharegpt/merge.py
index bdf2b40..727b9a8 100644
--- a/examples/clean_sharegpt/merge.py
+++ b/examples/clean_sharegpt/merge.py
@@ -1,5 +1,4 @@
-"""
-Merge two conversation files into one
+"""Merge two conversation files into one.
Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
"""
diff --git a/examples/clean_sharegpt/split_long_conversation.py b/examples/clean_sharegpt/split_long_conversation.py
index e14b952..1e551c6 100644
--- a/examples/clean_sharegpt/split_long_conversation.py
+++ b/examples/clean_sharegpt/split_long_conversation.py
@@ -1,5 +1,4 @@
-"""
-Split long conversations based on certain max length.
+"""Split long conversations based on certain max length.
Usage: python3 -m split_long_conversation.py \
--in sharegpt_clean.json \
@@ -23,8 +22,8 @@
def make_sample(sample: Dict[str, any], start_idx: int,
end_idx: int) -> Dict[str, any]:
- """
- Create a new sample dictionary by selecting conversations from the given sample.
+ """Create a new sample dictionary by selecting conversations from the given
+ sample.
Args:
sample (Dict[str, any]): The original sample dictionary.
@@ -43,8 +42,8 @@ def make_sample(sample: Dict[str, any], start_idx: int,
def split_one_sample(sample: Dict[str, any]) -> List[Dict[str, any]]:
- """
- Split a single sample into multiple samples based on conversation lengths.
+ """Split a single sample into multiple samples based on conversation
+ lengths.
Args:
sample (Dict[str, any]): The original sample dictionary.
@@ -94,8 +93,8 @@ def worker(input_data: List[Dict[str, Any]]):
def split_all(raw_data: List[Dict[str, Any]],
tokenizer_: transformers.PreTrainedTokenizer,
max_length_: int) -> List[Dict[str, Any]]:
- """
- Split the content into smaller parts based on the max token length constraint.
+ """Split the content into smaller parts based on the max token length
+ constraint.
Args:
raw_data (List[Dict[str, Any]]): The list of samples to split.
diff --git a/examples/finetune_llm/baichuan7b_demo.py b/examples/finetune_llm/baichuan7b_demo.py
index 73169e0..cfdca1b 100644
--- a/examples/finetune_llm/baichuan7b_demo.py
+++ b/examples/finetune_llm/baichuan7b_demo.py
@@ -19,5 +19,5 @@ def main(load_in_8bit=True, model_path=''):
if __name__ == '__main__':
load_in_8bit = True
- model_path = '/home/robin/work_dir/llm/llm_pretrain_model/baichuan'
+ model_path = 'baichuan'
main(load_in_8bit, model_path)
diff --git a/examples/finetune_llm/finetune_llama_with_qlora.py b/examples/finetune_llm/finetune_llama_with_qlora.py
index d9dd3f2..a7bc1ee 100644
--- a/examples/finetune_llm/finetune_llama_with_qlora.py
+++ b/examples/finetune_llm/finetune_llama_with_qlora.py
@@ -15,9 +15,7 @@
def print_trainable_parameters(model: AutoModelForCausalLM) -> None:
- """
- Prints the number of trainable parameters in the model.
- """
+ """Prints the number of trainable parameters in the model."""
trainable_params, all_param = 0, 0
for _, param in model.named_parameters():
all_param += param.numel()
diff --git a/examples/finetune_llm/qlora_int4_finetune.py b/examples/finetune_llm/qlora_int4_finetune.py
index 123029c..3e7d46a 100644
--- a/examples/finetune_llm/qlora_int4_finetune.py
+++ b/examples/finetune_llm/qlora_int4_finetune.py
@@ -284,6 +284,7 @@ def find_all_linear_names(args, model):
class SavePeftModelCallback(transformers.TrainerCallback):
+
def save_model(self, args, state, kwargs):
logger.info('Saving PEFT checkpoint...')
if state.best_model_checkpoint is not None:
@@ -307,6 +308,7 @@ def on_save(self, args, state, control, **kwargs):
return control
def on_train_end(self, args, state, control, **kwargs):
+
def touch(fname, times=None):
with open(fname, 'a'):
os.utime(fname, times)
@@ -408,9 +410,7 @@ def get_accelerate_model(args, checkpoint_dir):
def print_trainable_parameters(args, model):
- """
- Prints the number of trainable parameters in the model.
- """
+ """Prints the number of trainable parameters in the model."""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
@@ -576,9 +576,8 @@ def local_dataset(dataset_name):
def make_data_module(tokenizer: transformers.PreTrainedTokenizer,
args) -> Dict:
- """
- Make dataset and collator for supervised fine-tuning.
- Datasets are expected to have the following columns: { `input`, `output` }
+ """Make dataset and collator for supervised fine-tuning. Datasets are
+ expected to have the following columns: { `input`, `output` }
Available datasets to be selected with `dataset` argument:
- alpaca, 52002 examples
@@ -597,8 +596,8 @@ def make_data_module(tokenizer: transformers.PreTrainedTokenizer,
- supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
- flan (FLAN v2), up to 20M examples available
- vicuna
-
"""
+
def load_data(dataset_name):
if dataset_name == 'alpaca':
return load_dataset('tatsu-lab/alpaca')
@@ -832,6 +831,7 @@ def train():
accuracy = evaluate.load('accuracy')
class MMLUEvalCallback(transformers.TrainerCallback):
+
def on_evaluate(self, args, state, control, model, **kwargs):
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
source_max_len = trainer.data_collator.source_max_len
diff --git a/examples/format_data/convert_oasst1.py b/examples/format_data/convert_oasst1.py
index 31a29e8..ac965dc 100644
--- a/examples/format_data/convert_oasst1.py
+++ b/examples/format_data/convert_oasst1.py
@@ -16,12 +16,14 @@ def json_load(in_file):
def convert_oasst1_data(data_dir, output_dir):
- '''
- For OASST1, because it's in a tree structure, where every user input might get multiple replies,
- we have to save every path from the root node to the assistant reply (including both leaf node and intemediate node).
- This results in some of the messages being duplicated among different paths (instances).
- Be careful when using this dataset for training. Ideally, you should only minimize the loss of the last message in each path.
- '''
+ """For OASST1, because it's in a tree structure, where every user input
+ might get multiple replies, we have to save every path from the root node
+ to the assistant reply (including both leaf node and intemediate node).
+
+ This results in some of the messages being duplicated among different paths
+ (instances). Be careful when using this dataset for training. Ideally, you
+ should only minimize the loss of the last message in each path.
+ """
conversations = []
with open(os.path.join(data_dir, '2023-04-12_oasst_ready.trees.jsonl'),
'r') as fin:
diff --git a/examples/format_data/merge.py b/examples/format_data/merge.py
index 8da9846..84324e7 100644
--- a/examples/format_data/merge.py
+++ b/examples/format_data/merge.py
@@ -1,5 +1,4 @@
-"""
-Merge two conversation files into one
+"""Merge two conversation files into one.
Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
"""
diff --git a/examples/vllm/apil_chient.py b/examples/vllm/apil_chient.py
deleted file mode 100644
index d3ba848..0000000
--- a/examples/vllm/apil_chient.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""Example Python client for vllm.entrypoints.api_server"""
-
-import argparse
-import json
-from typing import Iterable, List
-
-import requests
-
-
-def clear_line(n: int = 1) -> None:
- LINE_UP = '\033[1A'
- LINE_CLEAR = '\x1b[2K'
- for _ in range(n):
- print(LINE_UP, end=LINE_CLEAR, flush=True)
-
-
-def post_http_request(prompt: str,
- api_url: str,
- n: int = 1,
- stream: bool = False) -> requests.Response:
- headers = {'User-Agent': 'Test Client'}
- pload = {
- 'prompt': prompt,
- 'n': n,
- 'use_beam_search': True,
- 'temperature': 0.0,
- 'max_tokens': 16,
- 'stream': stream,
- }
- response = requests.post(api_url, headers=headers, json=pload, stream=True)
- return response
-
-
-def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
- for chunk in response.iter_lines(chunk_size=8192,
- decode_unicode=False,
- delimiter=b'\0'):
- if chunk:
- data = json.loads(chunk.decode('utf-8'))
- output = data['text']
- yield output
-
-
-def get_response(response: requests.Response) -> List[str]:
- data = json.loads(response.content)
- output = data['text']
- return output
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--host', type=str, default='localhost')
- parser.add_argument('--port', type=int, default=8000)
- parser.add_argument('--n', type=int, default=4)
- parser.add_argument('--prompt', type=str, default='San Francisco is a')
- parser.add_argument('--stream', action='store_true')
- args = parser.parse_args()
- prompt = args.prompt
- api_url = f'http://{args.host}:{args.port}/generate'
- n = args.n
- stream = args.stream
-
- print(f'Prompt: {prompt!r}\n', flush=True)
- response = post_http_request(prompt, api_url, n, stream)
-
- if stream:
- num_printed_lines = 0
- for h in get_streaming_response(response):
- clear_line(num_printed_lines)
- num_printed_lines = 0
- for i, line in enumerate(h):
- num_printed_lines += 1
- print(f'Beam candidate {i}: {line!r}', flush=True)
- else:
- output = get_response(response)
- for i, line in enumerate(output):
- print(f'Beam candidate {i}: {line!r}', flush=True)
diff --git a/examples/vllm/vllm_demo.py b/examples/vllm/vllm_demo.py
deleted file mode 100644
index f23be9b..0000000
--- a/examples/vllm/vllm_demo.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from vllm import LLM, SamplingParams
-
-prompts = [
- 'Hello, my name is',
- 'The president of the United States is',
- 'The capital of France is',
- 'The future of AI is',
-]
-sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
-
-llm = LLM(model='decapoda-research/llama-7b-hf', gpu_memory_utilization=0.9)
-
-# Print the outputs.
-for i in range(10):
- outputs = llm.generate(prompts, sampling_params)
- for output in outputs:
- prompt = output.prompt
- generated_text = output.outputs[0].text
- print(f'Prompt: {prompt!r}, Generated text: {generated_text!r}')
diff --git a/py b/py
new file mode 100644
index 0000000..c114e12
--- /dev/null
+++ b/py
@@ -0,0 +1,2249 @@
+diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
+index f12c8b8..37b2203 100644
+--- a/.pre-commit-config.yaml
++++ b/.pre-commit-config.yaml
+@@ -1,18 +1,18 @@
+ repos:
+- - repo: https://github.com/PyCQA/flake8
+- rev: 3.8.3
++ - repo: https://gitee.com/openmmlab/mirrors-flake8
++ rev: 5.0.4
+ hooks:
+ - id: flake8
+- - repo: https://github.com/PyCQA/isort
+- rev: 5.10.1
++ - repo: https://gitee.com/openmmlab/mirrors-isort
++ rev: 5.11.5
+ hooks:
+ - id: isort
+- - repo: https://github.com/pre-commit/mirrors-yapf
+- rev: v0.30.0
++ - repo: https://gitee.com/openmmlab/mirrors-yapf
++ rev: v0.32.0
+ hooks:
+ - id: yapf
+- - repo: https://github.com/pre-commit/pre-commit-hooks
+- rev: v4.1.0
++ - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
++ rev: v4.3.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-yaml
+@@ -23,4 +23,4 @@ repos:
+ - id: fix-encoding-pragma
+ args: ["--remove"]
+ - id: mixed-line-ending
+- args: ["--fix=lf"]
++ args: ["--fix=lf"]
+\ No newline at end of file
+diff --git a/README.md b/README.md
+index 65d5e07..735a27e 100644
+--- a/README.md
++++ b/README.md
+@@ -10,25 +10,24 @@
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/)
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+
+-
+
+
+ This is the repo for the `Efficient Finetuning of Quantized LLMs` project, which aims to build and share instruction-following Chinese `baichuan-7b/LLaMA/Pythia/GLM` model tuning methods which can be trained on **a single Nvidia RTX-2080TI**, multi-round chatbot which can be trained on **a single Nvidia RTX-3090** with the context len 2048.
+
+ We uses [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for quantization and is integrated with Huggingface's [PEFT](https://github.com/huggingface/peft) and [transformers](https://github.com/huggingface/transformers/) libraries.
+
+-
+ ## News
+
+ - [23/07/20] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path Llama-2-7b-hf` argument to use the LLaMA-2 model.
+@@ -67,6 +66,7 @@ We uses [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for quantiza
+ As of now, we support the following datasets, most of which are all available in the [Hugging Face datasets library](https://huggingface.co/datasets/).
+
+ - For supervised fine-tuning:
++
+ - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
+ - [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
+ - [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3)
+@@ -88,6 +88,7 @@ As of now, we support the following datasets, most of which are all available in
+ - [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)
+
+ - For reward model training:
++
+ - [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+ - [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
+ - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+@@ -109,7 +110,6 @@ We provide a number of data preprocessing tools in the [data](./chatllms/data) f
+ - [sft_dataset.py](./chatllms/data/sft_dataset.py) : Supervised fine-tuning dataset class and collator
+ - [conv_dataset.py](./chatllms/data/conv_dataset.py) : Conversation dataset class and collator
+
+-
+ ## Model Zoo
+
+ We provide a number of models in the [Hugging Face model hub](https://huggingface.co/decapoda-research). These models are trained with QLoRA and can be used for inference and finetuning. We provide the following models:
+@@ -129,13 +129,17 @@ We provide a number of models in the [Hugging Face model hub](https://huggingfac
+ - CUDA >= 11.0
+
+ - Python 3.8+ and PyTorch 1.13.1+
++
+ - 🤗Transformers, Datasets, Accelerate, PEFT and bitsandbytes
++
+ - jieba, rouge_chinese and nltk (used at evaluation)
++
+ - gradio (used in gradio_webserver.py)
+
+ ### Install required packages
+
+ To load models in 4bits with transformers and bitsandbytes, you have to install accelerate and transformers from source and make sure you have the latest version of the bitsandbytes library (0.39.0). You can achieve the above with the following commands:
++
+ ```bash
+ pip install -q -U bitsandbytes
+ pip install -q -U git+https://github.com/huggingface/transformers.git
+@@ -154,11 +158,11 @@ cd Efficient-Tuning-LLMs
+
+ ## Getting Started
+
+-| main function | Useage | Scripts |
+-| ---------------------------------------- | ------------------------------------------------------------------------------------ | ------------------------------------------ |
+-| [train.py](./train.py) | Full finetune LLMs on SFT datasets | [full_finetune](./scripts/full_finetune) |
+-| [train_lora.py](./train_lora.py) | Finetune LLMs by using Lora (Low-Rank Adaptation of Large Language Models finetune) | [lora_finetune](./scripts/lora_finetune) |
+-| [train_qlora.py](train_qlora.py) | Finetune LLMs by using QLora (QLoRA: Efficient Finetuning of Quantized LLMs) | [qlora_finetune](./scripts/qlora_finetune) |
++| main function | Useage | Scripts |
++| -------------------------------- | ------------------------------------------------------------------------------------ | ------------------------------------------ |
++| [train.py](./train.py) | Full finetune LLMs on SFT datasets | [full_finetune](./scripts/full_finetune) |
++| [train_lora.py](./train_lora.py) | Finetune LLMs by using Lora (Low-Rank Adaptation of Large Language Models finetune) | [lora_finetune](./scripts/lora_finetune) |
++| [train_qlora.py](train_qlora.py) | Finetune LLMs by using QLora (QLoRA: Efficient Finetuning of Quantized LLMs) | [qlora_finetune](./scripts/qlora_finetune) |
+
+ ### QLora int4 Finetune
+
+@@ -170,6 +174,7 @@ python train_qlora.py --model_name_or_path
+ ```
+
+ For models larger than 13B, we recommend adjusting the learning rate:
++
+ ```bash
+ python train_qlora.py –learning_rate 0.0001 --model_name_or_path
+ ```
+@@ -220,7 +225,9 @@ python train_qlora.py \
+ To find more scripts for finetuning and inference, please refer to the `scripts` folder.
+
+ ## Quantization
++
+ Quantization parameters are controlled from the `BitsandbytesConfig` ([see HF documenation](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig)) as follows:
++
+ - Loading in 4 bits is activated through `load_in_4bit`
+ - The datatype used for the linear layer computations with `bnb_4bit_compute_dtype`
+ - Nested quantization is activated through `bnb_4bit_use_double_quant`
+@@ -245,20 +252,25 @@ Quantization parameters are controlled from the `BitsandbytesConfig` ([see HF do
+ ## Tutorials and Demonstrations
+
+ We provide two Google Colab notebooks to demonstrate the use of 4bit models in inference and fine-tuning. These notebooks are intended to be a starting point for further research and development.
++
+ - [Basic usage Google Colab notebook](https://colab.research.google.com/drive/1ge2F1QSK8Q7h0hn3YKuBCOAS0bK8E0wf?usp=sharing) - This notebook shows how to use 4bit models in inference with all their variants, and how to run GPT-neo-X (a 20B parameter model) on a free Google Colab instance 🤯
+ - [Fine tuning Google Colab notebook](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing) - This notebook shows how to fine-tune a 4bit model on a downstream task using the Hugging Face ecosystem. We show that it is possible to fine tune GPT-neo-X 20B on a Google Colab instance!
+
+ Other examples are found under the examples/ folder.
++
+ - Finetune LLama-7B (ex1)
+ - Finetune GPT-neo-X 20B (ex2)
+
+ ## Using Local Datasets
++
+ You can specify the path to your dataset using the --dataset argument. If the --dataset_format argument is not set, it will default to the Alpaca format. Here are a few examples:
+
+ - Training with an alpaca format dataset:
++
+ ```python
+ python train_qlora.py --dataset="path/to/your/dataset"
+ ```
++
+ - Training with a self-instruct format dataset:
+
+ ```python
+@@ -266,9 +278,11 @@ python train_qlora.py --dataset="path/to/your/dataset" --dataset_format="self-in
+ ```
+
+ ## Multi GPU
++
+ Multi GPU training and inference work out-of-the-box with Hugging Face's Accelerate. Note that the per_device_train_batch_size and per_device_eval_batch_size arguments are global batch sizes unlike what their name suggest.
+
+ When loading a model for training or inference on multiple GPUs you should pass something like the following to AutoModelForCausalLM.from_pretrained():
++
+ ```python
+ device_map = "auto"
+ max_memory = {i: '46000MB' for i in range(torch.cuda.device_count())}
+@@ -303,15 +317,15 @@ python gradio_webserver.py \
+ --lora_model_name_or_path `path/to/your/model_dir`
+ ```
+
+-
+ ## Sample Outputs
++
+ We provide generations for the models described in the paper for both OA and Vicuna queries in the `eval/generations` folder. These are intended to foster further research on model evaluation and analysis.
+
+ Can you distinguish ChatGPT from Guanaco? Give it a try!
+ You can access [the model response Colab here](https://colab.research.google.com/drive/1kK6xasHiav9nhiRUJjPMZb4fAED4qRHb?usp=sharing) comparing ChatGPT and Guanaco 65B on Vicuna prompts.
+
+-
+ ## Known Issues and Limitations
++
+ Here a list of known issues and bugs. If your issue is not reported here, please open a new issue and describe the problem.
+
+ 1. 4-bit inference is slow. Currently, our 4-bit inference implementation is not yet integrated with the 4-bit matrix multiplication
+@@ -319,13 +333,12 @@ Here a list of known issues and bugs. If your issue is not reported here, please
+ 3. Currently, using `bnb_4bit_compute_type='fp16'` can lead to instabilities. For 7B LLaMA, only 80% of finetuning runs complete without error. We have solutions, but they are not integrated yet into bitsandbytes.
+ 4. Make sure that `tokenizer.bos_token_id = 1` to avoid generation issues.
+
+-
+ ## License
+
+ `Efficient Finetuning of Quantized LLMs` is released under the Apache 2.0 license.
+
+-
+ ## Acknowledgements
++
+ We thank the Huggingface team, in particular Younes Belkada, for their support integrating QLoRA with PEFT and transformers libraries.
+
+ We appreciate the work by many open-source contributors, especially:
+@@ -338,7 +351,6 @@ We appreciate the work by many open-source contributors, especially:
+ - [Vicuna](https://github.com/lm-sys/FastChat/)
+ - [xTuring](https://github.com/stochasticai/xTuring)
+
+-
+ ## Citation
+
+ Please cite the repo if you use the data or code in this repo.
+diff --git a/README_zh.md b/README_zh.md
+index 00b7078..3b2d883 100644
+--- a/README_zh.md
++++ b/README_zh.md
+@@ -13,22 +13,22 @@
+
+
+ 这里是`Efficient Finetuning of Quantized LLMs`项目的存储库,旨在构建和开源 遵循指令的`baichuan/LLaMA/Pythia/GLM`中文大模型微调训练方法,该方法可以在**单个 Nvidia RTX-2080TI**上进行训练,多轮聊天机器人可以在**单个 Nvidia RTX-3090**上进行上下文长度 2048的模型训练。
+
+ 我们使用[bitsandbytes](https://github.com/TimDettmers/bitsandbytes)进行量化,并与Huggingface的[PEFT](https://github.com/huggingface/peft)和 [transformers](https://github.com/huggingface/transformers/)库集成。
+
+- 本项目主要内容如下:
++本项目主要内容如下:
+
+ - 📗 支持全量参数指令微调、LoRA指令微调(后续将会提供支持), QLoRA低成本高效指令微调。
+ - 📗 支持绝大部分主流的开源大模型,如百川 baichuan、Ziya、Bloom、LLaMA、Pythia、OPT等。
+@@ -88,6 +88,7 @@ QLora 引入了多种创新,旨在在不牺牲性能的情况下减少内存
+ 截至目前,我们支持以下数据集,这些数据集都可以在 [Hugging Face Datasets](https://huggingface.co/datasets) 上找到。我们将在未来添加更多数据集。
+
+ - For supervised fine-tuning:
++
+ - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
+ - [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
+ - [Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3)
+@@ -103,16 +104,16 @@ QLora 引入了多种创新,旨在在不牺牲性能的情况下减少内存
+ - [Evol-Instruct](https://huggingface.co/datasets/victor123/evol_instruct_70k)
+
+ - For reward model training:
++
+ - [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+ - [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
+ - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+ - [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+
+-
+ 请参考 [data/README.md](data/README.md) 了解如何使用这些数据集训练自己的 ChatGPT。如果您想探索更多数据集,请参考 [awesome-instruction-datasets](https://github.com/jianzhnie/awesome-instruction-datasets). 默认情况下,我们使用 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集进行训练和微调。
+
+-
+ 部分数据集需要 huggingface 的账号认证确认才能使用,我们建议使用以下命令登录您的 Hugging Face 账户。
++
+ ```bash
+ pip install --upgrade huggingface_hub
+ huggingface-cli login
+@@ -126,7 +127,6 @@ huggingface-cli login
+ - sft_dataset.py:有监督的对话数据集类
+ - conv_dataset.py:多轮对话数据集类
+
+-
+ ## 模型仓库
+
+ 我们在 [Hugging Face ](https://huggingface.co/GaussianTech/)提供了许多模型。这些模型经过Self- Instruct 数据集的训练,可用于推理和微调:
+diff --git a/assets/guanaco.svg b/assets/guanaco.svg
+index 64e341f..c126704 100644
+--- a/assets/guanaco.svg
++++ b/assets/guanaco.svg
+@@ -1,98 +1,98 @@
+-
+-
+-
+-
++
++
++
++
+diff --git a/chatbot.py b/chatbot.py
+index 31021c0..fa44857 100644
+--- a/chatbot.py
++++ b/chatbot.py
+@@ -1,27 +1,31 @@
+-import openai
+ import gradio as gr
++import openai
+
+-
+-if __name__ == "__main__":
+- openai.api_key = "Your API key"
++if __name__ == '__main__':
++ openai.api_key = 'Your API key'
+
+ messages = [
+- {"role": "system", "content": "You are a helpful and kind AI Assistant."},
++ {
++ 'role': 'system',
++ 'content': 'You are a helpful and kind AI Assistant.'
++ },
+ ]
+
+ def chatbot(input):
+ if input:
+- messages.append({"role": "user", "content": input})
+- chat = openai.ChatCompletion.create(
+- model="gpt-3.5-turbo", messages=messages
+- )
++ messages.append({'role': 'user', 'content': input})
++ chat = openai.ChatCompletion.create(model='gpt-3.5-turbo',
++ messages=messages)
+ reply = chat.choices[0].message.content
+- messages.append({"role": "assistant", "content": reply})
++ messages.append({'role': 'assistant', 'content': reply})
+ return reply
+
+- inputs = gr.inputs.Textbox(lines=7, label="Chat with AI")
+- outputs = gr.outputs.Textbox(label="Reply")
++ inputs = gr.inputs.Textbox(lines=7, label='Chat with AI')
++ outputs = gr.outputs.Textbox(label='Reply')
+
+- gr.Interface(fn=chatbot, inputs=inputs, outputs=outputs, title="AI Chatbot",
+- description="Ask anything you want",
+- theme="compact").launch(share=True)
+\ No newline at end of file
++ gr.Interface(fn=chatbot,
++ inputs=inputs,
++ outputs=outputs,
++ title='AI Chatbot',
++ description='Ask anything you want',
++ theme='compact').launch(share=True)
+diff --git a/chatllms/configs/gen_args.py b/chatllms/configs/gen_args.py
+index ca5df19..8ed84fd 100644
+--- a/chatllms/configs/gen_args.py
++++ b/chatllms/configs/gen_args.py
+@@ -4,9 +4,7 @@ from typing import Any, Dict, Optional
+
+ @dataclass
+ class GenerationArguments:
+- """
+- Arguments pertaining to specify the model generation parameters.
+- """
++ """Arguments pertaining to specify the model generation parameters."""
+ # generation parameters
+ # 是否使用cache
+ use_cache: Optional[bool] = field(default=True)
+diff --git a/chatllms/data/conv_dataset.py b/chatllms/data/conv_dataset.py
+index 4e885f3..5f9f3e9 100644
+--- a/chatllms/data/conv_dataset.py
++++ b/chatllms/data/conv_dataset.py
+@@ -13,22 +13,22 @@ from chatllms.data.sft_dataset import DataCollatorForSupervisedDataset
+
+ @dataclass
+ class VicunaDataset(Dataset):
+- """
+- Dataset for multi-turn conversations using a Transformer model.
++ """Dataset for multi-turn conversations using a Transformer model.
+
+ Attributes:
+ raw_data: The preprocessed dataset dict to load
+ tokenizer: Pretrained tokenizer to encode text
+ max_seq_length: Maximum sequence length for model inputs
+ """
++
+ def __init__(
+ self,
+ raw_data: datasets.DatasetDict,
+ tokenizer: PreTrainedTokenizer,
+ max_seq_length: int = 1024,
+ ):
+- """
+- Initialize the dataset with conversations, tokenizer, and max sequence length.
++ """Initialize the dataset with conversations, tokenizer, and max
++ sequence length.
+
+ Args:
+ raw_data: The preprocessed dataset dict to load
+@@ -51,8 +51,7 @@ class VicunaDataset(Dataset):
+ def tokenize_conversation(
+ self,
+ conversation: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor]:
+- """
+- Tokenize a single conversation into input IDs and labels.
++ """Tokenize a single conversation into input IDs and labels.
+
+ Args:
+ conversation: List of turns in the conversation
+@@ -105,8 +104,7 @@ class VicunaDataset(Dataset):
+ return torch.tensor(input_ids), torch.tensor(labels)
+
+ def _get_human_prefix(self, turn_id: int, role: str) -> str:
+- """
+- Get the prefix for a human turn.
++ """Get the prefix for a human turn.
+
+ Args:
+ turn_id: Index of the current turn
+@@ -126,8 +124,7 @@ class VicunaDataset(Dataset):
+ return len(self.raw_data)
+
+ def __getitem__(self, index: int) -> Dict:
+- """
+- Get the input IDs and labels for a specific conversation.
++ """Get the input IDs and labels for a specific conversation.
+
+ Args:
+ index: Index of the conversation
+@@ -147,23 +144,22 @@ class VicunaDataset(Dataset):
+
+ @dataclass
+ class ConversationDataset(Dataset):
+- """
+- Dataset for multi-turn conversations using Transformer model.
++ """Dataset for multi-turn conversations using Transformer model.
+
+ Attributes:
+ raw_data: The preprocessed dataset dict to load
+ tokenizer: Pretrained tokenizer
+ max_seq_length: Maximum length of sequence
+ """
++
+ def __init__(
+ self,
+ raw_data: datasets.DatasetDict,
+ tokenizer: PreTrainedTokenizer,
+ max_seq_length: int = 1024,
+ ):
+- """
+- Initialize the dataset with conversations, tokenizer and max sequence length.
+- """
++ """Initialize the dataset with conversations, tokenizer and max
++ sequence length."""
+ self.raw_data = raw_data
+ self.tokenizer = tokenizer
+ self.max_seq_length = max_seq_length
+@@ -174,8 +170,7 @@ class ConversationDataset(Dataset):
+ self,
+ conversation: List[Dict],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+- """
+- Tokenize a single conversation into input IDs and labels.
++ """Tokenize a single conversation into input IDs and labels.
+
+ Args:
+ conversation: List of turns in the conversation
+@@ -218,8 +213,7 @@ class ConversationDataset(Dataset):
+ return len(self.raw_data)
+
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
+- """
+- Get the input IDs and labels for a specific conversation.
++ """Get the input IDs and labels for a specific conversation.
+
+ Args:
+ index: Index of the conversation
+@@ -248,9 +242,9 @@ class ConversationDataset(Dataset):
+
+ @dataclass
+ class ConversationDataCollator(object):
+- """
+- Collate and pad a batch of conversation examples to prepare for training.
+- """
++ """Collate and pad a batch of conversation examples to prepare for
++ training."""
++
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+@@ -300,8 +294,7 @@ class ConversationDataCollator(object):
+
+ def make_conversation_data_module(tokenizer: PreTrainedTokenizer,
+ args) -> Dict[str, Dataset]:
+- """
+- Create dataset and collator for conversation modeling.
++ """Create dataset and collator for conversation modeling.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): The tokenizer object.
+@@ -310,7 +303,6 @@ def make_conversation_data_module(tokenizer: PreTrainedTokenizer,
+
+ Returns:
+ dict: A dictionary containing the train_dataset and eval_dataset.
+-
+ """
+ # Determine the appropriate dataset class based on dataset_type flag
+ dataset_cls = (VicunaDataset if args.conversation_template == 'vicuna' else
+diff --git a/chatllms/data/data_utils.py b/chatllms/data/data_utils.py
+index 01e96a2..7c03cfa 100644
+--- a/chatllms/data/data_utils.py
++++ b/chatllms/data/data_utils.py
+@@ -87,8 +87,7 @@ def extract_default_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+
+
+ def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+- """
+- Extracts input from an example in the Alpaca dataset.
++ """Extracts input from an example in the Alpaca dataset.
+
+ Args:
+ example: A dictionary containing a single example from the Alpaca dataset.
+@@ -100,7 +99,6 @@ def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+ >>> example = {'input': 'example input', 'output': 'example output'}
+ >>> extract_alpaca_dataset(example)
+ {'input': 'example input'}
+-
+ """
+ if example.get('input', '') != '':
+ prompt_format = ALPACA_PROMPT_DICT['prompt_input']
+@@ -110,8 +108,8 @@ def extract_alpaca_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+
+
+ def extract_vicuna_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+- """
+- Extracts the input and output portions of a single conversation example from the Vicuña format.
++ """Extracts the input and output portions of a single conversation example
++ from the Vicuña format.
+
+ Args:
+ example (Dict[str, Any]): A single conversation example in the Vicuña format.
+@@ -184,8 +182,8 @@ def extract_random_prompt_dataset(example: Dict[str, Any]) -> Dict[str, str]:
+
+ def local_dataset(dataset_path: str,
+ eval_dataset_size: float = 0.1) -> Tuple[Dataset, Dataset]:
+- """
+- Reads in a dataset from a file and returns it as a split train-test dataset.
++ """Reads in a dataset from a file and returns it as a split train-test
++ dataset.
+
+ Args:
+ dataset_path (str): The name of the dataset file to read in. \
+@@ -195,7 +193,6 @@ def local_dataset(dataset_path: str,
+ A tuple containing two datasets - the training subset and the testing subset.
+ Raises:
+ ValueError: If the specified file format is unsupported.
+-
+ """
+
+ # Read in the full dataset from file based on the file format
+@@ -221,8 +218,7 @@ def local_dataset(dataset_path: str,
+ def load_data(
+ dataset_path: str,
+ eval_dataset_size: float = 0.1) -> Union[Dict[str, Dataset], None]:
+- """
+- Load a dataset based on its name.
++ """Load a dataset based on its name.
+
+ Args:
+ dataset_path: A string representing the path to the dataset to be loaded.
+@@ -238,7 +234,6 @@ def load_data(
+ Examples:
+ >>> load_data('alpaca')
+ {'train': Dataset(...), 'validation': Dataset(...), 'test': Dataset(...)}
+-
+ """
+ if not os.path.exists(dataset_path):
+ # Download dataset from HuggingFace Datasets
+@@ -263,9 +258,7 @@ def formate_instruction_dataset(
+ dataset_name: str,
+ dataset_format: str,
+ instruction_template: str = 'default') -> Optional[Dict[str, Dataset]]:
+- """
+- Formats a given dataset based on its name and format.
+-
++ """Formats a given dataset based on its name and format.
+
+ Removes unused columns, renames columns to 'input' and 'output',
+ and applies dataset-specific formatting based on the dataset_name.
+@@ -283,6 +276,7 @@ def formate_instruction_dataset(
+ specified format.
+ None if the dataset does not exist or if the format is not recognized.
+ """
++
+ def _format_dolly15k(dataset: Dataset) -> Dataset:
+ """Format Dolly-15k dataset."""
+ dataset = dataset.rename_column('context', 'input')
+@@ -376,8 +370,8 @@ def split_train_eval(
+ do_train: bool = True,
+ max_train_samples: int = None,
+ ) -> Dict[str, Dataset]:
+- """
+- Prepare the training and evaluation datasets for a machine learning model.
++ """Prepare the training and evaluation datasets for a machine learning
++ model.
+
+ Args:
+ dataset (DatasetDict): The complete dataset containing train, validation, and test splits.
+@@ -435,9 +429,8 @@ def split_train_eval(
+
+
+ def make_data_module(args):
+- """
+- Make dataset and collator for supervised fine-tuning.
+- Datasets are expected to have the following columns: { `input`, `output` }
++ """Make dataset and collator for supervised fine-tuning. Datasets are
++ expected to have the following columns: { `input`, `output` }
+
+ Available datasets to be selected with `dataset` argument:
+ - alpaca, 52002 examples
+@@ -456,7 +449,6 @@ def make_data_module(args):
+ - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
+ - flan (FLAN v2), up to 20M examples available
+ - vicuna
+-
+ """
+ train_datasets: List[Dataset] = []
+ eval_datasets: List[Dataset] = []
+diff --git a/chatllms/data/sft_dataset.py b/chatllms/data/sft_dataset.py
+index 71383dd..a4c0177 100644
+--- a/chatllms/data/sft_dataset.py
++++ b/chatllms/data/sft_dataset.py
+@@ -16,8 +16,7 @@ logger = logging.getLogger(__name__)
+
+
+ class SFTInstructionDataset(Dataset):
+- """
+- Dataset for supervised fine-tuning of instruction following models.
++ """Dataset for supervised fine-tuning of instruction following models.
+
+ Converts raw dataset containing source/target instructions
+ into tokenized input/target pairs with truncation and padding.
+@@ -26,14 +25,13 @@ class SFTInstructionDataset(Dataset):
+ dataset: The raw dataset containing source/target examples
+ tokenizer: Tokenizer to use for encoding text
+ max_seq_len: Maximum sequence length for truncation
+-
+ """
++
+ def __init__(self,
+ raw_data: DatasetDict,
+ tokenizer: PreTrainedTokenizer,
+ max_seq_len: int = 1024):
+- """
+- Initialize the dataset with the raw data and tokenizer.
++ """Initialize the dataset with the raw data and tokenizer.
+
+ Args:
+ raw_data: Raw dataset containing source/target examples
+@@ -45,12 +43,11 @@ class SFTInstructionDataset(Dataset):
+ self.max_seq_len = max_seq_len
+
+ def __len__(self) -> int:
+- """Return number of examples in dataset"""
++ """Return number of examples in dataset."""
+ return len(self.dataset)
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+- """
+- Convert an raw example into tokenized input/target pair.
++ """Convert an raw example into tokenized input/target pair.
+
+ Args:
+ idx: Index of the example in the dataset
+@@ -107,14 +104,15 @@ class SFTInstructionDataset(Dataset):
+ class SupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+
+- Args:
+- hf_dataset (dataset): The preprocesed dataset to load.
+- tokenizer (PreTrainedTokenizer): The tokenizer to use when tokenizing the data.
+- source_max_len (int): The maximum length allowed for the source text.
+- target_max_len (int): The maximum length allowed for the target text.
+- train_on_source (bool): If True, the model will be trained on the source text as well as the target text.
+- predict_with_generate (bool): If True, the model will generate predictions instead of training.
++ Args:
++ hf_dataset (dataset): The preprocesed dataset to load.
++ tokenizer (PreTrainedTokenizer): The tokenizer to use when tokenizing the data.
++ source_max_len (int): The maximum length allowed for the source text.
++ target_max_len (int): The maximum length allowed for the target text.
++ train_on_source (bool): If True, the model will be trained on the source text as well as the target text.
++ predict_with_generate (bool): If True, the model will generate predictions instead of training.
+ """
++
+ def __init__(
+ self,
+ hf_dataset: datasets.DatasetDict,
+@@ -185,9 +183,7 @@ class SupervisedDataset(Dataset):
+
+ @dataclass
+ class DataCollatorForSupervisedDataset:
+- """
+- Collate and pad examples for supervised training.
+- """
++ """Collate and pad examples for supervised training."""
+
+ tokenizer: PreTrainedTokenizer
+ predict_with_generate: bool = False
+@@ -196,8 +192,7 @@ class DataCollatorForSupervisedDataset:
+ self,
+ examples: List[Dict[str,
+ torch.Tensor]]) -> Dict[str, torch.Tensor]:
+- """
+- Collate examples into dictionary for supervised training.
++ """Collate examples into dictionary for supervised training.
+
+ Args:
+ examples: List of examples, each containing 'input_ids' and 'labels'
+diff --git a/chatllms/data/utils/conversation.py b/chatllms/data/utils/conversation.py
+index 1ef229a..bcbf4ea 100644
+--- a/chatllms/data/utils/conversation.py
++++ b/chatllms/data/utils/conversation.py
+@@ -1,6 +1,4 @@
+-"""
+-Conversation prompt templates.
+-"""
++"""Conversation prompt templates."""
+
+ import dataclasses
+ from enum import Enum, auto
+@@ -25,7 +23,8 @@ class SeparatorStyle(Enum):
+
+ @dataclasses.dataclass
+ class Conversation:
+- """A class that manages prompt templates and keeps all conversation history."""
++ """A class that manages prompt templates and keeps all conversation
++ history."""
+
+ # The name of this template
+ name: str
+@@ -163,8 +162,9 @@ class Conversation:
+ def update_last_message(self, message: str):
+ """Update the last output.
+
+- The last message is typically set to be None when constructing the prompt,
+- so we need to update it in-place after getting the response from a model.
++ The last message is typically set to be None when constructing the
++ prompt, so we need to update it in-place after getting the response
++ from a model.
+ """
+ self.messages[-1][1] = message
+
+diff --git a/chatllms/data/utils/convert_alpaca.py b/chatllms/data/utils/convert_alpaca.py
+index 8af5520..4513072 100644
+--- a/chatllms/data/utils/convert_alpaca.py
++++ b/chatllms/data/utils/convert_alpaca.py
+@@ -1,5 +1,4 @@
+-"""
+-Convert alpaca dataset into sharegpt format.
++"""Convert alpaca dataset into sharegpt format.
+
+ Usage: python3 -m chatllms.data.convert_alpaca --in alpaca_data.json
+ """
+diff --git a/chatllms/data/utils/convert_oasst1.py b/chatllms/data/utils/convert_oasst1.py
+index a6f06ae..35a538d 100644
+--- a/chatllms/data/utils/convert_oasst1.py
++++ b/chatllms/data/utils/convert_oasst1.py
+@@ -167,9 +167,9 @@ def test_add_open_assistant(fixup_personality,
+ only_personality,
+ deberta_grading,
+ save_json=True):
+- """
+- Flatten tree structure into one row per path from root to leaf
+- Also turn into human_bot prompting format:
++ """Flatten tree structure into one row per path from root to leaf Also turn
++ into human_bot prompting format:
++
+ : question\n: answer : question2\n: answer2 Etc.
+ Also saves a .json locally as side-effect
+ returns list of dicts, containing intput, prompt_type and source
+diff --git a/chatllms/data/vicuna_dataset.py b/chatllms/data/vicuna_dataset.py
+index 392d6fe..f5fcbf1 100644
+--- a/chatllms/data/vicuna_dataset.py
++++ b/chatllms/data/vicuna_dataset.py
+@@ -49,8 +49,7 @@ def tokenize_conversations(
+ conversations: List[str],
+ tokenizer: PreTrainedTokenizer,
+ ) -> torch.Tensor:
+- """Tokenize conversations
+- """
++ """Tokenize conversations."""
+
+ input_ids = tokenizer(
+ conversations,
+@@ -70,7 +69,9 @@ def mask_targets(
+ tokenizer: PreTrainedTokenizer,
+ conv: Conversation,
+ ) -> None:
+- """Mask targets. Only compute loss on the assistant outputs.
++ """Mask targets.
++
++ Only compute loss on the assistant outputs.
+ """
+
+ # Mask targets
+@@ -108,8 +109,7 @@ def mask_targets(
+
+ def preprocess(sources: Sequence[Dict[str, str]],
+ tokenizer: PreTrainedTokenizer) -> Dict[str, List[int]]:
+- """
+- Preprocesses the data by tokenizing it.
++ """Preprocesses the data by tokenizing it.
+
+ Args:
+ sources (Sequence[Dict[str, str]]): List of conversation sources.
+@@ -131,13 +131,13 @@ def preprocess(sources: Sequence[Dict[str, str]],
+
+
+ class SupervisedDataset(Dataset):
+- """
+- Dataset for supervised fine-tuning.
++ """Dataset for supervised fine-tuning.
+
+ Args:
+ raw_data (List[Dict]): Raw input data.
+ tokenizer (PreTrainedTokenizer): Tokenizer for preprocessing the data.
+ """
++
+ def __init__(self, raw_data: List[Dict[str, List[str]]],
+ tokenizer: PreTrainedTokenizer) -> None:
+ super().__init__()
+@@ -160,8 +160,7 @@ class SupervisedDataset(Dataset):
+ return len(self.input_ids)
+
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
+- """
+- Get an example from the dataset at the specified index.
++ """Get an example from the dataset at the specified index.
+
+ Args:
+ index (int): Index of the example to retrieve.
+@@ -177,16 +176,14 @@ class SupervisedDataset(Dataset):
+
+
+ class LazySupervisedDataset(Dataset):
+- """
+- Dataset for supervised fine-tuning.
+- """
++ """Dataset for supervised fine-tuning."""
++
+ def __init__(
+ self,
+ raw_data: List[Dict[str, str]],
+ tokenizer: PreTrainedTokenizer,
+ ):
+- """
+- Initialize the LazySupervisedDataset.
++ """Initialize the LazySupervisedDataset.
+
+ Args:
+ raw_data (List[Dict[str, str]]): The raw input data for the dataset.
+@@ -198,8 +195,7 @@ class LazySupervisedDataset(Dataset):
+ self.cached_data_dict: Dict[int, Dict[str, torch.Tensor]] = {}
+
+ def __len__(self) -> int:
+- """
+- Get the length of the dataset.
++ """Get the length of the dataset.
+
+ Returns:
+ int: The length of the dataset.
+@@ -207,8 +203,7 @@ class LazySupervisedDataset(Dataset):
+ return len(self.raw_data)
+
+ def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
+- """
+- Get an item from the dataset at the given index.
++ """Get an item from the dataset at the given index.
+
+ Args:
+ i (int): The index of the item to retrieve.
+@@ -240,8 +235,7 @@ def make_conversation_data_module(
+ lazy_preprocess: bool,
+ data_path: str,
+ ) -> Dict[str, Dataset]:
+- """
+- Make dataset and collator for supervised fine-tuning.
++ """Make dataset and collator for supervised fine-tuning.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): The tokenizer object.
+@@ -250,7 +244,6 @@ def make_conversation_data_module(
+
+ Returns:
+ dict: A dictionary containing the train_dataset and eval_dataset.
+-
+ """
+ # Determine the appropriate dataset class based on lazy_preprocess flag
+
+diff --git a/chatllms/evaluation/evaluate_zh.py b/chatllms/evaluation/evaluate_zh.py
+index f23d611..ca2c26e 100644
+--- a/chatllms/evaluation/evaluate_zh.py
++++ b/chatllms/evaluation/evaluate_zh.py
+@@ -95,8 +95,7 @@ class CEval(object):
+ data_path: str = 'ceval/ceval-exam',
+ output_dir: str = 'ceval_output',
+ ) -> None:
+- """
+- Initialize the CEval object.
++ """Initialize the CEval object.
+
+ Args:
+ model (PreTrainedModel): Pre-trained model for question answering.
+@@ -112,8 +111,7 @@ class CEval(object):
+ self.output_dir = output_dir
+
+ def run(self, shot: int, split: str) -> None:
+- """
+- Run the evaluation for all tasks.
++ """Run the evaluation for all tasks.
+
+ Args:
+ shot (int): Number of additional examples to include in the prompt.
+@@ -144,8 +142,7 @@ class CEval(object):
+
+ def run_single_task(self, task_name: str, shot: int,
+ split: str) -> Tuple[List[Dict[str, str]], float]:
+- """
+- Run the evaluation for a single task.
++ """Run the evaluation for a single task.
+
+ Args:
+ task_name (str): Name of the task.
+@@ -198,8 +195,7 @@ class CEval(object):
+ def build_example(self,
+ data: Dict[str, str],
+ with_answer: bool = True) -> str:
+- """
+- Builds an example string based on the given data.
++ """Builds an example string based on the given data.
+
+ Args:
+ data (Dict[str, str]): A dictionary containing the question, choices, and answer.
+diff --git a/chatllms/model/compute_metrics.py b/chatllms/model/compute_metrics.py
+index 6488c39..2b311d0 100644
+--- a/chatllms/model/compute_metrics.py
++++ b/chatllms/model/compute_metrics.py
+@@ -10,14 +10,14 @@ from transformers import PreTrainedTokenizer
+
+ @dataclass
+ class ComputeMetrics:
+- """
+- Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+- Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
++ """Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
+
++ Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
+ """
++
+ def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
+- """
+- Initialize the ComputeMetrics class with a pre-trained tokenizer object.
++ """Initialize the ComputeMetrics class with a pre-trained tokenizer
++ object.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): A pre-trained tokenizer object to be used for decoding tokenized sequences.
+@@ -27,8 +27,7 @@ class ComputeMetrics:
+ def __call__(
+ self, eval_preds: List[Union[np.ndarray, Tuple[np.ndarray]]]
+ ) -> Dict[str, float]:
+- """
+- Computes evaluation metrics for model predictions.
++ """Computes evaluation metrics for model predictions.
+
+ Args:
+ eval_preds (List[Union[np.ndarray, Tuple[np.ndarray]]]): List of tuples containing prediction and label arrays.
+diff --git a/chatllms/model/llm_perplexity.py b/chatllms/model/llm_perplexity.py
+index 2da2fe7..30a1c79 100644
+--- a/chatllms/model/llm_perplexity.py
++++ b/chatllms/model/llm_perplexity.py
+@@ -18,8 +18,7 @@ from chatllms.utils.model_utils import add_special_tokens_if_missing
+
+
+ class LLMPerplexity:
+- """
+- Language model to compute perplexity.
++ """Language model to compute perplexity.
+
+ Args:
+ cache_dir (str): Directory to cache models.
+@@ -31,6 +30,7 @@ class LLMPerplexity:
+ fp16 (bool): Whether to use 16-bit precision.
+ device (str): Device to load model to.
+ """
++
+ def __init__(
+ self,
+ cache_dir: str = None,
+@@ -93,8 +93,7 @@ class LLMPerplexity:
+ def get_perplexity(self,
+ input_texts: Union[str, List[str]],
+ batch_size: int = None) -> Union[float, List[float]]:
+- """
+- Compute perplexity on input text(s).
++ """Compute perplexity on input text(s).
+
+ Args:
+ input_texts (Union[str, List[str]]): Input text(s) to compute perplexity for.
+diff --git a/chatllms/model/load_pretrain_model.py b/chatllms/model/load_pretrain_model.py
+index b917bad..a8e5c38 100644
+--- a/chatllms/model/load_pretrain_model.py
++++ b/chatllms/model/load_pretrain_model.py
+@@ -27,9 +27,8 @@ def load_model_tokenizer(
+ is_trainable: Optional[bool] = True,
+ logger=None,
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+- """
+- Returns a language model and tokenizer for text generation that can be trained with mixed precision.
+- Support both training and inference.
++ """Returns a language model and tokenizer for text generation that can be
++ trained with mixed precision. Support both training and inference.
+
+ Args:
+ args: A dictionary containing various hyperparameters.
+diff --git a/chatllms/model/mmlueval_callback.py b/chatllms/model/mmlueval_callback.py
+index 839b5b1..2e6e9c2 100644
+--- a/chatllms/model/mmlueval_callback.py
++++ b/chatllms/model/mmlueval_callback.py
+@@ -17,10 +17,9 @@ from chatllms.data.sft_dataset import SupervisedDataset
+
+ @dataclass
+ class MMLUEvalCallback(TrainerCallback):
+- """
+- A callback function called after each evaluation step during training to evaluate \
+- the performance of a model on an
+- MMLU (Mean Length of Utterance) dataset.
++ """A callback function called after each evaluation step during training to
++ evaluate the performance of a model on an MMLU (Mean Length of Utterance)
++ dataset.
+
+ Args:
+ trainer (Trainer): The trainer instance to be used.
+@@ -28,6 +27,7 @@ class MMLUEvalCallback(TrainerCallback):
+ data_dir (str): The directory where the MMLU dataset is stored.
+ args (argparse.Namespace): The command line arguments for the current run.
+ """
++
+ def __init__(
+ self,
+ trainer: 'Trainer',
+@@ -98,8 +98,8 @@ class MMLUEvalCallback(TrainerCallback):
+ model: PreTrainedModel,
+ **kwargs: Any,
+ ) -> None:
+- """
+- Iterate over the batches of the evaluation dataset and make predictions for MMLU.
++ """Iterate over the batches of the evaluation dataset and make
++ predictions for MMLU.
+
+ Args:
+ args (Dict[str, Any]): Dictionary containing the evaluation arguments.
+diff --git a/chatllms/model/sample_generate_callback.py b/chatllms/model/sample_generate_callback.py
+index 6c453fa..10a2788 100644
+--- a/chatllms/model/sample_generate_callback.py
++++ b/chatllms/model/sample_generate_callback.py
+@@ -7,13 +7,14 @@ from transformers import PreTrainedTokenizer, TrainerCallback
+
+ @dataclass
+ class SampleGenerateCallback(TrainerCallback):
+- """
+- A callback that generates text samples from a pre-trained language model during training.
++ """A callback that generates text samples from a pre-trained language model
++ during training.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): The tokenizer used to preprocess inputs.
+ max_new_tokens (int): The maximum number of tokens to generate in response to each input.
+ """
++
+ def __init__(self, tokenizer: PreTrainedTokenizer,
+ generation_config: argparse.Namespace, logger: None):
+ self.tokenizer = tokenizer
+@@ -49,8 +50,7 @@ class SampleGenerateCallback(TrainerCallback):
+
+ def on_evaluate(self, args: Any, state: Dict[str, Any], control: Any,
+ **kwargs: Any) -> None:
+- """
+- Generates text samples from the language model during evaluation.
++ """Generates text samples from the language model during evaluation.
+
+ Args:
+ args (Any): Trainer arguments, not used in this method.
+diff --git a/chatllms/model/save_peft_model_callback.py b/chatllms/model/save_peft_model_callback.py
+index 2c93e8c..fb37f12 100644
+--- a/chatllms/model/save_peft_model_callback.py
++++ b/chatllms/model/save_peft_model_callback.py
+@@ -7,16 +7,15 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+
+ class SavePeftModelCallback(TrainerCallback):
+- """
+- Callback to save PEFT model checkpoints during training.
++ """Callback to save PEFT model checkpoints during training.
+
+ Saves both the full model and the adapter model to separate directories
+ within the checkpoint directory.
+ """
++
+ def save_model(self, args: Any, state: TrainingArguments,
+ kwargs: Dict[str, Any]) -> None:
+- """
+- Saves the PEFT model checkpoint.
++ """Saves the PEFT model checkpoint.
+
+ Args:
+ args (Any): The command line arguments passed to the script.
+@@ -52,8 +51,8 @@ class SavePeftModelCallback(TrainerCallback):
+ def on_save(self, args: Any, state: TrainingArguments,
+ control: TrainerControl,
+ **kwargs: Dict[str, Any]) -> TrainerControl:
+- """
+- Callback method that calls save_model() and returns `control` argument.
++ """Callback method that calls save_model() and returns `control`
++ argument.
+
+ Args:
+ args (Any): The command line arguments passed to the script.
+@@ -74,8 +73,8 @@ class SavePeftModelCallback(TrainerCallback):
+ def on_train_end(self, args: Any, state: TrainingArguments,
+ control: TrainerControl, **kwargs: Dict[str,
+ Any]) -> None:
+- """
+- Callback method that saves the model checkpoint and creates a 'completed' file in the output directory.
++ """Callback method that saves the model checkpoint and creates a
++ 'completed' file in the output directory.
+
+ Args:
+ args (Any): The command line arguments passed to the script.
+diff --git a/chatllms/train/training.py b/chatllms/train/training.py
+index 0a2a2c8..c473969 100644
+--- a/chatllms/train/training.py
++++ b/chatllms/train/training.py
+@@ -11,8 +11,7 @@ from torch.utils.data import Dataset
+
+ def train_and_evaluate(trainer: transformers.Trainer, args: argparse.Namespace,
+ logger: None) -> None:
+- """
+- Trains and evaluates a machine learning model.
++ """Trains and evaluates a machine learning model.
+
+ Args:
+ trainer (Trainer): The training object to use for training and evaluation.
+@@ -75,10 +74,8 @@ def predict_and_save(trainer: transformers.Trainer,
+ tokenizer: transformers.PreTrainedTokenizer,
+ predict_dataset: Dataset, args: argparse.Namespace,
+ logger: None) -> None:
+- """
+- Make predictions on new data, save them to a file along with input examples,
+- and update the overall metrics.
+- """
++ """Make predictions on new data, save them to a file along with input
++ examples, and update the overall metrics."""
+ logger.info('=' * 80)
+ logger.info('*** Predict ***')
+ logger.info('=' * 80)
+diff --git a/chatllms/utils/apply_lora.py b/chatllms/utils/apply_lora.py
+index b46639e..0291f4e 100644
+--- a/chatllms/utils/apply_lora.py
++++ b/chatllms/utils/apply_lora.py
+@@ -1,5 +1,4 @@
+-"""
+-Apply the LoRA weights on top of a base model.
++"""Apply the LoRA weights on top of a base model.
+
+ Usage:
+ python3 apply_lora.py --base_model_path ~/model_weights/llama-7b --target_model_path ~/model_weights/baize-7b \
+@@ -24,7 +23,8 @@ def apply_lora(
+ use_auth_token: str = True,
+ trust_remote_code: bool = True,
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
+- """Applies the LoRA adapter to a base model and saves the resulting target model (optional).
++ """Applies the LoRA adapter to a base model and saves the resulting target
++ model (optional).
+
+ Args:
+ base_model_path (str): The path to the base model to which the LoRA adapter will be applied.
+@@ -36,7 +36,6 @@ def apply_lora(
+
+ Returns:
+ Tuple[AutoModelForCausalLM, AutoTokenizer]: A tuple containing the target model and its tokenizer.
+-
+ """
+ # Load the base model and tokenizer
+ print(f'Loading the base model from {base_model_path}')
+diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py
+index a053f5d..af4b092 100644
+--- a/chatllms/utils/model_utils.py
++++ b/chatllms/utils/model_utils.py
+@@ -9,19 +9,18 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, Trainer
+ from transformers.generation.logits_process import LogitsProcessor
+ from transformers.generation.utils import LogitsProcessorList
+ from transformers.trainer_utils import get_last_checkpoint
+-from transformers.generation.logits_process import LogitsProcessor
+-from transformers.generation.utils import LogitsProcessorList
++
+ from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN,
+ DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN)
+
+
+ def add_special_tokens_if_missing(tokenizer: PreTrainedTokenizer,
+ model: PreTrainedModel) -> None:
+- """
+- If 'llama' or 'baichuan' is in the model name or path, check if the special tokens are set correctly.
+- Add any missing special tokens to prevent them from being parsed into different tokens.
+- Note that these special tokens are present in the vocabulary.
+- Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
++ """If 'llama' or 'baichuan' is in the model name or path, check if the
++ special tokens are set correctly. Add any missing special tokens to prevent
++ them from being parsed into different tokens. Note that these special
++ tokens are present in the vocabulary. Note also that
++ `model.config.pad_token_id` is 0 which corresponds to `` token.
+
+ Args:
+ tokenizer: The pre-trained tokenizer.
+@@ -54,8 +53,7 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict[str, str],
+ tokenizer: PreTrainedTokenizer,
+ model: PreTrainedModel) -> None:
+ """Resize tokenizer and embedding to accommodate new special tokens.
+- 改变tokenizer和embedding的尺寸。
+- 一般需要将tokenizer和embedding的尺寸设置为64的倍数,方便GPU加速。
++ 改变tokenizer和embedding的尺寸。 一般需要将tokenizer和embedding的尺寸设置为64的倍数,方便GPU加速。
+
+ Args:
+ special_tokens_dict (Dict[str, str]): A dictionary of special tokens to be added to the tokenizer.
+@@ -99,11 +97,9 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict[str, str],
+
+ def find_all_linear_names(args: argparse.Namespace,
+ model: torch.nn.Module) -> List[str]:
+- """
+- Returns a list of names of all linear layers present in the given model.
++ """Returns a list of names of all linear layers present in the given model.
+ 如果args.bits是4,使用bitsandbytes库中的bnb.nn.Linear4bit层;
+- 如果args.bits是8,使用bitsandbytes库中的bnb.nn.Linear8bitLt层;
+- 否则,使用torch.nn.Linear层;
++ 如果args.bits是8,使用bitsandbytes库中的bnb.nn.Linear8bitLt层; 否则,使用torch.nn.Linear层;
+ 并记录下这些层的名称,保存在lora_module_names集合中。
+
+ Args:
+@@ -154,8 +150,7 @@ def find_all_linear_names(args: argparse.Namespace,
+
+ def print_trainable_parameters(args: argparse.Namespace,
+ model: torch.nn.Module) -> None:
+- """
+- Prints the number of trainable parameters in the given model.
++ """Prints the number of trainable parameters in the given model.
+
+ Args:
+ args (argparse.Namespace): A namespace containing arguments of the script. Must contain the 'bits' argument.
+@@ -198,8 +193,7 @@ def print_trainable_parameters(args: argparse.Namespace,
+
+
+ def verify_dtypes(model: torch.nn.Module) -> None:
+- """
+- 检查模型参数的数据类型,并输出各个数据类型在这些张量中所占的比例.
++ """检查模型参数的数据类型,并输出各个数据类型在这些张量中所占的比例.
+
+ :param model: 待检查的模型.
+ :return: 无返回值.
+@@ -225,9 +219,9 @@ def verify_dtypes(model: torch.nn.Module) -> None:
+
+ def check_training_finished(args: argparse.Namespace,
+ logger=None) -> Tuple[str, bool]:
+- """
+- Given a directory containing previous saved checkpoints, returns the path to the last checkpoint
+- if available along with a boolean flag indicating whether training has already been completed.
++ """Given a directory containing previous saved checkpoints, returns the
++ path to the last checkpoint if available along with a boolean flag
++ indicating whether training has already been completed.
+
+ Args:
+ checkpoint_dir (str): Path to the directory containing the saved checkpoints.
+@@ -276,24 +270,10 @@ def find_last_checkpoint(checkpoint_dir):
+ last_checkpoint = join(checkpoint_dir, f'checkpoint-{max_step}')
+ return last_checkpoint
+
+-# Avoid runtime error in model.generate(do_sample=True).
+-class InvalidScoreLogitsProcessor(LogitsProcessor):
+- def __call__(self, input_ids: torch.LongTensor,
+- scores: torch.FloatTensor) -> torch.FloatTensor:
+- if torch.isnan(scores).any() or torch.isinf(scores).any():
+- scores.zero_()
+- scores[..., 0] = 1.0
+- return scores
+-
+-
+-def get_logits_processor() -> LogitsProcessorList:
+- logits_processor = LogitsProcessorList()
+- logits_processor.append(InvalidScoreLogitsProcessor())
+- return logits_processor
+-
+
+ # Avoid runtime error in model.generate(do_sample=True).
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
++
+ def __call__(self, input_ids: torch.LongTensor,
+ scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+diff --git a/chatllms/utils/stream_server.py b/chatllms/utils/stream_server.py
+index 66bf42f..93288b3 100644
+--- a/chatllms/utils/stream_server.py
++++ b/chatllms/utils/stream_server.py
+@@ -1,5 +1,5 @@
+-"""
+-Helpers to support streaming generate output.
++"""Helpers to support streaming generate output.
++
+ Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
+ """
+ import traceback
+@@ -10,6 +10,7 @@ import transformers
+
+
+ class Stream(transformers.StoppingCriteria):
++
+ def __init__(self, callback_func=None):
+ self.callback_func = callback_func
+
+@@ -20,10 +21,9 @@ class Stream(transformers.StoppingCriteria):
+
+
+ class Iteratorize:
+- """
+- Transforms a function that takes a callback
+- into a lazy iterator (generator).
+- """
++ """Transforms a function that takes a callback into a lazy iterator
++ (generator)."""
++
+ def __init__(self, func, kwargs={}, callback=None):
+ self.mfunc = func
+ self.c_callback = callback
+diff --git a/chatllms/utils/template.py b/chatllms/utils/template.py
+index a96b8fa..e865331 100644
+--- a/chatllms/utils/template.py
++++ b/chatllms/utils/template.py
+@@ -7,8 +7,7 @@ logger = logging.getLogger(__name__)
+
+ @dataclass
+ class PromptTemplate(object):
+- """
+- A template for formatting a conversation prompt.
++ """A template for formatting a conversation prompt.
+
+ Args:
+ name: Name of template
+@@ -16,7 +15,6 @@ class PromptTemplate(object):
+ prompt: Prompt text
+ sep: Separator between prompts
+ use_history: Whether to use conversation history
+-
+ """
+
+ name: str
+@@ -31,8 +29,7 @@ class PromptTemplate(object):
+ history: Optional[List[Tuple[str, str]]] = None,
+ prefix: Optional[str] = None,
+ ) -> str:
+- """
+- Returns a string containing prompt without response.
++ """Returns a string containing prompt without response.
+
+ Args:
+ query (str): The input query text.
+@@ -67,8 +64,7 @@ class PromptTemplate(object):
+ query: str,
+ history: Optional[List[Tuple[str, str]]] = None,
+ prefix: Optional[str] = None) -> List[str]:
+- """
+- Formats the conversation example.
++ """Formats the conversation example.
+
+ Args:
+ query (str): The input query text.
+@@ -100,8 +96,7 @@ class PromptTemplate(object):
+ sep: str,
+ use_history: Optional[bool] = True,
+ ) -> None:
+- """
+- Registers a new conversation template.
++ """Registers a new conversation template.
+
+ Args:
+ prefix (str): The prefix text for the prompt.
+@@ -116,13 +111,9 @@ class PromptTemplate(object):
+ self.use_history = use_history
+
+ def __post_init__(self):
+- """
+- Initializes the instance of the class.
+- """
++ """Initializes the instance of the class."""
+ if self.name == 'default':
+- """
+- Supports language model inference without histories.
+- """
++ """Supports language model inference without histories."""
+ self.register_template(name='vanilla',
+ prefix='',
+ prompt='{query}',
+diff --git a/cli_demo.py b/cli_demo.py
+index 959be76..85e34f0 100644
+--- a/cli_demo.py
++++ b/cli_demo.py
+@@ -22,8 +22,8 @@ def generate_response(
+ model: PreTrainedModel,
+ generation_args: dict,
+ ) -> List[str]:
+- """
+- Generates a response to the given query using GPT-3.5 model and prints it to the console.
++ """Generates a response to the given query using GPT-3.5 model and prints
++ it to the console.
+
+ Args:
+ query (str): The input query for which a response is to be generated.
+diff --git a/data/README.md b/data/README.md
+index 5612380..eff7c4f 100644
+--- a/data/README.md
++++ b/data/README.md
+@@ -1,4 +1,3 @@
+-
+ # How to use the data
+
+ ## Datasets Supported by the Framework
+@@ -44,7 +43,6 @@ We provide the following datasets for the experiments in this framework.
+ 数据集说明:开源了数据规模为145k的价值对齐数据集,该数据集对于每个prompt包括了拒绝&正向建议,(safe and reponsibility) > 拒绝为主(safe) > 风险回复(unsafe)三种类型,可用于增强SFT模型的安全性或用于训练reward模型。
+ - [CValues-Comparison中文大模型价值观比较数据集](https://modelscope.cn/datasets/damo/CValues-Comparison/summary)
+
+-
+ ## Dataset formation
+
+ The `dataset_info.yaml` file contains all the datasets can be used in the experiments. The following is the format of the datasets, main including the following fields.
+diff --git a/data/alpaca_zh_pcyn.yaml b/data/alpaca_zh_pcyn.yaml
+deleted file mode 100644
+index 84b6a01..0000000
+--- a/data/alpaca_zh_pcyn.yaml
++++ /dev/null
+@@ -1,42 +0,0 @@
+-# The dataset_info.yaml file contains the information of the datasets used in the experiments.
+-coig:
+- hf_hub_url: BAAI/COIG
+- local_path: /userhome/jianzhnie/prompt_data/COIG/train_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-cvalues_comparison_train:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/train_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-cvalues_comparison_test:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/test_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-olcc:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/olcc/olcc_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-100PoisonMpts:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/100PoisonMpts/train_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-safety_prompt_part1:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/attack_scenarios_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+-
+-safety_prompt_part2:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_alpaca.json
+- dataset_format: alpaca
+- multi_turn: False
+diff --git a/data/dataset_info.py b/data/dataset_info.py
+index 7f38e6c..62c96f5 100644
+--- a/data/dataset_info.py
++++ b/data/dataset_info.py
+@@ -2,9 +2,9 @@ from os.path import join
+
+
+ def get_dataset_info(dataset_dir):
+- """
+- Returns the datasets info to a dataset based on a pre-defined map of dataset names to their corresponding URLs on the internet
+- or local file paths.
++ """Returns the datasets info to a dataset based on a pre-defined map of
++ dataset names to their corresponding URLs on the internet or local file
++ paths.
+
+ Args:
+ dataset_dir (str): The local directory where the dataset is stored; this is used for datasets that are stored locally.
+diff --git a/data/vicuna_zh_pcyn.yaml b/data/vicuna_zh_pcyn.yaml
+deleted file mode 100644
+index 610d8e1..0000000
+--- a/data/vicuna_zh_pcyn.yaml
++++ /dev/null
+@@ -1,42 +0,0 @@
+-# The dataset_info.yaml file contains the information of the datasets used in the experiments.
+-coig:
+- hf_hub_url: BAAI/COIG
+- local_path: /userhome/jianzhnie/prompt_data/COIG/train_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-cvalues_comparison_train:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/train_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-cvalues_comparison_test:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/CValues-Comparison/test_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-olcc:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/olcc/olcc_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-100PoisonMpts:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/100PoisonMpts/train_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-safety_prompt_part1:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/attack_scenarios_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+-
+-safety_prompt_part2:
+- hf_hub_url: ''
+- local_path: /userhome/jianzhnie/prompt_data/Safety-Prompts/safety_scenarios_vicuna.json
+- dataset_format: sharegpt
+- multi_turn: True
+diff --git a/examples/clean_sharegpt/clean_sharegpt.py b/examples/clean_sharegpt/clean_sharegpt.py
+index 017fecb..ed4fab7 100644
+--- a/examples/clean_sharegpt/clean_sharegpt.py
++++ b/examples/clean_sharegpt/clean_sharegpt.py
+@@ -99,8 +99,8 @@ def format_roles(
+ def filter_invalid_roles(
+ raw_data: List[Dict[str,
+ any]]) -> List[Dict[str, List[Dict[str, any]]]]:
+- """
+- Filter out invalid contents based on the roles assigned to each conversation.
++ """Filter out invalid contents based on the roles assigned to each
++ conversation.
+
+ Args:
+ raw_data: A list of dictionaries containing conversation data.
+diff --git a/examples/clean_sharegpt/hardcoded_questions.py b/examples/clean_sharegpt/hardcoded_questions.py
+index cb89c63..abafaf2 100644
+--- a/examples/clean_sharegpt/hardcoded_questions.py
++++ b/examples/clean_sharegpt/hardcoded_questions.py
+@@ -2,9 +2,8 @@ import json
+
+
+ def identity_questions():
+- """ "
+- Adopted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py
+- """
++ """" Adopted from https://github.com/young-
++ geng/koala_data_pipeline/blob/main/process_hard_coded_data.py."""
+ content = []
+
+ name = 'Vicuna'
+diff --git a/examples/clean_sharegpt/merge.py b/examples/clean_sharegpt/merge.py
+index bdf2b40..727b9a8 100644
+--- a/examples/clean_sharegpt/merge.py
++++ b/examples/clean_sharegpt/merge.py
+@@ -1,5 +1,4 @@
+-"""
+-Merge two conversation files into one
++"""Merge two conversation files into one.
+
+ Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
+ """
+diff --git a/examples/clean_sharegpt/split_long_conversation.py b/examples/clean_sharegpt/split_long_conversation.py
+index e14b952..1e551c6 100644
+--- a/examples/clean_sharegpt/split_long_conversation.py
++++ b/examples/clean_sharegpt/split_long_conversation.py
+@@ -1,5 +1,4 @@
+-"""
+-Split long conversations based on certain max length.
++"""Split long conversations based on certain max length.
+
+ Usage: python3 -m split_long_conversation.py \
+ --in sharegpt_clean.json \
+@@ -23,8 +22,8 @@ from tqdm import tqdm
+
+ def make_sample(sample: Dict[str, any], start_idx: int,
+ end_idx: int) -> Dict[str, any]:
+- """
+- Create a new sample dictionary by selecting conversations from the given sample.
++ """Create a new sample dictionary by selecting conversations from the given
++ sample.
+
+ Args:
+ sample (Dict[str, any]): The original sample dictionary.
+@@ -43,8 +42,8 @@ def make_sample(sample: Dict[str, any], start_idx: int,
+
+
+ def split_one_sample(sample: Dict[str, any]) -> List[Dict[str, any]]:
+- """
+- Split a single sample into multiple samples based on conversation lengths.
++ """Split a single sample into multiple samples based on conversation
++ lengths.
+
+ Args:
+ sample (Dict[str, any]): The original sample dictionary.
+@@ -94,8 +93,8 @@ def worker(input_data: List[Dict[str, Any]]):
+ def split_all(raw_data: List[Dict[str, Any]],
+ tokenizer_: transformers.PreTrainedTokenizer,
+ max_length_: int) -> List[Dict[str, Any]]:
+- """
+- Split the content into smaller parts based on the max token length constraint.
++ """Split the content into smaller parts based on the max token length
++ constraint.
+
+ Args:
+ raw_data (List[Dict[str, Any]]): The list of samples to split.
+diff --git a/examples/finetune_llm/baichuan7b_demo.py b/examples/finetune_llm/baichuan7b_demo.py
+index 73169e0..cfdca1b 100644
+--- a/examples/finetune_llm/baichuan7b_demo.py
++++ b/examples/finetune_llm/baichuan7b_demo.py
+@@ -19,5 +19,5 @@ def main(load_in_8bit=True, model_path=''):
+
+ if __name__ == '__main__':
+ load_in_8bit = True
+- model_path = '/home/robin/work_dir/llm/llm_pretrain_model/baichuan'
++ model_path = 'baichuan'
+ main(load_in_8bit, model_path)
+diff --git a/examples/finetune_llm/finetune_llama_with_qlora.py b/examples/finetune_llm/finetune_llama_with_qlora.py
+index d9dd3f2..a7bc1ee 100644
+--- a/examples/finetune_llm/finetune_llama_with_qlora.py
++++ b/examples/finetune_llm/finetune_llama_with_qlora.py
+@@ -15,9 +15,7 @@ DEFAULT_UNK_TOKEN = ''
+
+
+ def print_trainable_parameters(model: AutoModelForCausalLM) -> None:
+- """
+- Prints the number of trainable parameters in the model.
+- """
++ """Prints the number of trainable parameters in the model."""
+ trainable_params, all_param = 0, 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+diff --git a/examples/finetune_llm/qlora_int4_finetune.py b/examples/finetune_llm/qlora_int4_finetune.py
+index 123029c..3e7d46a 100644
+--- a/examples/finetune_llm/qlora_int4_finetune.py
++++ b/examples/finetune_llm/qlora_int4_finetune.py
+@@ -284,6 +284,7 @@ def find_all_linear_names(args, model):
+
+
+ class SavePeftModelCallback(transformers.TrainerCallback):
++
+ def save_model(self, args, state, kwargs):
+ logger.info('Saving PEFT checkpoint...')
+ if state.best_model_checkpoint is not None:
+@@ -307,6 +308,7 @@ class SavePeftModelCallback(transformers.TrainerCallback):
+ return control
+
+ def on_train_end(self, args, state, control, **kwargs):
++
+ def touch(fname, times=None):
+ with open(fname, 'a'):
+ os.utime(fname, times)
+@@ -408,9 +410,7 @@ def get_accelerate_model(args, checkpoint_dir):
+
+
+ def print_trainable_parameters(args, model):
+- """
+- Prints the number of trainable parameters in the model.
+- """
++ """Prints the number of trainable parameters in the model."""
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+@@ -576,9 +576,8 @@ def local_dataset(dataset_name):
+
+ def make_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ args) -> Dict:
+- """
+- Make dataset and collator for supervised fine-tuning.
+- Datasets are expected to have the following columns: { `input`, `output` }
++ """Make dataset and collator for supervised fine-tuning. Datasets are
++ expected to have the following columns: { `input`, `output` }
+
+ Available datasets to be selected with `dataset` argument:
+ - alpaca, 52002 examples
+@@ -597,8 +596,8 @@ def make_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used)
+ - flan (FLAN v2), up to 20M examples available
+ - vicuna
+-
+ """
++
+ def load_data(dataset_name):
+ if dataset_name == 'alpaca':
+ return load_dataset('tatsu-lab/alpaca')
+@@ -832,6 +831,7 @@ def train():
+ accuracy = evaluate.load('accuracy')
+
+ class MMLUEvalCallback(transformers.TrainerCallback):
++
+ def on_evaluate(self, args, state, control, model, **kwargs):
+ data_loader = trainer.get_eval_dataloader(mmlu_dataset)
+ source_max_len = trainer.data_collator.source_max_len
+diff --git a/examples/format_data/convert_oasst1.py b/examples/format_data/convert_oasst1.py
+index 31a29e8..ac965dc 100644
+--- a/examples/format_data/convert_oasst1.py
++++ b/examples/format_data/convert_oasst1.py
+@@ -16,12 +16,14 @@ def json_load(in_file):
+
+
+ def convert_oasst1_data(data_dir, output_dir):
+- '''
+- For OASST1, because it's in a tree structure, where every user input might get multiple replies,
+- we have to save every path from the root node to the assistant reply (including both leaf node and intemediate node).
+- This results in some of the messages being duplicated among different paths (instances).
+- Be careful when using this dataset for training. Ideally, you should only minimize the loss of the last message in each path.
+- '''
++ """For OASST1, because it's in a tree structure, where every user input
++ might get multiple replies, we have to save every path from the root node
++ to the assistant reply (including both leaf node and intemediate node).
++
++ This results in some of the messages being duplicated among different paths
++ (instances). Be careful when using this dataset for training. Ideally, you
++ should only minimize the loss of the last message in each path.
++ """
+ conversations = []
+ with open(os.path.join(data_dir, '2023-04-12_oasst_ready.trees.jsonl'),
+ 'r') as fin:
+diff --git a/examples/format_data/merge.py b/examples/format_data/merge.py
+index 8da9846..84324e7 100644
+--- a/examples/format_data/merge.py
++++ b/examples/format_data/merge.py
+@@ -1,5 +1,4 @@
+-"""
+-Merge two conversation files into one
++"""Merge two conversation files into one.
+
+ Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json
+ """
+diff --git a/examples/vllm/apil_chient.py b/examples/vllm/apil_chient.py
+deleted file mode 100644
+index d3ba848..0000000
+--- a/examples/vllm/apil_chient.py
++++ /dev/null
+@@ -1,77 +0,0 @@
+-"""Example Python client for vllm.entrypoints.api_server"""
+-
+-import argparse
+-import json
+-from typing import Iterable, List
+-
+-import requests
+-
+-
+-def clear_line(n: int = 1) -> None:
+- LINE_UP = '\033[1A'
+- LINE_CLEAR = '\x1b[2K'
+- for _ in range(n):
+- print(LINE_UP, end=LINE_CLEAR, flush=True)
+-
+-
+-def post_http_request(prompt: str,
+- api_url: str,
+- n: int = 1,
+- stream: bool = False) -> requests.Response:
+- headers = {'User-Agent': 'Test Client'}
+- pload = {
+- 'prompt': prompt,
+- 'n': n,
+- 'use_beam_search': True,
+- 'temperature': 0.0,
+- 'max_tokens': 16,
+- 'stream': stream,
+- }
+- response = requests.post(api_url, headers=headers, json=pload, stream=True)
+- return response
+-
+-
+-def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
+- for chunk in response.iter_lines(chunk_size=8192,
+- decode_unicode=False,
+- delimiter=b'\0'):
+- if chunk:
+- data = json.loads(chunk.decode('utf-8'))
+- output = data['text']
+- yield output
+-
+-
+-def get_response(response: requests.Response) -> List[str]:
+- data = json.loads(response.content)
+- output = data['text']
+- return output
+-
+-
+-if __name__ == '__main__':
+- parser = argparse.ArgumentParser()
+- parser.add_argument('--host', type=str, default='localhost')
+- parser.add_argument('--port', type=int, default=8000)
+- parser.add_argument('--n', type=int, default=4)
+- parser.add_argument('--prompt', type=str, default='San Francisco is a')
+- parser.add_argument('--stream', action='store_true')
+- args = parser.parse_args()
+- prompt = args.prompt
+- api_url = f'http://{args.host}:{args.port}/generate'
+- n = args.n
+- stream = args.stream
+-
+- print(f'Prompt: {prompt!r}\n', flush=True)
+- response = post_http_request(prompt, api_url, n, stream)
+-
+- if stream:
+- num_printed_lines = 0
+- for h in get_streaming_response(response):
+- clear_line(num_printed_lines)
+- num_printed_lines = 0
+- for i, line in enumerate(h):
+- num_printed_lines += 1
+- print(f'Beam candidate {i}: {line!r}', flush=True)
+- else:
+- output = get_response(response)
+- for i, line in enumerate(output):
+- print(f'Beam candidate {i}: {line!r}', flush=True)
+diff --git a/examples/vllm/vllm_demo.py b/examples/vllm/vllm_demo.py
+deleted file mode 100644
+index f23be9b..0000000
+--- a/examples/vllm/vllm_demo.py
++++ /dev/null
+@@ -1,19 +0,0 @@
+-from vllm import LLM, SamplingParams
+-
+-prompts = [
+- 'Hello, my name is',
+- 'The president of the United States is',
+- 'The capital of France is',
+- 'The future of AI is',
+-]
+-sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+-
+-llm = LLM(model='decapoda-research/llama-7b-hf', gpu_memory_utilization=0.9)
+-
+-# Print the outputs.
+-for i in range(10):
+- outputs = llm.generate(prompts, sampling_params)
+- for output in outputs:
+- prompt = output.prompt
+- generated_text = output.outputs[0].text
+- print(f'Prompt: {prompt!r}, Generated text: {generated_text!r}')
+diff --git a/scripts/server/apply_lora_to_base_model.sh b/scripts/server/apply_lora_to_base_model.sh
+index d4a21f5..040f901 100644
+--- a/scripts/server/apply_lora_to_base_model.sh
++++ b/scripts/server/apply_lora_to_base_model.sh
+@@ -1,4 +1,4 @@
+ CUDA_VISIBLE_DEVICES=0 python chatllms/utils/apply_lora.py \
+ --base-model-path ~/checkpoints/baichuan7b/ \
+ --lora-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/checkpoint-15000 \
+- --target-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
+\ No newline at end of file
++ --target-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
+diff --git a/scripts/server/run_inference.sh b/scripts/server/run_inference.sh
+index e316b23..ff65545 100755
+--- a/scripts/server/run_inference.sh
++++ b/scripts/server/run_inference.sh
+@@ -1,3 +1,3 @@
+ # generated_chat_vicuna
+ CUDA_VISIBLE_DEVICES=0 python single_chat.py \
+- --model_name_or_path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
+\ No newline at end of file
++ --model_name_or_path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
+diff --git a/server/gradio_qlora_webserver.py b/server/gradio_qlora_webserver.py
+index 3493495..9e6eb16 100644
+--- a/server/gradio_qlora_webserver.py
++++ b/server/gradio_qlora_webserver.py
+@@ -32,12 +32,11 @@ logger = logging.getLogger(__name__)
+
+
+ class Prompter:
+- """
+- A class for generating prompts and extracting responses from generated text.
+- """
++ """A class for generating prompts and extracting responses from generated
++ text."""
++
+ def __init__(self, prompt_template: str = None):
+- """
+- Initializes a new instance of the Prompter class.
++ """Initializes a new instance of the Prompter class.
+
+ Args:
+ prompt_template (str): The name of the prompt template to use. Default is None.
+@@ -50,8 +49,7 @@ class Prompter:
+ instruction: str,
+ input: Union[str, None] = None,
+ response: Union[str, None] = None) -> str:
+- """
+- Generates a prompt based on the specified inputs.
++ """Generates a prompt based on the specified inputs.
+
+ Args:
+ instruction (str): The instruction to include in the prompt.
+@@ -76,8 +74,7 @@ class Prompter:
+ return prompt_text
+
+ def get_response(self, output: str) -> str:
+- """
+- Extracts the response from the generated text.
++ """Extracts the response from the generated text.
+
+ Args:
+ output (str): The generated text to extract the response from.
+diff --git a/server/gradio_webserver.py b/server/gradio_webserver.py
+index 795ba54..0592fcd 100644
+--- a/server/gradio_webserver.py
++++ b/server/gradio_webserver.py
+@@ -11,6 +11,7 @@ from chatllms.utils.stream_server import Iteratorize, Stream
+
+
+ class Prompter(object):
++
+ def __init__(self) -> None:
+ self.PROMPT_DICT = {
+ 'prompt_input':
+diff --git a/server/multi_chat.py b/server/multi_chat.py
+index f368aaa..18b927e 100644
+--- a/server/multi_chat.py
++++ b/server/multi_chat.py
+@@ -12,9 +12,7 @@ from chatllms.utils.model_utils import get_logits_processor
+
+
+ def main(model_server_args, generation_args):
+- """
+- 多轮对话,不具有对话历史的记忆功能
+- """
++ """多轮对话,不具有对话历史的记忆功能."""
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model = AutoModelForCausalLM.from_pretrained(
+ model_server_args.model_name_or_path,
+diff --git a/server/single_chat.py b/server/single_chat.py
+index a85c270..ef679a4 100644
+--- a/server/single_chat.py
++++ b/server/single_chat.py
+@@ -17,8 +17,8 @@ from chatllms.utils.model_utils import get_logits_processor
+ def generate_response(query: str, tokenizer: PreTrainedTokenizer,
+ model: PreTrainedModel,
+ generation_args: dict) -> List[str]:
+- """
+- Generates a response to the given query using GPT-3.5 model and prints it to the console.
++ """Generates a response to the given query using GPT-3.5 model and prints
++ it to the console.
+
+ Args:
+ query (str): The input query for which a response is to be generated.
+@@ -63,9 +63,7 @@ def generate_response(query: str, tokenizer: PreTrainedTokenizer,
+
+
+ def main():
+- """
+- 单轮对话,不具有对话历史的记忆功能
+- Run conversational agent loop with input/output.
++ """单轮对话,不具有对话历史的记忆功能 Run conversational agent loop with input/output.
+
+ Args:
+ model_args: Arguments for loading model
+diff --git a/train.py b/train.py
+index 3606ef4..8519c83 100644
+--- a/train.py
++++ b/train.py
+@@ -15,8 +15,8 @@ from chatllms.utils.model_utils import (add_special_tokens_if_missing,
+
+
+ def load_model_tokenizer(args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+- """
+- Load a pre-trained model and tokenizer for natural language processing tasks.
++ """Load a pre-trained model and tokenizer for natural language processing
++ tasks.
+
+ Args:
+ args: An object containing the input arguments.
+@@ -66,8 +66,7 @@ def load_model_tokenizer(args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+
+
+ def train() -> None:
+- """
+- Trains a language model using Hugging Face's Transformers library.
++ """Trains a language model using Hugging Face's Transformers library.
+
+ Args:
+ model_args (ModelArguments): The arguments for the model configuration.
+@@ -76,7 +75,6 @@ def train() -> None:
+
+ Returns:
+ None
+-
+ """
+ parser = HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+diff --git a/train_lora.py b/train_lora.py
+index e497cfc..81560e7 100644
+--- a/train_lora.py
++++ b/train_lora.py
+@@ -32,9 +32,9 @@ class LoraArguments:
+
+
+ def maybe_zero_3(param: Union[torch.Tensor, object]) -> torch.Tensor:
+- """
+- Applies zero.GatheredParameters to gather the parameter if it has ds_id attribute,
+- and clones and detaches the tensor data if ds_status is ZeroParamStatus.NOT_AVAILABLE.
++ """Applies zero.GatheredParameters to gather the parameter if it has ds_id
++ attribute, and clones and detaches the tensor data if ds_status is
++ ZeroParamStatus.NOT_AVAILABLE.
+
+ Args:
+ param: The parameter to be processed.
+@@ -58,8 +58,7 @@ def maybe_zero_3(param: Union[torch.Tensor, object]) -> torch.Tensor:
+ # Borrowed from peft.utils.get_peft_model_state_dict
+ def get_peft_state_maybe_zero_3(named_params: List[Tuple[str, torch.Tensor]],
+ bias: str) -> Dict[str, torch.Tensor]:
+- """
+- Filters and processes named parameters based on the specified bias.
++ """Filters and processes named parameters based on the specified bias.
+
+ Args:
+ named_params: An iterable containing tuples of parameter names and their corresponding values.
+@@ -107,8 +106,8 @@ def get_peft_state_maybe_zero_3(named_params: List[Tuple[str, torch.Tensor]],
+ def load_model_tokenizer(
+ args: argparse.Namespace
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+- """
+- Load a pre-trained model and tokenizer for natural language processing tasks.
++ """Load a pre-trained model and tokenizer for natural language processing
++ tasks.
+
+ Args:
+ args: An object containing the input arguments.
diff --git a/scripts/server/apply_lora_to_base_model.sh b/scripts/server/apply_lora_to_base_model.sh
index d4a21f5..040f901 100644
--- a/scripts/server/apply_lora_to_base_model.sh
+++ b/scripts/server/apply_lora_to_base_model.sh
@@ -1,4 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python chatllms/utils/apply_lora.py \
--base-model-path ~/checkpoints/baichuan7b/ \
--lora-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/checkpoint-15000 \
- --target-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
\ No newline at end of file
+ --target-model-path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
diff --git a/scripts/server/run_inference.sh b/scripts/server/run_inference.sh
index e316b23..ff65545 100755
--- a/scripts/server/run_inference.sh
+++ b/scripts/server/run_inference.sh
@@ -1,3 +1,3 @@
# generated_chat_vicuna
CUDA_VISIBLE_DEVICES=0 python single_chat.py \
- --model_name_or_path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
\ No newline at end of file
+ --model_name_or_path ./work_dir/vicuna_merge_vicuna-baichuan-7b-1gpu/merged_model
diff --git a/server/gradio_qlora_webserver.py b/server/gradio_qlora_webserver.py
index 3493495..9e6eb16 100644
--- a/server/gradio_qlora_webserver.py
+++ b/server/gradio_qlora_webserver.py
@@ -32,12 +32,11 @@
class Prompter:
- """
- A class for generating prompts and extracting responses from generated text.
- """
+ """A class for generating prompts and extracting responses from generated
+ text."""
+
def __init__(self, prompt_template: str = None):
- """
- Initializes a new instance of the Prompter class.
+ """Initializes a new instance of the Prompter class.
Args:
prompt_template (str): The name of the prompt template to use. Default is None.
@@ -50,8 +49,7 @@ def generate_prompt(self,
instruction: str,
input: Union[str, None] = None,
response: Union[str, None] = None) -> str:
- """
- Generates a prompt based on the specified inputs.
+ """Generates a prompt based on the specified inputs.
Args:
instruction (str): The instruction to include in the prompt.
@@ -76,8 +74,7 @@ def generate_prompt(self,
return prompt_text
def get_response(self, output: str) -> str:
- """
- Extracts the response from the generated text.
+ """Extracts the response from the generated text.
Args:
output (str): The generated text to extract the response from.
diff --git a/server/gradio_webserver.py b/server/gradio_webserver.py
index 795ba54..0592fcd 100644
--- a/server/gradio_webserver.py
+++ b/server/gradio_webserver.py
@@ -11,6 +11,7 @@
class Prompter(object):
+
def __init__(self) -> None:
self.PROMPT_DICT = {
'prompt_input':
diff --git a/server/multi_chat.py b/server/multi_chat.py
index f368aaa..18b927e 100644
--- a/server/multi_chat.py
+++ b/server/multi_chat.py
@@ -12,9 +12,7 @@
def main(model_server_args, generation_args):
- """
- 多轮对话,不具有对话历史的记忆功能
- """
+ """多轮对话,不具有对话历史的记忆功能."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForCausalLM.from_pretrained(
model_server_args.model_name_or_path,
diff --git a/server/single_chat.py b/server/single_chat.py
index a85c270..ef679a4 100644
--- a/server/single_chat.py
+++ b/server/single_chat.py
@@ -17,8 +17,8 @@
def generate_response(query: str, tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
generation_args: dict) -> List[str]:
- """
- Generates a response to the given query using GPT-3.5 model and prints it to the console.
+ """Generates a response to the given query using GPT-3.5 model and prints
+ it to the console.
Args:
query (str): The input query for which a response is to be generated.
@@ -63,9 +63,7 @@ def generate_response(query: str, tokenizer: PreTrainedTokenizer,
def main():
- """
- 单轮对话,不具有对话历史的记忆功能
- Run conversational agent loop with input/output.
+ """单轮对话,不具有对话历史的记忆功能 Run conversational agent loop with input/output.
Args:
model_args: Arguments for loading model
diff --git a/train.py b/train.py
index 3606ef4..8519c83 100644
--- a/train.py
+++ b/train.py
@@ -15,8 +15,8 @@
def load_model_tokenizer(args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
- """
- Load a pre-trained model and tokenizer for natural language processing tasks.
+ """Load a pre-trained model and tokenizer for natural language processing
+ tasks.
Args:
args: An object containing the input arguments.
@@ -66,8 +66,7 @@ def load_model_tokenizer(args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
def train() -> None:
- """
- Trains a language model using Hugging Face's Transformers library.
+ """Trains a language model using Hugging Face's Transformers library.
Args:
model_args (ModelArguments): The arguments for the model configuration.
@@ -76,7 +75,6 @@ def train() -> None:
Returns:
None
-
"""
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
diff --git a/train_lora.py b/train_lora.py
index e497cfc..81560e7 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -32,9 +32,9 @@ class LoraArguments:
def maybe_zero_3(param: Union[torch.Tensor, object]) -> torch.Tensor:
- """
- Applies zero.GatheredParameters to gather the parameter if it has ds_id attribute,
- and clones and detaches the tensor data if ds_status is ZeroParamStatus.NOT_AVAILABLE.
+ """Applies zero.GatheredParameters to gather the parameter if it has ds_id
+ attribute, and clones and detaches the tensor data if ds_status is
+ ZeroParamStatus.NOT_AVAILABLE.
Args:
param: The parameter to be processed.
@@ -58,8 +58,7 @@ def maybe_zero_3(param: Union[torch.Tensor, object]) -> torch.Tensor:
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params: List[Tuple[str, torch.Tensor]],
bias: str) -> Dict[str, torch.Tensor]:
- """
- Filters and processes named parameters based on the specified bias.
+ """Filters and processes named parameters based on the specified bias.
Args:
named_params: An iterable containing tuples of parameter names and their corresponding values.
@@ -107,8 +106,8 @@ def get_peft_state_maybe_zero_3(named_params: List[Tuple[str, torch.Tensor]],
def load_model_tokenizer(
args: argparse.Namespace
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
- """
- Load a pre-trained model and tokenizer for natural language processing tasks.
+ """Load a pre-trained model and tokenizer for natural language processing
+ tasks.
Args:
args: An object containing the input arguments.