Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
Browse files Browse the repository at this point in the history
…nto develop
  • Loading branch information
cyber-pioneer committed Mar 4, 2024
2 parents 8d49caf + 17cc169 commit 4e4a149
Show file tree
Hide file tree
Showing 39 changed files with 3,713 additions and 138 deletions.
6 changes: 5 additions & 1 deletion examples/RLHF/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler._scale
optimizer_was_run = not self.scaler._cache_founf_inf
# Compatible with paddlepaddle 2.6.0 using typo word.
if hasattr(self.scaler, "_cache_founf_inf"):
optimizer_was_run = not self.scaler._cache_founf_inf
else:
optimizer_was_run = not self.scaler._cache_found_inf
if not optimizer_was_run:
scale_before_value = scale_before.cpu().numpy()
scale_after_value = scale_after.cpu().numpy()
Expand Down
1 change: 1 addition & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ModelArgument:

# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
num_prefix_tokens: int = field(default=128, metadata={"help": "Number of prefix tokens"})

from_aistudio: bool = field(default=False, metadata={"help": "Whether to load model from aistudio"})
Expand Down
6 changes: 3 additions & 3 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen"]:
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama."
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral"
)


Expand Down Expand Up @@ -107,7 +107,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):
# 0. prepare data
context_data = example.get("context", {})
context_data["is_training"] = True

example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]]
example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]]

Expand Down
19 changes: 12 additions & 7 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,18 @@ def neft_post_hook(module, input, output):
multi_query_group_num=prefix_tuning_params["multi_query_group_num"],
dtype=dtype,
)
model = PrefixModelForCausalLM(
model=model,
prefix_config=prefix_config,
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
)
model.mark_only_prefix_as_trainable()
if model_args.prefix_path is None:
model = PrefixModelForCausalLM(
model=model,
prefix_config=prefix_config,
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
)
else:
model = PrefixModelForCausalLM.from_pretrained(
model=model,
prefix_path=model_args.prefix_path,
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
)
model.print_trainable_parameters()

if model_args.lora:
Expand All @@ -422,7 +428,6 @@ def neft_post_hook(module, input, output):
model = LoRAModel(model, lora_config)
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
model.mark_only_lora_as_trainable()
model.print_trainable_parameters()

def compute_metrics_do_generation(eval_preds):
Expand Down
16 changes: 8 additions & 8 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,16 +542,16 @@ def main():

print("Final pre-training config:", config)

# Set the dtype for loading model
dtype = "float32"
if training_args.fp16_opt_level == "O2":
if training_args.fp16:
dtype = "float16"
if training_args.bf16:
dtype = "bfloat16"
# # Set the dtype for loading model
# dtype = "float32"
# if training_args.fp16_opt_level == "O2":
# if training_args.fp16:
# dtype = "float16"
# if training_args.bf16:
# dtype = "bfloat16"

with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
model = model_class.from_config(config, dtype="float32")
criterion = criterion_class(config)

for param in model.parameters():
Expand Down
2 changes: 1 addition & 1 deletion llm/llama/pt_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"prefix_tuning": true,
"zero_padding": false,
"use_flash_attention": false
}
}
32 changes: 32 additions & 0 deletions llm/mixtral/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mixtral_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
30 changes: 30 additions & 0 deletions llm/mixtral/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mixtral_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"sharding": "stage2",
"pipeline_parallel_degree": 1
}
10 changes: 10 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def get_lora_target_modules(model):
".*mlp.w2.*",
".*mlp.c_proj.*",
]
elif model.base_model_prefix == "mixtral":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*w1.*",
".*w2.*",
".*w3.*",
]
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
return target_modules
Expand Down
2 changes: 2 additions & 0 deletions model_zoo/ernie-1.0/preprocess/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ common config:
打印日志间隔,interval表示处理 文本行数/doc数的 间隔。
--workers WORKERS Number of worker processes to launch
处理文本id化的进程个数。
--max_repeated_len Max length of repeated chars to keep
最大保留重复的字符个数。
```
通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`baike_sample.bin`, 文章索引信息`baike_sample.idx`.

Expand Down
19 changes: 19 additions & 0 deletions model_zoo/ernie-1.0/preprocess/create_pretraining_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def get_args():
group.add_argument("--log_interval", type=int, default=100, help="Interval between progress updates")
group.add_argument("--workers", type=int, default=1, help="Number of worker processes to launch")
group.add_argument("--max_doc_num", type=int, default=sys.maxsize, help="Number of worker processes to launch")
group.add_argument(
"--max_repeated_len", type=int, default=100, help="The maximum length of the repeated characters to keep"
)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -278,8 +281,24 @@ def process(text):

Converter.process = process

def remove_repeated_chars(text, max_repeated_len=100):
"""
Removes repeated characters from the given text, where the length of
the repeated characters is greater than or equal to the specified length.
Args:
text (str): The input text from which to remove repeated characters.
length (int, optional): The minimum length of the repeated characters. Defaults to 15.
Returns:
str: The modified text with the repeated characters removed.
"""
pattern = r"(.)\1{" + str(max_repeated_len) + ",}"
return re.sub(pattern, r"\1", text)

def encode(self, json_line):
text = json.loads(json_line)[self.args.json_key]
text = Converter.remove_repeated_chars(text, self.args.max_repeated_len)
doc_ids = []
for sentence in Converter.splitter.tokenize(text):
sentence_ids = Converter.process(sentence.strip())
Expand Down
38 changes: 32 additions & 6 deletions paddlenlp/data/causal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,26 +363,52 @@ def __getitem__(self, idx):
if doc_index_f == doc_index_l:
doc_ids.append(self.doc_idx[doc_index_f])

sample = self.indexed_dataset.get(
sample, mask = self.indexed_dataset.get(
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
)
else:
# Otherwise, get the rest of the initial document.
doc_ids.append(self.doc_idx[doc_index_f])
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
append_mask = True
if mask is None:
append_mask = False

sample_list = [sample]
mask_list = []
mask_list = [mask]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
doc_ids.append(self.doc_idx[i])
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
sample, mask = self.indexed_dataset.get(self.doc_idx[i])
sample_list.append(sample)
if append_mask:
mask_list.append(mask)

# And finally add the relevant portion of last document.
doc_ids.append(self.doc_idx[doc_index_l])
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
sample_list.append(sample)
if append_mask:
mask_list.append(mask)
sample = np.concatenate(sample_list)
if append_mask:
mask = np.concatenate(mask_list)
# print(sample)
if self.return_doc_ids: # for retro preprocessing
return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)}
if mask is None:
return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)}
else:
return {
"text": np.array(sample, dtype=np.int64),
"doc_ids": np.array(doc_ids, dtype=np.int64),
"mask": np.array(mask, dtype=np.int64),
}
else:
return {"text": np.array(sample, dtype=np.int64)}
if mask is None:
return {"text": np.array(sample, dtype=np.int64)}
else:
return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)}


def _build_index_mappings(
Expand Down
31 changes: 27 additions & 4 deletions paddlenlp/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def data_file_path(prefix_path):
return prefix_path + ".bin"


def loss_mask_file_path(prefix_path):
return prefix_path + ".lsm"


def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
Expand Down Expand Up @@ -444,6 +448,7 @@ def __init__(self, path, skip_warmup=False):
self._path = None
self._index = None
self._bin_buffer = None
self._loss_mask_buffer = None

self._do_init(path, skip_warmup)

Expand All @@ -466,12 +471,18 @@ def _do_init(self, path, skip_warmup):
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
if os.path.exists(loss_mask_file_path(self._path)):
self._loss_mask_buffer_mmap = np.memmap(loss_mask_file_path(self._path), mode="r", order="C")
self._loss_mask_buffer = memoryview(self._loss_mask_buffer_mmap)
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
if hasattr(self, "_loss_mask_buffer_mmap"):
self._loss_mask_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._loss_mask_buffer
del self._index

def __len__(self):
Expand Down Expand Up @@ -507,8 +518,12 @@ def get(self, idx, offset=0, length=None):
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
mask_ptr = ptr // np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
return np_array
mask_array = None
if self._loss_mask_buffer is not None:
mask_array = np.frombuffer(self._loss_mask_buffer, dtype=np.uint8, count=length, offset=mask_ptr)
return np_array, mask_array

@property
def sizes(self):
Expand All @@ -533,20 +548,28 @@ def exists(path):
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))


def make_builder(out_file, impl, save_dtype):
def make_builder(out_file, impl, save_dtype, loss_mask_file=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype)
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file)
else:
return IndexedDatasetBuilder(out_file, dtype=save_dtype)


class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype):
def __init__(self, out_file, dtype, loss_mask_file=None):
self._data_file = open(out_file, "wb")
self._loss_mask_file = None
if loss_mask_file is not None:
self._loss_mask_file = open(loss_mask_file, "wb")
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]

def flush_loss_mask_item(self, loss_mask_lst):
for loss_mask in loss_mask_lst:
tensor = np.array(loss_mask, dtype=np.uint8)
self._loss_mask_file.write(tensor.tobytes(order="C"))

def add_item(self, tensor):
tensor = np.array(tensor, dtype=self._dtype)
self._data_file.write(tensor.tobytes(order="C"))
Expand Down
Loading

0 comments on commit 4e4a149

Please sign in to comment.