diff --git a/examples/RLHF/ppo_trainer.py b/examples/RLHF/ppo_trainer.py index 32a9c148f193..84577d6d2c01 100644 --- a/examples/RLHF/ppo_trainer.py +++ b/examples/RLHF/ppo_trainer.py @@ -415,7 +415,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler._scale - optimizer_was_run = not self.scaler._cache_founf_inf + # Compatible with paddlepaddle 2.6.0 using typo word. + if hasattr(self.scaler, "_cache_founf_inf"): + optimizer_was_run = not self.scaler._cache_founf_inf + else: + optimizer_was_run = not self.scaler._cache_found_inf if not optimizer_was_run: scale_before_value = scale_before.cpu().numpy() scale_after_value = scale_after.cpu().numpy() diff --git a/llm/argument.py b/llm/argument.py index fcec69a93dea..50add643675e 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -129,6 +129,7 @@ class ModelArgument: # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) + prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."}) num_prefix_tokens: int = field(default=128, metadata={"help": "Number of prefix tokens"}) from_aistudio: bool = field(default=False, metadata={"help": "Whether to load model from aistudio"}) diff --git a/llm/data.py b/llm/data.py index 6a38a7096042..5d44c72c8abd 100644 --- a/llm/data.py +++ b/llm/data.py @@ -44,11 +44,11 @@ def get_convert_example(model): if base_model_prefix == "chatglm": return convert_example_chatglm - elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen"]: + elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]: return convert_example_common else: raise ValueError( - f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama." + f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral" ) @@ -107,7 +107,7 @@ def tokenize_rounds_example(tokenizer, example, data_args): # 0. prepare data context_data = example.get("context", {}) context_data["is_training"] = True - + example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]] example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]] diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index d9a54a0e6226..be9b8f3cb4d1 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -398,12 +398,18 @@ def neft_post_hook(module, input, output): multi_query_group_num=prefix_tuning_params["multi_query_group_num"], dtype=dtype, ) - model = PrefixModelForCausalLM( - model=model, - prefix_config=prefix_config, - postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"], - ) - model.mark_only_prefix_as_trainable() + if model_args.prefix_path is None: + model = PrefixModelForCausalLM( + model=model, + prefix_config=prefix_config, + postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"], + ) + else: + model = PrefixModelForCausalLM.from_pretrained( + model=model, + prefix_path=model_args.prefix_path, + postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"], + ) model.print_trainable_parameters() if model_args.lora: @@ -422,7 +428,6 @@ def neft_post_hook(module, input, output): model = LoRAModel(model, lora_config) else: model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) - model.mark_only_lora_as_trainable() model.print_trainable_parameters() def compute_metrics_do_generation(eval_preds): diff --git a/llm/llama/auto_parallel/run_pretrain_auto.py b/llm/llama/auto_parallel/run_pretrain_auto.py index 0866dce8fada..8bd848c6c8a8 100644 --- a/llm/llama/auto_parallel/run_pretrain_auto.py +++ b/llm/llama/auto_parallel/run_pretrain_auto.py @@ -542,16 +542,16 @@ def main(): print("Final pre-training config:", config) - # Set the dtype for loading model - dtype = "float32" - if training_args.fp16_opt_level == "O2": - if training_args.fp16: - dtype = "float16" - if training_args.bf16: - dtype = "bfloat16" + # # Set the dtype for loading model + # dtype = "float32" + # if training_args.fp16_opt_level == "O2": + # if training_args.fp16: + # dtype = "float16" + # if training_args.bf16: + # dtype = "bfloat16" with paddle.LazyGuard(): - model = model_class.from_config(config, dtype=dtype) + model = model_class.from_config(config, dtype="float32") criterion = criterion_class(config) for param in model.parameters(): diff --git a/llm/llama/pt_argument.json b/llm/llama/pt_argument.json index 311c5e886a49..501e09c47160 100644 --- a/llm/llama/pt_argument.json +++ b/llm/llama/pt_argument.json @@ -29,4 +29,4 @@ "prefix_tuning": true, "zero_padding": false, "use_flash_attention": false - } \ No newline at end of file + } diff --git a/llm/mixtral/lora_argument.json b/llm/mixtral/lora_argument.json new file mode 100644 index 000000000000..507c0f76e798 --- /dev/null +++ b/llm/mixtral/lora_argument.json @@ -0,0 +1,32 @@ +{ + "model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/mixtral_lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "fp16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "pipeline_parallel_degree": 1, + "lora": true, + "zero_padding": false, + "use_flash_attention": false + } diff --git a/llm/mixtral/sft_argument.json b/llm/mixtral/sft_argument.json new file mode 100644 index 000000000000..3e778b913ffc --- /dev/null +++ b/llm/mixtral/sft_argument.json @@ -0,0 +1,30 @@ +{ + "model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/mixtral_sft_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-05, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 8, + "sharding": "stage2", + "pipeline_parallel_degree": 1 +} diff --git a/llm/utils.py b/llm/utils.py index bee1529f8ecc..8bcc52ae33ab 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -149,6 +149,16 @@ def get_lora_target_modules(model): ".*mlp.w2.*", ".*mlp.c_proj.*", ] + elif model.base_model_prefix == "mixtral": + target_modules = [ + ".*q_proj.*", + ".*k_proj.*", + ".*v_proj.*", + ".*o_proj.*", + ".*w1.*", + ".*w2.*", + ".*w3.*", + ] else: raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.") return target_modules diff --git a/model_zoo/ernie-1.0/preprocess/README.md b/model_zoo/ernie-1.0/preprocess/README.md index 49ac81efcf27..893eaad5cff7 100644 --- a/model_zoo/ernie-1.0/preprocess/README.md +++ b/model_zoo/ernie-1.0/preprocess/README.md @@ -175,6 +175,8 @@ common config: 打印日志间隔,interval表示处理 文本行数/doc数的 间隔。 --workers WORKERS Number of worker processes to launch 处理文本id化的进程个数。 + --max_repeated_len Max length of repeated chars to keep + 最大保留重复的字符个数。 ``` 通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`baike_sample.bin`, 文章索引信息`baike_sample.idx`. diff --git a/model_zoo/ernie-1.0/preprocess/create_pretraining_data.py b/model_zoo/ernie-1.0/preprocess/create_pretraining_data.py index ea63297936ed..c1874f8b936a 100644 --- a/model_zoo/ernie-1.0/preprocess/create_pretraining_data.py +++ b/model_zoo/ernie-1.0/preprocess/create_pretraining_data.py @@ -104,6 +104,9 @@ def get_args(): group.add_argument("--log_interval", type=int, default=100, help="Interval between progress updates") group.add_argument("--workers", type=int, default=1, help="Number of worker processes to launch") group.add_argument("--max_doc_num", type=int, default=sys.maxsize, help="Number of worker processes to launch") + group.add_argument( + "--max_repeated_len", type=int, default=100, help="The maximum length of the repeated characters to keep" + ) args = parser.parse_args() return args @@ -278,8 +281,24 @@ def process(text): Converter.process = process + def remove_repeated_chars(text, max_repeated_len=100): + """ + Removes repeated characters from the given text, where the length of + the repeated characters is greater than or equal to the specified length. + + Args: + text (str): The input text from which to remove repeated characters. + length (int, optional): The minimum length of the repeated characters. Defaults to 15. + + Returns: + str: The modified text with the repeated characters removed. + """ + pattern = r"(.)\1{" + str(max_repeated_len) + ",}" + return re.sub(pattern, r"\1", text) + def encode(self, json_line): text = json.loads(json_line)[self.args.json_key] + text = Converter.remove_repeated_chars(text, self.args.max_repeated_len) doc_ids = [] for sentence in Converter.splitter.tokenize(text): sentence_ids = Converter.process(sentence.strip()) diff --git a/paddlenlp/data/causal_dataset.py b/paddlenlp/data/causal_dataset.py index d469c231545f..75a2ba193dca 100644 --- a/paddlenlp/data/causal_dataset.py +++ b/paddlenlp/data/causal_dataset.py @@ -363,26 +363,52 @@ def __getitem__(self, idx): if doc_index_f == doc_index_l: doc_ids.append(self.doc_idx[doc_index_f]) - sample = self.indexed_dataset.get( + sample, mask = self.indexed_dataset.get( self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 ) else: # Otherwise, get the rest of the initial document. doc_ids.append(self.doc_idx[doc_index_f]) - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + append_mask = True + if mask is None: + append_mask = False + + sample_list = [sample] + mask_list = [] + mask_list = [mask] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): doc_ids.append(self.doc_idx[i]) - sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + sample, mask = self.indexed_dataset.get(self.doc_idx[i]) + sample_list.append(sample) + if append_mask: + mask_list.append(mask) + # And finally add the relevant portion of last document. doc_ids.append(self.doc_idx[doc_index_l]) - sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) + sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + sample_list.append(sample) + if append_mask: + mask_list.append(mask) sample = np.concatenate(sample_list) + if append_mask: + mask = np.concatenate(mask_list) # print(sample) if self.return_doc_ids: # for retro preprocessing - return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)} + if mask is None: + return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)} + else: + return { + "text": np.array(sample, dtype=np.int64), + "doc_ids": np.array(doc_ids, dtype=np.int64), + "mask": np.array(mask, dtype=np.int64), + } else: - return {"text": np.array(sample, dtype=np.int64)} + if mask is None: + return {"text": np.array(sample, dtype=np.int64)} + else: + return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)} def _build_index_mappings( diff --git a/paddlenlp/data/indexed_dataset.py b/paddlenlp/data/indexed_dataset.py index 8bd7b05e7249..fd3f77bcf378 100644 --- a/paddlenlp/data/indexed_dataset.py +++ b/paddlenlp/data/indexed_dataset.py @@ -124,6 +124,10 @@ def data_file_path(prefix_path): return prefix_path + ".bin" +def loss_mask_file_path(prefix_path): + return prefix_path + ".lsm" + + def create_doc_idx(sizes): doc_idx = [0] for i, s in enumerate(sizes): @@ -444,6 +448,7 @@ def __init__(self, path, skip_warmup=False): self._path = None self._index = None self._bin_buffer = None + self._loss_mask_buffer = None self._do_init(path, skip_warmup) @@ -466,12 +471,18 @@ def _do_init(self, path, skip_warmup): _warmup_mmap_file(data_file_path(self._path)) print_rank_0(" creating numpy buffer of mmap...") self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C") + if os.path.exists(loss_mask_file_path(self._path)): + self._loss_mask_buffer_mmap = np.memmap(loss_mask_file_path(self._path), mode="r", order="C") + self._loss_mask_buffer = memoryview(self._loss_mask_buffer_mmap) print_rank_0(" creating memory view of numpy buffer...") self._bin_buffer = memoryview(self._bin_buffer_mmap) def __del__(self): self._bin_buffer_mmap._mmap.close() + if hasattr(self, "_loss_mask_buffer_mmap"): + self._loss_mask_buffer_mmap._mmap.close() del self._bin_buffer_mmap + del self._loss_mask_buffer del self._index def __len__(self): @@ -507,8 +518,12 @@ def get(self, idx, offset=0, length=None): if length is None: length = size - offset ptr += offset * np.dtype(self._index.dtype).itemsize + mask_ptr = ptr // np.dtype(self._index.dtype).itemsize np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr) - return np_array + mask_array = None + if self._loss_mask_buffer is not None: + mask_array = np.frombuffer(self._loss_mask_buffer, dtype=np.uint8, count=length, offset=mask_ptr) + return np_array, mask_array @property def sizes(self): @@ -533,20 +548,28 @@ def exists(path): return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) -def make_builder(out_file, impl, save_dtype): +def make_builder(out_file, impl, save_dtype, loss_mask_file=None): if impl == "mmap": - return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype) + return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file) else: return IndexedDatasetBuilder(out_file, dtype=save_dtype) class MMapIndexedDatasetBuilder(object): - def __init__(self, out_file, dtype): + def __init__(self, out_file, dtype, loss_mask_file=None): self._data_file = open(out_file, "wb") + self._loss_mask_file = None + if loss_mask_file is not None: + self._loss_mask_file = open(loss_mask_file, "wb") self._dtype = dtype self._sizes = [] self._doc_idx = [0] + def flush_loss_mask_item(self, loss_mask_lst): + for loss_mask in loss_mask_lst: + tensor = np.array(loss_mask, dtype=np.uint8) + self._loss_mask_file.write(tensor.tobytes(order="C")) + def add_item(self, tensor): tensor = np.array(tensor, dtype=self._dtype) self._data_file.write(tensor.tobytes(order="C")) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 83b392e4893f..2bad88e01771 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import gc import math import os import re @@ -32,10 +33,16 @@ ) from ...transformers.conversion_utils import ConversionMixin -from ...transformers.model_utils import PretrainedModel, _add_variant, dtype_guard -from ...transformers.utils import weight_name_suffix +from ...transformers.model_utils import ( + PretrainedModel, + _add_variant, + _load_state_dict_into_model, + dtype_guard, + load_state_dict, +) +from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix from ...utils.distributed import distributed_gather -from ...utils.env import LORA_WEIGHTS_NAME +from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME from ...utils.log import logger from .lora_config import LoRAConfig from .lora_layers import ( @@ -99,6 +106,9 @@ def __init__(self, model, lora_config: LoRAConfig) -> None: ) self.forward = self.model.forward + logger.info("Mark only lora and trainable_module as trainable.") + self.mark_only_lora_as_trainable() + def add_lora_split_mapping(self, module_name, is_column=False): self.lora_split_mapping[module_name] = is_column @@ -113,9 +123,51 @@ def _get_tensor_parallel_mappings(self, config, is_split=True): num_attention_heads=config.num_attention_heads, ) + rename_lora_split_mapping = {} + if issubclass(type(self.model), PipelineLayer): + # rename lora_split_mapping + prefixes = self.model.get_sequential_name_prefixes() + keys = self.lora_split_mapping.keys() + first_key = "" + for k in keys: + first_key = k + break + first_key = first_key.split(".") + use_virtual_pp_degree = first_key[0].isdigit() and first_key[1].isdigit() + + for k in keys: + name_splited = k.split(".") + if use_virtual_pp_degree: + if name_splited[0].isdigit(): + if name_splited[1].isdigit(): + idx = str(int(name_splited[0]) + int(name_splited[1])) + single_name = [prefixes[idx]] + single_name.extend(name_splited[2:]) + else: + single_name = [prefixes[str(len(prefixes) - 1)]] + single_name.extend(name_splited[2:]) + logger.warning( + f"Please check! we treat this key as last layer, get {k}, set origin name as {'.'.join(single_name)}" + ) + else: + raise ValueError(f"Please check! {k} is not a valid key.") + else: + idx = name_splited[0] + # for normal pp layer name + if idx.isdigit(): + single_name = [prefixes[idx]] + single_name.extend(name_splited[1:]) + else: + raise ValueError(f"Unexpected key: {k} for pp lora layer.") + rename_lora_split_mapping[".".join(single_name)] = self.lora_split_mapping[k] + + lora_split_mapping = ( + rename_lora_split_mapping if issubclass(type(self.model), PipelineLayer) else self.lora_split_mapping + ) + def get_tensor_parallel_split_mappings(): final_actions = {} - for key, is_col in self.lora_split_mapping.items(): + for key, is_col in lora_split_mapping.items(): final_actions[key] = partial(fn, is_column=is_col) return final_actions @@ -134,6 +186,41 @@ def from_pretrained(cls, model, lora_path, **kwargs): lora_config_tensor_parallel_degree = lora_config.tensor_parallel_degree lora_model = cls(model, lora_config) + lora_model_index_file = os.path.join(lora_path, SAFE_PEFT_WEIGHTS_INDEX_NAME) + if os.path.exists(lora_model_index_file): + # load safetensors format file. + resolved_archieve_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=lora_path, + index_filename=lora_model_index_file, + ) + loaded_keys = sharded_metadata["all_checkpoint_keys"] + expected_keys = set(lora_model.get_trainable_state_dict().keys()) + + missing_keys = expected_keys - set(loaded_keys) + if len(missing_keys) > 0: + raise ValueError(f"missing_keys: {missing_keys}") + + error_msgs = [] + for shard_file in resolved_archieve_file: + pre_tensor_parallel_split = False + if model.config.tensor_parallel_degree > 1: + pre_tensor_parallel_split = True + tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True) + state_dict = load_state_dict( + shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + ) + error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") + del state_dict + gc.collect() + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + raise RuntimeError( + f"Error(s) in loading state_dict for {lora_model.__class__.__name__}:\n\t{error_msg}" + ) + + return lora_model + # define lora weight name if lora_config_tensor_parallel_degree > 1: lora_weight_name = _add_variant(LORA_WEIGHTS_NAME, f"tp{model.config.tensor_parallel_rank:0>2d}") @@ -176,15 +263,9 @@ def set_state_dict(self, state_dict): logger.info("Load lora weight successfully") def _merge_trainable_tensor_parallel(self, trainable_state_dict): - trainable_name_action_mappings = self._get_tensor_parallel_mappings(self.model.config, is_split=False) - - name_action_mappings = self.model._get_tensor_parallel_mappings(self.model.config, is_split=False) - state_keys_map = ConversionMixin._resolve_prefix_keys( - name_action_mappings.keys(), self.model.state_dict().keys() + trainable_name_action_mappings = self._get_tensor_parallel_convert_actions( + trainable_state_dict.keys(), is_split=False ) - for k, v in state_keys_map.items(): - if v in trainable_state_dict: - trainable_name_action_mappings[v] = name_action_mappings[k] hcg = paddle.distributed.fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group() @@ -206,16 +287,21 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict): return trainable_state_dict - def _convert_tensor_parallel(self, lora_state_dict): - lora_name_action_mappings = self._get_tensor_parallel_mappings(self.model.config, is_split=False) - - name_action_mappings = self.model._get_tensor_parallel_mappings(self.model.config, is_split=False) + def _get_tensor_parallel_convert_actions(self, loaded_keys, is_split=True, ignore_error=False, config=None): + if config is None: + config = self.model.config + specific_name_action_mappings = self._get_tensor_parallel_mappings(config, is_split=is_split) + name_action_mappings = self.model._get_tensor_parallel_mappings(config, is_split=is_split) state_keys_map = ConversionMixin._resolve_prefix_keys( - name_action_mappings.keys(), self.model.state_dict().keys() + name_action_mappings.keys(), self.model.state_dict().keys(), ignore_error=ignore_error ) for k, v in state_keys_map.items(): - if v in lora_state_dict.keys(): - lora_name_action_mappings[v] = name_action_mappings[k] + if v in loaded_keys: + specific_name_action_mappings[v] = name_action_mappings[k] + return specific_name_action_mappings + + def _convert_tensor_parallel(self, lora_state_dict): + lora_name_action_mappings = self._get_tensor_parallel_convert_actions(lora_state_dict.keys(), is_split=True) for name, action in lora_name_action_mappings.items(): if name in lora_state_dict: diff --git a/paddlenlp/peft/prefix/prefix_model.py b/paddlenlp/peft/prefix/prefix_model.py index 1ccf1f0bc594..29a34442280c 100644 --- a/paddlenlp/peft/prefix/prefix_model.py +++ b/paddlenlp/peft/prefix/prefix_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile from functools import partial @@ -24,9 +25,19 @@ from paddle.distributed import fleet from ...prompt.prompt_utils import signature -from ...transformers.model_utils import _add_variant, dtype_guard +from ...transformers.model_utils import ( + _add_variant, + _load_state_dict_into_model, + dtype_guard, + load_state_dict, +) +from ...transformers.utils import get_checkpoint_shard_files from ...utils.distributed import distributed_gather -from ...utils.env import PAST_KEY_VALUES_FILE_NAME, PREFIX_WEIGHTS_NAME +from ...utils.env import ( + PAST_KEY_VALUES_FILE_NAME, + PREFIX_WEIGHTS_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, +) from ...utils.log import logger from .prefix_config import PrefixConfig @@ -68,6 +79,8 @@ def __init__( logger.warning( f"Reset tensor_parallel_degree of prefix_config to {self.model.config.tensor_parallel_degree}." ) + logger.info("Mark only prefix and trainable_module as trainable.") + self.mark_only_prefix_as_trainable() def forward( self, @@ -300,6 +313,39 @@ def from_pretrained( prefix_config_tensor_parallel_degree = prefix_config.tensor_parallel_degree prefix_model = cls(model, prefix_config, postprocess_past_key_value, pad_attention_mask) + prefix_model_index_file = os.path.join(prefix_path, SAFE_PEFT_WEIGHTS_INDEX_NAME) + if os.path.exists(prefix_model_index_file): + # load safetensors format file. + resolved_archieve_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=prefix_path, + index_filename=prefix_model_index_file, + ) + loaded_keys = sharded_metadata["all_checkpoint_keys"] + expected_keys = set(prefix_model.prefix_encoder.state_dict().keys()) + missing_keys = expected_keys - set(loaded_keys) + if len(missing_keys) > 0: + raise ValueError(f"missing_keys: {missing_keys}") + + error_msgs = [] + for shard_file in resolved_archieve_file: + pre_tensor_parallel_split = False + if model.config.tensor_parallel_degree > 1: + pre_tensor_parallel_split = True + tp_actions = prefix_model._get_tensor_parallel_convert_actions(is_split=True) + state_dict = load_state_dict( + shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys + ) + error_msgs += _load_state_dict_into_model(prefix_model.prefix_encoder, state_dict, "") + del state_dict + gc.collect() + + if len(error_msgs) > 0: + error_msgs = "\n\t".join(error_msgs) + raise RuntimeError( + f"Error(s) in loading state_dict for {prefix_model.__class__.__name__}:\n\t{error_msgs}" + ) + return prefix_model + # define prefix weight name if prefix_config_tensor_parallel_degree > 1: prefix_weight_name = _add_variant(PREFIX_WEIGHTS_NAME, f"tp{model.config.tensor_parallel_rank:0>2d}") @@ -389,15 +435,16 @@ def set_state_dict(self, state_dict): self.prefix_encoder.set_state_dict(state_dict) logger.info("Load prefix weight successfully") - def _merge_trainable_tensor_parallel(self, trainable_state_dict): + def _get_tensor_parallel_convert_actions(self, loaded_keys=None, is_split=False, ignore_error=False): from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( - is_split=False, + is_split=is_split, tensor_parallel_degree=self.prefix_config.tensor_parallel_degree, tensor_parallel_rank=self.model.config.tensor_parallel_rank, num_attention_heads=self.model.config.num_attention_heads, ) + if self.prefix_config.prefix_projection: name_action_mappings = { "0.weight": partial(fn, is_column=False), @@ -409,6 +456,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict): name_action_mappings = { "0.weight": partial(fn, is_column=False), } + return name_action_mappings + + def _merge_trainable_tensor_parallel(self, trainable_state_dict): + name_action_mappings = self._get_tensor_parallel_convert_actions(is_split=False) hcg = paddle.distributed.fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group() is_dst = paddle.distributed.get_rank(mp_group) == 0 @@ -426,27 +477,7 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict): return trainable_state_dict def _convert_tensor_parallel(self, prefix_state_dict): - from paddlenlp.transformers.conversion_utils import split_or_merge_func - - fn = split_or_merge_func( - is_split=True, - tensor_parallel_degree=self.model.config.tensor_parallel_degree, - tensor_parallel_rank=self.model.config.tensor_parallel_rank, - num_attention_heads=self.model.config.num_attention_heads, - ) - - if self.prefix_config.prefix_projection: - name_action_mappings = { - "0.weight": partial(fn, is_column=False), - "1.weight": partial(fn, is_column=True), - "1.bias": partial(fn, is_column=True), - "3.weight": partial(fn, is_column=False), - } - else: - name_action_mappings = { - "0.weight": partial(fn, is_column=False), - } - + name_action_mappings = self._get_tensor_parallel_convert_actions(is_split=True) for name, action in name_action_mappings.items(): tensor = prefix_state_dict.pop(name) prefix_state_dict[name] = action(tensor) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index c6d6f2962daf..b8fcbaab4adc 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -92,9 +92,12 @@ def _get_meshes_for_loader(self): def _get_mesh(pp_idx=0): return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] + # Note(lizhiyu): If the values returned by `DataLoader` don't have the format `[images, labels]`, + # error may occurs here. meshes = [] - for pp_idx in range(self.args.pipeline_parallel_degree): - meshes.append(_get_mesh(pp_idx)) + meshes.append(_get_mesh(0)) + if self.args.pipeline_parallel_degree > 1: + meshes.append(_get_mesh(self.args.pipeline_parallel_degree - 1)) return meshes def _wrap_for_dist_loader(self, train_dataloader): @@ -438,7 +441,11 @@ def optimizer_step(self): self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler._scale - optimizer_was_run = not self.scaler._cache_founf_inf + # Compatible with paddlepaddle 2.6.0 using typo word. + if hasattr(self.scaler, "_cache_founf_inf"): + optimizer_was_run = not self.scaler._cache_founf_inf + else: + optimizer_was_run = not self.scaler._cache_found_inf if not optimizer_was_run: scale_before_value = scale_before.cpu().numpy() scale_after_value = scale_after.cpu().numpy() diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 8a5a6ca5131a..69efd002a1b2 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -18,11 +18,13 @@ import multiprocessing import os +import numpy as np import paddle import paddle.distributed as dist from paddle.distributed import fleet from tqdm.auto import tqdm +from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.trainer_utils import ExplicitEnum from paddlenlp.trainer.utils.helper import distributed_file, distributed_isfile from paddlenlp.transformers.model_utils import ( @@ -40,16 +42,22 @@ ) from paddlenlp.utils.distributed import distributed_gather from paddlenlp.utils.env import ( + LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, PADDLE_MASTER_WEIGHTS_NAME, PADDLE_OPTIMIZER_INDEX_NAME, PADDLE_OPTIMIZER_NAME, + PADDLE_PEFT_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_NAME, + PAST_KEY_VALUES_FILE_NAME, + PREFIX_WEIGHTS_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_MASTER_WEIGHTS_NAME, SAFE_OPTIMIZER_INDEX_NAME, SAFE_OPTIMIZER_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, + SAFE_PEFT_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, ) @@ -108,12 +116,20 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati Raises: ValueError: if model is not an instance of `PretrainedModel` and the model cannot be saved """ + if isinstance(model, PretrainedModel): model_to_save = model elif isinstance(unwrap_model(model), PretrainedModel): model_to_save = unwrap_model(model) + elif isinstance(model, PrefixModelForCausalLM) or isinstance(model, LoRAModel): + model_to_save = model else: - raise ValueError("Unified checkpoint only supports PretrainedModel") + raise ValueError("Unified checkpoint only supports PretrainedModel, LoRAModel and PrefixModelForCausalLM!") + + # Under non distributed environment. + if paddle.distributed.get_world_size() <= 1: + save_single_card_checkpoint(args, model_to_save, output_dir) + return skip_save_model_weight = False if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config: @@ -141,15 +157,24 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati ) if sharded_index is not None: - if not safe_serialization: - path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME) + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + index_name = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME else: - path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME) + index_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME + path = os.path.join(output_dir, index_name) if args.should_save: with open(path, "w") as f: json.dump(sharded_index, f, indent=4) + if args.should_save: + # Save prefix model past_key_values + if isinstance(model_to_save, PrefixModelForCausalLM): + save_prefix_past_key_value(model_to_save, save_directory) + model_to_save.prefix_config.save_pretrained(save_directory) + if isinstance(model_to_save, LoRAModel): + model_to_save.lora_config.save_pretrained(save_directory) + # save the config config_to_save = save_config(model_to_save) # Attach architecture to the config @@ -170,6 +195,11 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, Returns: None """ + + if paddle.distributed.get_world_size() <= 1: + load_single_card_checkpoint(args, model, resume_from_checkpoint) + return + local_resume = check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization) if not local_resume: @@ -194,7 +224,7 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa loaded_keys = sharded_metadata["all_checkpoint_keys"] - model_state_dict = model.state_dict() + model_state_dict = get_expected_state_dict(model) expected_keys = set(list(model_state_dict.keys())) missing_keys = expected_keys - set(loaded_keys) @@ -228,7 +258,12 @@ def _remove_unused_keys( if shard_file.endswith(".safetensors") and model.config.tensor_parallel_degree > 1: pre_tensor_parallel_split = True assert loaded_keys is not None, "loaded_keys is not None." - tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(loaded_keys), is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True) # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys) @@ -287,22 +322,31 @@ def unified_checkpoint_into_shards( paddle.device.cuda.empty_cache() assert hasattr(model_to_save, "config") - state_dict = model_to_save.state_dict() - + state_dict = get_expected_state_dict(model_to_save) all_filter_keys = filter_params(model_to_save, state_dict) config_to_save = copy.deepcopy(model_to_save.config) if config_to_save.tensor_parallel_degree > 1: - tp_actions = model_to_save.get_tensor_parallel_convert_actions( - model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True - ) + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + tp_actions = model_to_save._get_tensor_parallel_convert_actions( + all_filter_keys, is_split=False, ignore_error=True + ) + else: + tp_actions = model_to_save.get_tensor_parallel_convert_actions( + model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True + ) state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys) # build index json file index_weight_file = {} total_size = 0 - weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME + if isinstance(model_to_save, LoRAModel): + weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else LORA_WEIGHTS_NAME + elif isinstance(model_to_save, PrefixModelForCausalLM): + weights_name = SAFE_PEFT_WEIGHTS_NAME if safe_serialization else PREFIX_WEIGHTS_NAME + else: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME shard_file = get_sharded_file_name(args, weights_name) for key, weight in state_dict.items(): @@ -314,6 +358,11 @@ def unified_checkpoint_into_shards( index_file_list, total_size_list, ) + if sharded_index is not None: + if isinstance(model_to_save, LoRAModel): + sharded_index["type"] = "lora" + elif isinstance(model_to_save, PrefixModelForCausalLM): + sharded_index["type"] = "ptuning" paddle.device.cuda.empty_cache() @@ -330,6 +379,10 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio safe_serialization (bool, optional): Whether to use safetensors. Defaults to False. """ + if paddle.distributed.get_world_size() <= 1: + save_single_card_optimizer(args, model, optimizer, output_dir) + return + # Split into naive optimizer params and master weights. results = unified_optimizer_into_shards(args, model, optimizer, safe_serialization=safe_serialization) master_weight_state_dict = None @@ -388,6 +441,11 @@ def load_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_ Returns: None """ + + if paddle.distributed.get_world_size() <= 1: + optim_state_dict = load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint) + return optim_state_dict + local_resume = check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe_serialization) if not local_resume: logger.info("Begin to dynamically load unified optimizer!") @@ -422,7 +480,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin ) has_master_weights = True if sharded_metadata["master_weights"] else False - model_state_dict = model.state_dict() + model_state_dict = get_expected_state_dict(model) model_keys = list(model_state_dict.keys()) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings @@ -466,7 +524,14 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected if shard_file.endswith(".safetensors"): # assert model_keys is not None, "model_keys is None." TODO: correct the assert if model.config.tensor_parallel_degree > 1: - tp_actions = model.get_tensor_parallel_convert_actions(model.config, model_keys, ignore_error=True) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + model_keys, is_split=True, ignore_error=True + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, ignore_error=True + ) if not is_master_weights: tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys) @@ -533,9 +598,15 @@ def unified_optimizer_into_shards( if "LR_Scheduler" in optim_state_dict.keys(): optim_state_dict.pop("LR_Scheduler") + # gather global master_weights status. + global_master_weights = reduce_master_weights_status(master_weights is not None) + if master_weights is None and global_master_weights: + master_weights = {} + # get optimizer param mappings static2struct_name_mappings = {} - for k, v in model.state_dict().items(): + state_dict = get_expected_state_dict(model) + for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k # rename optimizer param @@ -562,9 +633,12 @@ def unified_optimizer_into_shards( base_model_key = key.split("/")[0] if base_model_key not in model_keys: model_keys.append(base_model_key) - tp_actions = model.get_tensor_parallel_convert_actions( - model.config, model_keys, is_split=False, ignore_error=True - ) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions(model_keys, is_split=False, ignore_error=True) + else: + tp_actions = model.get_tensor_parallel_convert_actions( + model.config, model_keys, is_split=False, ignore_error=True + ) optim_state_dict = merge_tensor_parallel_for_optimizer( optim_state_dict, tp_actions, @@ -663,7 +737,8 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa pp_group = hcg.get_pipe_parallel_group() need_files = set() - for key in model.state_dict().keys(): + state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): filename = index["weight_map"][key] need_files.add(filename) diff_filelist = list(need_files.difference(set(existed_files))) @@ -752,7 +827,8 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, local_resume = True if args.data_parallel_rank == 0: need_files = set() - for key in model.state_dict().keys(): + state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None) param_rank = param2rank.get(static_name, None) @@ -803,6 +879,133 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, return local_resume & local_resume_rw +def save_single_card_checkpoint(args, model_to_save, output_dir): + """Save checkpoint for non-distributed environment.""" + + state_dict = get_expected_state_dict(model_to_save) + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + weight_filename = "peft_model-00001-of-00001.safetensors" + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + weight_filename = "model-00001-of-00001.safetensors" + index_filename = SAFE_WEIGHTS_INDEX_NAME + # get index json + index_weight_file = {} + total_size = 0 + for key, weight in state_dict.items(): + index_weight_file[key] = weight_filename + total_size += weight.numel().item() * dtype_byte_size(weight.dtype) + sharded_index_json = {} + sharded_index_json["metadata"] = {"total_size": total_size} + sharded_index_json["weight_map"] = index_weight_file + if isinstance(model_to_save, LoRAModel): + sharded_index_json["type"] = "lora" + elif isinstance(model_to_save, PrefixModelForCausalLM): + sharded_index_json["type"] = "ptuning" + path = os.path.join(output_dir, index_filename) + with open(path, "w") as f: + json.dump(sharded_index_json, f, indent=4) + + # save checkpoint + file_save_async_or_sync(state_dict, os.path.join(output_dir, weight_filename), safe_serialization=True) + + if isinstance(model_to_save, PrefixModelForCausalLM): + save_prefix_past_key_value(model_to_save, output_dir) + + +def save_single_card_optimizer(args, model, optimizer, output_dir): + """ "Save optimizer for non-distributed environment.""" + # Split into optimizer params and master weights. + optim_state_dict = nested_copy(optimizer.state_dict()) + master_weights = None + if "master_weights" in optim_state_dict.keys(): + master_weights = optim_state_dict.pop("master_weights") + if "LR_Scheduler" in optim_state_dict.keys(): + optim_state_dict.pop("LR_Scheduler") + + static2struct_name_mappings = {} + state_dict = get_expected_state_dict(model) + for k, v in state_dict.items(): + static2struct_name_mappings[v.name] = k + + # rename optimizer param + for key in list(optim_state_dict.keys()): + static_name, type_name = generate_base_static_name(key) + new_name = static2struct_name_mappings[static_name] + "/" + type_name + optim_state_dict[new_name] = optim_state_dict.pop(key) + if master_weights is not None: + for key in list(master_weights.keys()): + master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + + # save index json + index_optimizer_file, index_master_weight_file = {}, {} + total_optim_size, total_master_weight_size = 0, 0 + for key, weight in optim_state_dict.items(): + index_optimizer_file[key] = "optimizer-00001-of-00001.safetensors" + total_optim_size += weight.numel().item() * dtype_byte_size(weight.dtype) + if master_weights is not None: + for key, weight in master_weights.items(): + index_master_weight_file[key] = "master_weights-00001-of-00001.safetensors" + total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) + path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME) + master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME) + with open(path, "w") as f: + has_master_weights = master_weights is not None + json.dump( + { + "metadata": {"total_size": total_optim_size}, + "weight_map": index_optimizer_file, + "master_weights": has_master_weights, + }, + f, + indent=4, + ) + if master_weights is not None: + with open(master_path, "w") as f: + json.dump( + {"metadata": {"total_size": total_master_weight_size}, "weight_map": index_master_weight_file}, + f, + indent=4, + ) + + # save optimizer state dict + file_save_async_or_sync( + optim_state_dict, os.path.join(output_dir, "optimizer-00001-of-00001.safetensors"), safe_serialization=True + ) + if master_weights is not None: + file_save_async_or_sync( + master_weights, + os.path.join(output_dir, "master_weights-00001-of-00001.safetensors"), + safe_serialization=True, + ) + + +def save_prefix_past_key_value(model_to_save, save_directory): + past_key_value = model_to_save.prefix_encoder(model_to_save.prefix_tokens.unsqueeze(0).expand([1, -1])) + past_key_value = past_key_value.reshape( + [ + model_to_save.prefix_config.num_prefix_tokens, + 2, + model_to_save.prefix_config.num_hidden_layers, + model_to_save.num_heads, + model_to_save.head_dim, + ] + ) + past_key_value = paddle.transpose(past_key_value, perm=[2, 1, 3, 0, 4]).cpu().numpy() + model_to_save.prefix_config.save_pretrained(save_directory) + np.save(os.path.join(save_directory, PAST_KEY_VALUES_FILE_NAME), past_key_value) + + +def get_expected_state_dict(model_to_save): + if isinstance(model_to_save, PretrainedModel): + state_dict = model_to_save.state_dict() + elif isinstance(model_to_save, LoRAModel): + state_dict = model_to_save.get_trainable_state_dict() + elif isinstance(model_to_save, PrefixModelForCausalLM): + state_dict = model_to_save.prefix_encoder.state_dict() + return state_dict + + def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappings, resume_from_checkpoint): """Create dispatch table for dynamically loading state dict. @@ -818,7 +1021,8 @@ def create_dispatch_table(args, model, file_keyname_mappings, file_machine_mappi dispatch_list = [] recv_table = {} if args.dataset_rank == 0: - for (k, v) in model.state_dict().items(): + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): if hasattr(v, "is_distributed") and v.is_distributed: recv_table[k] = [(dist.get_rank(), tp_rank)] else: @@ -863,7 +1067,8 @@ def create_optimizer_dispatch_table( dispatch_list = [] recv_table = {} if args.data_parallel_rank == 0: - for (k, v) in model.state_dict().items(): + state_dict = get_expected_state_dict(model) + for (k, v) in state_dict.items(): if sharding_group.nranks > 1: static_name = struct2static_name_mappings[k] param_rank = param2rank.get(static_name, None) @@ -928,13 +1133,18 @@ def load_unified_checkpoint_dynamically(args, model, optimizer, resume_from_chec tp_actions = {} else: # Get corresponding tensor parallel actions. - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) logger.debug("Distributed send recv for state dict load ...") # Distribute the checkpoint tensor dynamically, using the `send_table` and `recv_table` we create before. state_dict = distributed_send_recv( config_revise, - model.state_dict(), + get_expected_state_dict(model), tp_actions, send_table, recv_table, @@ -985,8 +1195,8 @@ def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_check for key in index["weight_map"].keys(): _, typename = key.split("/") typename_set.add(typename) - struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} - static2struct_name_mappings = {v.name: k for k, v in model.state_dict().items()} + struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + static2struct_name_mappings = {v.name: k for k, v in get_expected_state_dict(model).items()} # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. send_table, recv_table = create_optimizer_dispatch_table( args, @@ -1065,7 +1275,12 @@ def check_optimizer_param(parameter): if len(all_tp_keys) == 0: tp_actions = {} else: - tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + tp_actions = model._get_tensor_parallel_convert_actions( + set(all_tp_keys), is_split=True, ignore_error=True, config=config_revise + ) + else: + tp_actions = model.get_tensor_parallel_convert_actions(config_revise, all_tp_keys, ignore_error=True) optimizer_keys = list(index["weight_map"].keys()) optimizer_tp_actions = mapping_optimizer_tp_actions(tp_actions, optimizer_keys) if has_master_weights: @@ -1121,6 +1336,81 @@ def check_optimizer_param(parameter): return None +def load_single_card_checkpoint(args, model, resume_from_checkpoint: str): + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, index_filename), + ) + + loaded_keys = sharded_metadata["all_checkpoint_keys"] + model_state_dict = get_expected_state_dict(model) + expected_keys = set(list(model_state_dict.keys())) + missing_keys = expected_keys - set(loaded_keys) + + if len(missing_keys) > 0: + raise ValueError(f"Missing keys: {missing_keys}") + + state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys) + error_msgs = _load_state_dict_into_model(model, state_dict, "") + del state_dict + gc.collect() + + if error_msgs: + raise RuntimeError(f"Error(s) in loading state dict for {model.__class__.__name__}:\n\t{error_msgs}") + + +def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: str): + returned_optim_state_dict = nested_copy(optimizer.state_dict()) + + resolved_archive_file, sharded_metadata = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), + ) + has_master_weights = True if sharded_metadata["master_weights"] else False + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + expected_keys = sharded_metadata["all_optimizer_keys"] + + if has_master_weights: + returned_optim_state_dict["master_weights"] = {} + resolved_archive_file_mw, sharded_metadata_mw = get_optimizer_shard_files( + optimizer_path=resume_from_checkpoint, + index_filename=os.path.join(resume_from_checkpoint, SAFE_MASTER_WEIGHTS_INDEX_NAME), + ) + expected_keys_mw = sharded_metadata_mw["all_optimizer_keys"] + + state_dict_optim = load_state_dict(resolved_archive_file[0], None, expected_keys) + if has_master_weights: + state_dict_optim_mw = load_state_dict(resolved_archive_file_mw[0], None, expected_keys_mw) + + for key in list(state_dict_optim.keys()): + key_name = key.split("/") + static_name = struct2static_name_mappings[key_name[0]] + if has_master_weights: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) + returned_optim_state_dict[key_name] = state_dict_optim.pop(key) + returned_optim_state_dict[key_name].name = key_name + if has_master_weights: + for key in list(state_dict_optim_mw.keys()): + static_name = struct2static_name_mappings[key] + returned_optim_state_dict["master_weights"][static_name] = state_dict_optim_mw.pop(key) + returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER]) + + returned_optim_state_dict = nested_copy_place( + returned_optim_state_dict, + place=paddle.framework._current_expected_place(), + blocking=True, + ) + return returned_optim_state_dict + + def get_file_mappings(index, resume_from_checkpoint): file_keyname_mappings = {} for k, v in index["weight_map"].items(): @@ -1190,7 +1480,6 @@ def distributed_send_recv( for key in file_keyname_mappings[filename]: recv_info = recv_table[key] recv_ranklist = [a for (a, b) in recv_info] - if is_src and global_rank == send_table[key]: py_safe_slice_ = f.get_slice(key) # send @@ -1275,6 +1564,24 @@ def get_sharded_index( return None +def reduce_master_weights_status(has_master_weights=False): + data = paddle.to_tensor([has_master_weights], dtype="int32") + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() + + if tp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=tp_group) + if pp_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pp_group) + if sharding_group.nranks > 1: + dist.all_reduce(data, op=dist.ReduceOp.SUM, group=sharding_group) + + return data.item() > 0 + + def gather_sharded_object(index_file, total_size, is_optimizer=False): index_file_list, total_size_list = [], [] @@ -1347,7 +1654,7 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): if tp_rank == 0: tensor_bytes_dict = {} - model_state_dict = model_to_save.state_dict() + model_state_dict = get_expected_state_dict(model_to_save) for (k, v) in state_dict.items(): model_v = model_state_dict[k.split("/")[0]] if is_optimizer else v if hasattr(model_v, "is_distributed") and model_v.is_distributed: @@ -1379,7 +1686,8 @@ def filter_params(model_to_save, state_dict, is_optimizer=False): total_size += weight_size filter_tensor_list.append(current_block) - assert len(filter_tensor_list) == tp_size, "Error, partition failed!" + if len(filter_tensor_list) < tp_size: + filter_tensor_list.extend([[] for i in range(tp_size - len(filter_tensor_list))]) dist.broadcast_object_list( filter_tensor_list, @@ -1511,7 +1819,7 @@ def get_expected_keys(sharded_metadata, model, optimizer): if in_sharding_parallel_model: params2rank = optimizer._param2rank - struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} + struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): @@ -1563,14 +1871,14 @@ def nested_copy(inputs): return inputs -def nested_copy_place(inputs, place=None): +def nested_copy_place(inputs, place=None, blocking=False): if isinstance(inputs, dict): outputs = {} for key in list(inputs.keys()): - outputs[key] = nested_copy_place(inputs[key], place) + outputs[key] = nested_copy_place(inputs[key], place, blocking) return outputs if isinstance(inputs, paddle.Tensor): - inputs = inputs if inputs.place == place else inputs._copy_to(place, False) + inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking) return inputs @@ -1639,7 +1947,11 @@ def select_model_weight_index(args, model, resume_from_checkpoint, safe_serializ """ # find model weight index file - index_filename = PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_PEFT_WEIGHTS_INDEX_NAME + else: + index_filename = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME + index_filename_path = os.path.join(resume_from_checkpoint, index_filename) identify_func = os.path.isfile if local else distributed_isfile diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 298f6e9540a6..7fa1cb171e44 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -94,10 +94,12 @@ from ..utils.env import ( LORA_WEIGHTS_NAME, PADDLE_MASTER_WEIGHTS_INDEX_NAME, + PADDLE_PEFT_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_INDEX_NAME, PADDLE_WEIGHTS_NAME, PREFIX_WEIGHTS_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME, + SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME, ) from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available @@ -355,6 +357,13 @@ def __init__( if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") + if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): + if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: + self.args.unified_checkpoint_config.remove("skip_save_model_weight") + logger.warning( + "We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config." + ) + self.do_grad_scaling = False self.enable_autocast_context_manager = False if args.fp16 or args.bf16: @@ -1022,7 +1031,11 @@ def _inner_training_loop( self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler._scale - optimizer_was_run = not self.scaler._cache_founf_inf + # Compatible with paddlepaddle 2.6.0 using typo word. + if hasattr(self.scaler, "_cache_founf_inf"): + optimizer_was_run = not self.scaler._cache_founf_inf + else: + optimizer_was_run = not self.scaler._cache_found_inf if not optimizer_was_run: scale_before_value = scale_before.cpu().numpy() scale_after_value = scale_after.cpu().numpy() @@ -1094,20 +1107,29 @@ def _inner_training_loop( if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): self._load_best_model_from_peft_checkpoint() else: - weight_name = PADDLE_WEIGHTS_NAME - best_model_path = os.path.join( - self.state.best_model_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix) - ) - if os.path.exists(best_model_path): - # We load the model state dict on the CPU to avoid an OOM error. - state_dict = paddle.load(best_model_path, return_numpy=True) - # If the model is on the GPU, it still works! - self._set_state_dict_in_model(state_dict) + if self.args.unified_checkpoint: + load_unified_checkpoint( + self.args, + self.model, + self.optimizer, + self.state.best_model_checkpoint, + safe_serialization=True, + ) else: - logger.warning( - f"Could not locate the best model at {best_model_path}, if you are running a distributed training " - "on multiple nodes, you should activate `--save_on_each_node`." + weight_name = PADDLE_WEIGHTS_NAME + best_model_path = os.path.join( + self.state.best_model_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix) ) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = paddle.load(best_model_path, return_numpy=True) + # If the model is on the GPU, it still works! + self._set_state_dict_in_model(state_dict) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step @@ -1127,6 +1149,16 @@ def _inner_training_loop( return TrainOutput(self.state.global_step, train_loss, metrics) def _load_best_model_from_peft_checkpoint(self): + if self.args.unified_checkpoint: + load_unified_checkpoint( + self.args, + self.model, + self.optimizer, + self.state.best_model_checkpoint, + safe_serialization=True, + ) + return + convert_tp = False if isinstance(self.model, LoRAModel): if self.model.quantized or self.args.pipeline_parallel_degree > 1: @@ -1351,7 +1383,7 @@ def _get_eval_sampler(self, eval_dataset: Dataset): if self.args.pipeline_parallel_degree > 1: drop_last = True logger.warning( - "In parallel mode, the bacth_size is strictly checked. set DistributedBatchSampler drop_last=True." + "In parallel mode, the batch_size is strictly checked. set DistributedBatchSampler drop_last=True." ) return DistributedBatchSampler( @@ -2125,7 +2157,16 @@ def _save_checkpoint(self, model, metrics=None): if self.args.should_save: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") - self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if self.args.unified_checkpoint: + save_unified_optimizer( + self.args, + self.model, + self.optimizer, + output_dir, + safe_serialization=True, + ) + else: + self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -2403,6 +2444,8 @@ def _load_optimizer_and_scheduler(self, checkpoint): opt_state_dict = tmp # broadcast optimizer state in dp group + if self.args.local_rank != -1: + dist.barrier() opt_state_dict = broadcast_dp_optimizer(opt_state_dict) if opt_state_dict is not None: @@ -3031,7 +3074,12 @@ def print_config(self, args=None, key=""): def is_unified_checkpoint(self, resume_from_checkpoint, safe_serialization=True): is_unified_checkpoint_type = False - weights_index_name = PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME + if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): + weights_index_name = ( + PADDLE_PEFT_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_PEFT_WEIGHTS_INDEX_NAME + ) + else: + weights_index_name = PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME master_weights_index_name = ( PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME ) diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 0c7baf9595b6..804e64c10c23 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -281,6 +281,8 @@ from .rw.configuration import * from .rw.tokenizer import * from .qwen import * +from .mixtral.modeling import * +from .mixtral.configuration import * # For faster tokenizer from ..utils.import_utils import is_fast_tokenizer_available diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 24e63e8e5fe3..19fbc5ec5105 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -127,6 +127,7 @@ ("Blip", "blip"), ("Bloom", "bloom"), ("QWen", "qwen"), + ("Mixtral", "mixtral"), ] ) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 3905bf4f9efe..5f5483bc809e 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -32,6 +32,16 @@ except ImportError: fused_rotary_position_embedding = None +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + from paddlenlp.transformers.conversion_utils import ( StateDictNameMapping, init_name_mappings, @@ -228,10 +238,10 @@ def __init__(self, config, ipp: Optional[int] = None): def forward(self, x): if self.fuse_attention_ffn: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + x = swiglu(self.gate_up_fused_proj(x)) else: - out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) return out diff --git a/paddlenlp/transformers/llama/modeling_auto_static.py b/paddlenlp/transformers/llama/modeling_auto_static.py index c4ee48def480..61bf3daa2529 100644 --- a/paddlenlp/transformers/llama/modeling_auto_static.py +++ b/paddlenlp/transformers/llama/modeling_auto_static.py @@ -31,6 +31,16 @@ except ImportError: fused_rotary_position_embedding = None +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + from paddlenlp.transformers.conversion_utils import ( StateDictNameMapping, init_name_mappings, @@ -242,10 +252,10 @@ def forward(self, x): fleet.auto.shard_tensor(self.down_proj.weight, *get_dist_attr(["mp", None], self.ipp)) if self.fuse_attention_ffn: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + x = swiglu(self.gate_up_fused_proj(x)) else: - out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) return out diff --git a/paddlenlp/transformers/mixtral/__init__.py b/paddlenlp/transformers/mixtral/__init__.py new file mode 100644 index 000000000000..816c416d6821 --- /dev/null +++ b/paddlenlp/transformers/mixtral/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import MixtralConfig +from .modeling import MixtralForCausalLM diff --git a/paddlenlp/transformers/mixtral/configuration.py b/paddlenlp/transformers/mixtral/configuration.py new file mode 100644 index 000000000000..56f531aac410 --- /dev/null +++ b/paddlenlp/transformers/mixtral/configuration.py @@ -0,0 +1,191 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mixtral model configuration""" + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "MixtralConfig", +] + + +class MixtralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~MixtralModel`]. It is used to instantiate an Mixtral + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~MixtralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + use_fused_rope(`bool`, *optional*, defaults to False): + Enable rope fusion or not. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + Example: + ```python + >>> from paddlenlp.transformer import MixtralModel, MixtralConfig + + >>> # Initializing a Mixtral mixtral-7b style configuration + >>> configuration = MixtralConfig() + + >>> # Initializing a model from the mixtral-7b style configuration + >>> model = MixtralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mixtral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + max_position_embeddings=4096 * 32, + seq_length=2048, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + use_recompute=False, + recompute_granularity="full", + no_recompute_layers=None, + use_flash_attention=False, + attention_dropout=0.0, + use_fused_rope=False, + rope_theta=1e6, + tensor_parallel_output=True, + sequence_parallel=False, + fuse_sequence_parallel_allreduce=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + num_experts_per_tok=2, + num_local_experts=8, + router_aux_loss_coef=0.001, + output_router_logits=False, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity + self.no_recompute_layers = no_recompute_layers + self.use_flash_attention = use_flash_attention + self.tensor_parallel_output = tensor_parallel_output + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce + + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_fused_rope = use_fused_rope + self.rope_theta = rope_theta + + # ----------------- Experts -------------------- # + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.output_router_logits = output_router_logits + + self.sliding_window = sliding_window + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + tensor_parallel_output=tensor_parallel_output, + **kwargs, + ) diff --git a/paddlenlp/transformers/mixtral/modeling.py b/paddlenlp/transformers/mixtral/modeling.py new file mode 100644 index 000000000000..e591e30a9b91 --- /dev/null +++ b/paddlenlp/transformers/mixtral/modeling.py @@ -0,0 +1,1531 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Mixtral model""" +from __future__ import annotations + +import math +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.model_outputs import ( + MoECausalLMOutputWithPast, + MoEModelOutputWithPast, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model +from paddlenlp.utils.log import logger + +from ..activations import ACT2FN +from ..sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + GatherOp, + RowSequenceParallelLinear, + ScatterOp, + mark_as_sequence_parallel_parameter, +) +from .configuration import MixtralConfig + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +__all__ = [ + "MixtralModel", + "MixtralPretrainedModel", + "MixtralForCausalLM", + "MixtralPretrainingCriterion", +] + + +def load_balancing_loss_func(gate_logits, num_experts, top_k=2, attention_mask=None): + """ + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Paddle. + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + Args: + gate_logits (Union[`paddle.Tensor`, Tuple[paddle.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top k experts to be considered for the loss computation. + attention_mask (`paddle.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + concatenated_gate_logits = paddle.concat( + gate_logits, axis=0 + ) # [num_hidden_layers X batch_size X sequence_length, num_experts] + + routing_weights = F.softmax(concatenated_gate_logits, axis=-1) + _, selected_experts = paddle.topk(routing_weights, top_k, axis=-1) + expert_mask = F.one_hot( + selected_experts, num_classes=num_experts + ) # [num_hidden_layers X batch_size X sequence_length, top_k, num_experts] + + if attention_mask is None or len(attention_mask.shape) == 4: + # Only intokens strategy has 4-D attention_mask, we currently do not support excluding padding tokens. + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.mean(expert_mask.astype("float32"), axis=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.mean(routing_weights, axis=0) + else: + # Exclude the load balancing loss of padding tokens. + if len(attention_mask.shape) == 2: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape([-1, top_k, num_experts]) + ) # [num_hidden_layers * batch_size * sequence_length, top_k, num_experts] + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = paddle.sum(expert_mask.astype("float32") * expert_attention_mask, axis=0) / paddle.sum( + expert_attention_mask, axis=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape([-1, num_experts]) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = paddle.sum( + routing_weights * router_per_expert_attention_mask, axis=0 + ) / paddle.sum(router_per_expert_attention_mask, axis=0) + + overall_loss = paddle.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list + + +def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + training=True, + sequence_parallel=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=config.attention_dropout if training else 0.0, + training=training, + ) + attn_weights = None + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + + attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class MixtralRMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) + + +class MixtralRotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + + if position_ids is None: + # Note: Only for MixtralForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MixtralMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.w1 = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.w3 = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.w2 = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + x = self.act_fn(self.w1(x)) * self.w3(x) + x = self.w2(x) + return x + + +class MixtralSparseMoeBlock(nn.Layer): + def __init__(self, config: MixtralConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias_attr=False) + self.experts = nn.LayerList([MixtralMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states): + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape([-1, hidden_dim]) + # router_logits: [batch_size * seq_len, num_experts] + router_logits = self.gate(hidden_states) + + with paddle.amp.auto_cast(False): + routing_weights = F.softmax(router_logits.astype("float32"), axis=1) + routing_weights, selected_experts = paddle.topk(routing_weights, self.top_k, axis=-1) + routing_weights /= routing_weights.sum(axis=-1, keepdim=True) + # we cast back to input dtype + routing_weights = routing_weights.astype(hidden_states.dtype) + + final_hidden_states = paddle.zeros( + [batch_size * seq_len, hidden_dim], + dtype=hidden_states.dtype, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated. + # shape: [num_experts, top_k, batch_size * seq_len] + expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + + # Loop over all available experts in the model and perform the computation on each expert. + for expert_id in range(self.num_experts): + expert_layer = self.experts[expert_id] + idx, top_x = paddle.where(expert_mask[expert_id]) + + if top_x.shape[0] == 0: + continue + + current_state = paddle.gather(hidden_states, top_x.squeeze()) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx] + + top_x = top_x.squeeze() + if top_x.shape == []: + top_x = paddle.to_tensor([top_x.item()]) + final_hidden_states.index_add_(top_x, 0, current_hidden_states.astype(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape([batch_size, seq_len, hidden_dim]) + return final_hidden_states, router_logits + + +class MixtralAttention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MixtralConfig, layerwise_recompute: bool = False): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + self.rope_theta = config.rope_theta + self.max_position_embeddings = config.max_position_embeddings + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if config.tensor_parallel_degree > 1: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + self.k_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + if config.tensor_parallel_degree > 1: + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + input_is_parallel=True, + ) + else: + self.o_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.config = config + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + self.training, + self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class MixtralDecoderLayer(nn.Layer): + def __init__(self, config, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = MixtralAttention(config, layerwise_recompute) + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config) + self.post_attention_layernorm = MixtralRMSNorm(config) + self.sequence_parallel = config.sequence_parallel + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class MixtralPretrainedModel(PretrainedModel): + config_class = MixtralConfig + base_model_prefix = "mixtral" + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: MixtralConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + for expert_idx in range(config.num_local_experts): + expert_mappings = [ + [f"layers.{layer_index}.block_sparse_moe.experts.{expert_idx}.w1.weight", None, "transpose"], + [f"layers.{layer_index}.block_sparse_moe.experts.{expert_idx}.w2.weight", None, "transpose"], + [f"layers.{layer_index}.block_sparse_moe.experts.{expert_idx}.w3.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.block_sparse_moe.gate.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "MixtralModel" + if "MixtralModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "mixtral." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: MixtralConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers, num_local_experts): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # Add tp split for expert params. + base_actions = { + "layers.0.block_sparse_moe.experts.0.w1.weight": partial(fn, is_column=True), + "layers.0.block_sparse_moe.experts.0.w2.weight": partial(fn, is_column=False), + "layers.0.block_sparse_moe.experts.0.w3.weight": partial(fn, is_column=True), + } + for key, action in base_actions.items(): + for i in range(num_layers): + newkey = key.replace("layers.0.", f"layers.{i}.") + for j in range(num_local_experts): + newkey2 = newkey.replace("experts.0.", f"experts.{j}.") + final_actions[newkey2] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, config.num_local_experts) + + return mappings + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.ColumnParallelLinear, + mpu.RowParallelLinear, + MixtralLMHead, + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.mixtral.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.mixtral.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, MixtralMLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.w2.weight.scale_(factor) + if isinstance(layer, MixtralAttention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class MixtralModel(MixtralPretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList( + [MixtralDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)] + ) + self.norm = MixtralRMSNorm(config) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + output_router_logits: bool, + past_key_value: Tensor, + use_cache: bool, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + output_router_logits: Optional[bool] = None, + return_dict=False, + **kwargs, + ): + if self.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + if attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + is_casual = is_casual_mask(attention_mask) + if is_casual: + attention_mask = None + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + output_router_logits, + past_key_value, + use_cache, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class MixtralPretrainingCriterion(paddle.nn.Layer): + """ + Criterion for Mixtral. + It calculates the final loss. + """ + + def __init__(self, config): + + super(MixtralPretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + # skip ignore_index which loss == 0 + masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + loss = paddle.mean(masked_lm_loss) + + return loss + + +class MixtralLMHead(nn.Layer): + def __init__(self, config: MixtralConfig): + super(MixtralLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + return logits + + +class MixtralForCausalLM(MixtralPretrainedModel): + enable_to_static_method = True + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.mixtral = MixtralModel(config) + self.lm_head = MixtralLMHead(config) + self.criterion = MixtralPretrainingCriterion(config) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + if config.sliding_window is not None: + logger.warning("We do not support sliding window attention for now.") + + def get_input_embeddings(self): + return self.mixtral.embed_tokens + + def set_input_embeddings(self, value): + self.mixtral.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.mixtral = decoder + + def get_decoder(self): + return self.mixtral + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, MoECausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + output_router_logits: Optional[bool] = None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mixtral( + input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/paddlenlp/transformers/model_outputs.py b/paddlenlp/transformers/model_outputs.py index 4ffee9ebef24..5700522746ab 100644 --- a/paddlenlp/transformers/model_outputs.py +++ b/paddlenlp/transformers/model_outputs.py @@ -1427,3 +1427,94 @@ class Seq2SeqSpectrogramOutput(ModelOutput): encoder_last_hidden_state: Optional[paddle.Tensor] = None encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None encoder_attentions: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class MoEModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(paddle.Tensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: paddle.Tensor = None + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + router_logits: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`paddle.Tensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(paddle.Tensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[paddle.Tensor] = None + aux_loss: Optional[paddle.Tensor] = None + logits: paddle.Tensor = None + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + router_logits: Optional[Tuple[paddle.Tensor]] = None diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index 8acae5ab61fe..f617ff760ad1 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -77,6 +77,7 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: PREFIX_CONFIG_NAME = "prefix_config.json" PREFIX_WEIGHTS_NAME = "prefix_model_state.pdparams" +PADDLE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.pdparams.index.json" PAST_KEY_VALUES_FILE_NAME = "pre_caches.npy" @@ -100,3 +101,6 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: SAFE_MASTER_WEIGHTS_NAME = "master_weights.safetensors" SAFE_MASTER_WEIGHTS_INDEX_NAME = "master_weights.safetensors.index.json" + +SAFE_PEFT_WEIGHTS_NAME = "peft_model.safetensors" +SAFE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.safetensors.index.json" diff --git a/pipelines/examples/agents/react_example.py b/pipelines/examples/agents/react_example.py index 4496ad66eda6..75279eb5eb5a 100644 --- a/pipelines/examples/agents/react_example.py +++ b/pipelines/examples/agents/react_example.py @@ -82,7 +82,7 @@ # yapf: disable parser = argparse.ArgumentParser() -parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.") parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ") parser.add_argument("--api_key", default=None, type=str, help="The API Key.") args = parser.parse_args() diff --git a/pipelines/examples/agents/react_example_cn.py b/pipelines/examples/agents/react_example_cn.py index 967381e0e104..5d249d010c2a 100644 --- a/pipelines/examples/agents/react_example_cn.py +++ b/pipelines/examples/agents/react_example_cn.py @@ -60,7 +60,7 @@ parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") -parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI'], default="dense", help="The type of Retriever.") +parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI', 'SearchApi'], default="dense", help="The type of Retriever.") parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") @@ -68,7 +68,7 @@ parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path") parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path") parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index") -parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.") +parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.") parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding") parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types") parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ") diff --git a/pipelines/pipelines/nodes/search_engine/providers.py b/pipelines/pipelines/nodes/search_engine/providers.py index 9e8833968bbe..2f2405bc5c8b 100644 --- a/pipelines/pipelines/nodes/search_engine/providers.py +++ b/pipelines/pipelines/nodes/search_engine/providers.py @@ -239,3 +239,110 @@ def search(self, query: str, **kwargs) -> List[Document]: logger.debug("Serper.dev API returned %s documents for the query '%s'", len(documents), query) result_docs = documents[:top_k] return self.score_results(result_docs, len(answer_box) > 0) + + +class SearchApi(SearchEngine): + """ + SearchApi is a real-time search engine that provides an API to access search results from Google, Google Scholar, YouTube, + YouTube transcripts and more. See the [SearchApi website](https://www.searchapi.io/) for more details. + """ + + def __init__( + self, + api_key: str, + top_k: Optional[int] = 10, + engine: Optional[str] = "google", + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param api_key: API key for SearchApi. + :param top_k: Number of results to return. + :param engine: Search engine to use, for example google, google_scholar, youtube, youtube_transcripts. + See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported engines. + :param search_engine_kwargs: Additional parameters passed to the SearchApi. + See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported parameters. + """ + super().__init__() + self.params_dict: Dict[str, Union[str, int, float]] = {} + self.api_key = api_key + self.kwargs = search_engine_kwargs if search_engine_kwargs else {} + self.engine = engine + self.top_k = top_k + + def search(self, query: str, **kwargs) -> List[Document]: + """ + :param query: Query string. + :param kwargs: Additional parameters passed to the SearchApi. For example, you can set 'location' to 'New York,United States' + to localize search to the specific location. + :return: List[Document] + """ + kwargs = {**self.kwargs, **kwargs} + top_k = kwargs.pop("top_k", self.top_k) + url = "https://www.searchapi.io/api/v1/search" + + params = {"q": query, **kwargs} + headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "PaddleNLP"} + + if self.engine: + params["engine"] = self.engine + response = requests.get(url, params=params, headers=headers, timeout=90) + + if response.status_code != 200: + raise Exception(f"Error while querying {self.__class__.__name__}: {response.text}") + + json_content = json.loads(response.text) + documents = [] + has_answer_box = False + + if json_content.get("answer_box"): + if json_content["answer_box"].get("organic_result"): + title = json_content["answer_box"].get("organic_result").get("title", "") + link = json_content["answer_box"].get("organic_result").get("link", "") + if json_content["answer_box"].get("type") == "population_graph": + title = json_content["answer_box"].get("place", "") + link = json_content["answer_box"].get("explore_more_link", "") + + title = json_content["answer_box"].get("title", "") + link = json_content["answer_box"].get("link") + content = json_content["answer_box"].get("answer") or json_content["answer_box"].get("snippet") + + if link and content: + has_answer_box = True + documents.append(Document.from_dict({"title": title, "content": content, "link": link})) + + if json_content.get("knowledge_graph"): + if json_content["knowledge_graph"].get("source"): + link = json_content["knowledge_graph"].get("source").get("link", "") + + link = json_content["knowledge_graph"].get("website", "") + content = json_content["knowledge_graph"].get("description") + + if link and content: + documents.append( + Document.from_dict( + {"title": json_content["knowledge_graph"].get("title", ""), "content": content, "link": link} + ) + ) + + documents += [ + Document.from_dict({"title": c["title"], "content": c.get("snippet", ""), "link": c["link"]}) + for c in json_content["organic_results"] + ] + + if json_content.get("related_questions"): + for question in json_content["related_questions"]: + if question.get("source"): + link = question.get("source").get("link", "") + else: + link = "" + + content = question.get("answer", "") + + if link and content: + documents.append( + Document.from_dict({"title": question.get("question", ""), "content": content, "link": link}) + ) + + logger.debug("SearchApi returned %s documents for the query '%s'", len(documents), query) + result_docs = documents[:top_k] + return self.score_results(result_docs, has_answer_box) diff --git a/pipelines/pipelines/nodes/search_engine/web.py b/pipelines/pipelines/nodes/search_engine/web.py index 573756f58527..b0c62df6fefb 100644 --- a/pipelines/pipelines/nodes/search_engine/web.py +++ b/pipelines/pipelines/nodes/search_engine/web.py @@ -28,6 +28,7 @@ class WebSearch(BaseComponent): WebSerach currently supports the following search engines providers (bridges): - SerperDev (default) + - SearchApi - SerpAPI - BingAPI diff --git a/tests/llm/test_prefix_tuning.py b/tests/llm/test_finetune_prefix_tuning.py similarity index 94% rename from tests/llm/test_prefix_tuning.py rename to tests/llm/test_finetune_prefix_tuning.py index 9074c564bd6b..4f70a10d1c0e 100644 --- a/tests/llm/test_prefix_tuning.py +++ b/tests/llm/test_finetune_prefix_tuning.py @@ -19,7 +19,11 @@ from parameterized import parameterized_class -from tests.testing_utils import argv_context_guard, load_test_config +from tests.testing_utils import ( + argv_context_guard, + load_test_config, + skip_for_none_ce_case, +) from .testing_utils import LLMTest @@ -50,6 +54,7 @@ def tearDown(self) -> None: LLMTest.tearDown(self) sys.path.remove(self.model_codes_dir) + @skip_for_none_ce_case def test_prefix_tuning(self): prefix_tuning_config = load_test_config(self.config_path, "prefix_tuning", self.model_dir) diff --git a/tests/test_tipc/auto_tuner/autoconfig/check_mem_usage.sh b/tests/test_tipc/auto_tuner/autoconfig/check_mem_usage.sh new file mode 100644 index 000000000000..58e315ac905f --- /dev/null +++ b/tests/test_tipc/auto_tuner/autoconfig/check_mem_usage.sh @@ -0,0 +1,37 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +autoconfig_json_file="$1" # autoconfig/llama7b_pretrain_buffer.json +autoconfig_json_file_name=$(basename "$1") +model_name=${autoconfig_json_file_name%.*} +auto_log_file=./autoconfig/${model_name}_auto_tuner.log + +log="./llama7b_pretrain_auto_tuner.log" +launch_best_cfg=$(sed -n "s/.*Launch best cfg: \(.*\)}/\1/p" "$auto_log_file") +cfg_max_mem_usage=$(echo "$launch_best_cfg" | awk -F"max_mem_usage': " '{print $2}' | awk -F, '{print $1}') + +buffer=$(sed -n 's/.*"buffer":\([^,}]*\).*/\1/p' $autoconfig_json_file | awk '{print $1}') +max_mem_usage=$(sed -n 's/.*"max_mem_usage":\([^,}]*\).*/\1/p' $autoconfig_json_file | awk '{print $1}') +result=`expr $max_mem_usage - $buffer` + +if [ $cfg_max_mem_usage -le $result ] +then + echo "Autotuner buffer预留成功" + exit 0 +else + echo "Autotuner buffer预留失败" + echo "Autotuner 预设 max_mem_usgae: $max_mem_usage buffer: $buffer, 可用显存为: $result" + echo "Autotuner best_cfg 实际使用显存为: $cfg_max_mem_usage" + exit -1 +fi \ No newline at end of file diff --git a/tests/test_tipc/auto_tuner/autoconfig/llama7b_pretrain_buffer.json b/tests/test_tipc/auto_tuner/autoconfig/llama7b_pretrain_buffer.json new file mode 100644 index 000000000000..1c06a6094928 --- /dev/null +++ b/tests/test_tipc/auto_tuner/autoconfig/llama7b_pretrain_buffer.json @@ -0,0 +1,89 @@ +{ + "dp_degree": "auto", + "max_search_time": 900, + "max_time_per_task": 400, + "buffer":17408, + "max_mem_usage":40960, + "metric_cfg": { + "OptimizationDirection": "Maximize", + "name": "interval_samples_per_second" + }, + "micro_batch_size": "auto", + "model_cfg": { + "global_batch_size": 8, + "hidden_size": 5120, + "num_attention_heads": 40, + "num_layers": 40, + "vocab_size": 32000 + }, + "mp_degree": "auto", + "pp_degree": "auto", + "run_cmd": { + "gradient_accumulation_steps": [ + "./autoconfig/llama7b_pretrain_params.json", + "gradient_accumulation_steps" + ], + "micro_batch_size": [ + "./autoconfig/llama7b_pretrain_params.json", + "per_device_train_batch_size" + ], + "mp_degree": [ + "./autoconfig/llama7b_pretrain_params.json", + "tensor_parallel_degree" + ], + "pp_degree": [ + "./autoconfig/llama7b_pretrain_params.json", + "pipeline_parallel_degree" + ], + "run_best_stage": { + "continue_training": [ + "./autoconfig/llama7b_pretrain_params.json", + "continue_training", + 0 + ], + "autotuner_benchmark": [ + "./autoconfig/llama7b_pretrain_params.json", + "autotuner_benchmark", + 0 + ] + }, + "search_stage": { + "continue_training": [ + "./autoconfig/llama7b_pretrain_params.json", + "continue_training", + 0 + ], + "autotuner_benchmark": [ + "./autoconfig/llama7b_pretrain_params.json", + "autotuner_benchmark", + 1 + ] + }, + "sharding_degree": [ + "./autoconfig/llama7b_pretrain_params.json", + "sharding_parallel_degree" + ], + "sharding_stage": [ + "./autoconfig/llama7b_pretrain_params.json", + "sharding", + "stage" + ], + "use_recompute": [ + "./autoconfig/llama7b_pretrain_params.json", + "recompute" + ], + "recompute_granularity": [ + "./autoconfig/llama7b_pretrain_params.json", + "recompute_granularity" + ] + }, + "sharding_degree": "auto", + "sharding_stage": "auto", + "task_limit": 2000, + "use_recompute": "auto", + "recompute_granularity": "auto", + "invalid_strategy": ["stage3_mp*"], + "schedule_prior": ["mp4"], + "need_baseline": true, + "mode": "Pretrain" + } \ No newline at end of file diff --git a/tests/test_tipc/auto_tuner/llama_pretrain/N1C8/CE_buffer_autotuner_llama7b_bs8_bf16_pretrain.sh b/tests/test_tipc/auto_tuner/llama_pretrain/N1C8/CE_buffer_autotuner_llama7b_bs8_bf16_pretrain.sh new file mode 100644 index 000000000000..e20c5539925c --- /dev/null +++ b/tests/test_tipc/auto_tuner/llama_pretrain/N1C8/CE_buffer_autotuner_llama7b_bs8_bf16_pretrain.sh @@ -0,0 +1,25 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +param="model_item=CE_buffer_autotuner_llama7b " +param+="run_mode=pretrain " +param+="device_num=N1C8 " +param+="global_batch_size=8 " +param+="autoconfig_json_file=autoconfig/llama7b_pretrain_buffer.json " +param+="modle_json_file=autoconfig/llama7b_pretrain_params.json " + +cd ./tests +bash ./test_tipc/auto_tuner/llama_pretrain/benchmark_common/prepare.sh + +bash -c "${param} bash ./test_tipc/auto_tuner/llama_pretrain/benchmark_common/run_benchmark.sh" diff --git a/tests/test_tipc/auto_tuner/llama_pretrain/benchmark_common/run_benchmark.sh b/tests/test_tipc/auto_tuner/llama_pretrain/benchmark_common/run_benchmark.sh index 8055fc75932d..c0394a0bd681 100644 --- a/tests/test_tipc/auto_tuner/llama_pretrain/benchmark_common/run_benchmark.sh +++ b/tests/test_tipc/auto_tuner/llama_pretrain/benchmark_common/run_benchmark.sh @@ -130,6 +130,16 @@ function _train(){ echo -e "auto_tuner, SUCCESS" >> ${log_file} fi fi + if [[ ${model_item} =~ "buffer" ]];then + bash autoconfig/check_mem_usage.sh ${autoconfig_json_file} >> ${log_file} 2>&1 + if [ $? -ne 0 ];then + echo -e "${model_name}, mem_usage buffer check FAIL" >> ${log_file} + sed '/ips/d' "$log_file" > "$log_file.tmp" + mv "$log_file.tmp" "$log_file" + else + echo -e "${model_name}, mem_usage buffer check SUCCESS" >> ${log_file} + fi + fi #kill -9 `ps -ef|grep 'python'|awk '{print $2}'` if [ ${device_num} != "N1C1" -a -d ./autoconfig/best_cfg ]; then case_path=$PWD && cd - && mkdir -p mylog # PaddleNLP/tests/mylog diff --git a/tests/trainer/test_lora_unified_checkpoint.py b/tests/trainer/test_lora_unified_checkpoint.py new file mode 100644 index 000000000000..b46c14db7e8c --- /dev/null +++ b/tests/trainer/test_lora_unified_checkpoint.py @@ -0,0 +1,478 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +import numpy as np +import pytest + +from paddlenlp.utils.downloader import get_path_from_url_with_filelock +from tests.parallel_launch import TestMultipleGpus +from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case +from tests.trainer.trainer_utils import get_pretrain_arguments + +environment_variables = { + "NCCL_ALGO": "Tree", + "NVIDIA_TF32_OVERRIDE": "0", + "NCCL_IB_TIMEOUT": "22", + "NCCL_DEBUG": "INFO", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + "Flags_mp_aysnc_allreduce": "1", + "Flags_skip_mp_c_identity": "1", + "FLAGS_shard_norm_align_dp": "0", + "FLAGS_shard_use_reduce": "1", + "test_ci_no_save_model": "1", +} + +lora_arguments = { + "model_name_or_path": "__internal_testing__/unified-ckpt-llama-170m-for-peft", + "dataset_name_or_path": "./unified_checkpoint/peft_input/data/", + "output_dir": "./unified_checkpoint/checkpoints/llama_lora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 8, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps": 16, + "learning_rate": 3e-04, + "max_steps": 15, + "save_steps": 10, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "steps", + "src_length": 1024, + "max_length": 2048, + "fp16": "true", + "fp16_opt_level": "O2", + "do_train": "true", + "do_eval": "false", + "disable_tqdm": "true", + "eval_with_do_generation": "false", + "recompute": "true", + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "lora": "true", + "zero_padding": "false", + "use_flash_attention": "false", + "unified_checkpoint": 1, +} + + +# convert from N1C8 to N2C4 or N2C4 to N1C8 +MAX_CONVERT_CONFIGS = 1 # max: 16, min: 1 + +seed = 2024 + +rng = np.random.default_rng(seed=seed) + + +def random_sample(keys, k): + return rng.permutation(list(keys))[0:k].tolist() + + +def check_acc(log_dir="log"): + file_path = os.path.join(log_dir, "workerlog.n0.c0") + cmd = "grep -a 'global_step: 15' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'" + import subprocess + + res = subprocess.check_output(cmd, shell=True, text=True) + res = [float(x) for x in res.split()] + + return res + + +def remove_logs(log_dir="log"): + if os.path.exists(log_dir): + shutil.rmtree(log_dir) + + +def remove_ckpt(ckpt_dir): + if os.path.exists(ckpt_dir): + shutil.rmtree(ckpt_dir) + + +class TestUnifiedCheckpointSingle(TestMultipleGpus): + def setUp(self): + self.config = lora_arguments + os.environ.update(environment_variables) + + file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz" + input_dir = "unified_checkpoint/lora_input/" # unified_checkpoint/lora/data + os.makedirs(input_dir, exist_ok=True) + file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz") + if not os.path.exists(file_path): + get_path_from_url_with_filelock(file_, root_dir=input_dir) + + self.need_allclose = True + self.rtol = 1e-7 + self.run_lora_file = "llm/finetune_generation.py" + self.num_nodes = 1 + + def runfirst(self, train_args): + self.run_1gpu(self.run_lora_file, **train_args) + + def rerun(self, train_args): + self.run_1gpu(self.run_lora_file, **train_args) + + @skip_for_none_ce_case + def testDP1(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + self.runfirst(self.config) + self.rerun(self.config) + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + +# Test Unified Checkpoint Hybrid Parallel Strategy on N1C8 and N2C4 +class TestUnifiedCheckpointBase(TestMultipleGpus): + @classmethod + @property + def __test__(cls): + return cls != TestUnifiedCheckpointBase + + def setUp(self): + """ + 1. update runfrist and rerun to run defined different config + 2. update need_allclose to True if you want to check the result + 3. update rtol to the relative value you want to check + """ + + self.configs = get_pretrain_arguments(lora_arguments) + os.environ.update(environment_variables) + + file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz" + input_dir = "unified_checkpoint/peft_input/" # unified_checkpoint/lora/data + os.makedirs(input_dir, exist_ok=True) + file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz") + if not os.path.exists(file_path): + get_path_from_url_with_filelock(file_, root_dir=input_dir) + + self.need_allclose = True + self.rtol = 1e-7 + + self.run_lora_file = "llm/finetune_generation.py" + + def runfrist(self, train_args): + self.run_n1c8(self.run_lora_file, **train_args) + + def rerun(self, train_args): + self.run_n1c8(self.run_lora_file, **train_args) + + @require_paddle_at_least_8_gpu + def testTP4PP2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP4PP2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP2Sharding4(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP2Sharding4"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + +class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase): + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP8(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP8"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @require_paddle_at_least_8_gpu + def testTP4DP2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP4DP2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP4Sharding2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP4Sharding2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP2PP4(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["TP2PP4"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testPP8(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["PP8"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testPP4DP2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["PP4DP2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testPP4Sharding2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["PP4Sharding2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding8S1(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding8S1"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding8S2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding8S2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding4S1DP2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding4S1DP2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding4S2DP2(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding4S2DP2"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding2S1DP4(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding2S1DP4"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testSharding2S2DP4(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["Sharding2S2DP4"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testDP8(self): + remove_logs() + remove_ckpt(lora_arguments["output_dir"]) + + train_args = self.configs["DP8"] + self.runfrist(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + +class TestUnifiedCheckpointOnN2C4(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + self.need_allclose = True + self.rtol = 1e-7 + + def runfrist(self, train_args): + self.run_n2c4(self.run_lora_file, **train_args) + + def rerun(self, train_args): + self.run_n2c4(self.run_lora_file, **train_args) + + +class TestUnifiedCheckpointOnN1C8CheckpointCompatible(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + + self.need_allclose = True + self.rtol = 1e-7 + + def runfrist(self, train_args): + train_args["unified_checkpoint"] = 0 + self.run_n1c8(self.run_lora_file, **train_args) + + def rerun(self, train_args): + train_args["unified_checkpoint"] = 1 + self.run_n1c8(self.run_lora_file, **train_args) + + +class TestPaddleCheckpointOnN1C8Reset(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + + self.need_allclose = True + self.rtol = 1e-7 + + def runfrist(self, train_args): + train_args["unified_checkpoint"] = 0 + self.run_n1c8(self.run_lora_file, **train_args) + + def rerun(self, train_args): + train_args["unified_checkpoint"] = 0 + self.run_n1c8(self.run_lora_file, **train_args) + + +@pytest.mark.skipif(True, reason="Skip for None CE") +class TestUnifiedCheckpointOnN2C4CheckpointCompatible(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + + self.need_allclose = True + self.rtol = 1e-7 + + def runfrist(self, train_args): + train_args["unified_checkpoint"] = 0 + self.run_n2c4(self.run_lora_file, **train_args) + + def rerun(self, train_args): + train_args["unified_checkpoint"] = 1 + self.run_n2c4(self.run_lora_file, **train_args) diff --git a/tests/transformers/mixtral/__init__.py b/tests/transformers/mixtral/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/transformers/mixtral/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/mixtral/test_modeling.py b/tests/transformers/mixtral/test_modeling.py new file mode 100644 index 000000000000..2ae98a6fb326 --- /dev/null +++ b/tests/transformers/mixtral/test_modeling.py @@ -0,0 +1,318 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddlenlp.transformers import MixtralConfig, MixtralForCausalLM, MixtralModel +from tests.transformers.test_configuration_common import ConfigTester +from tests.transformers.test_generation_utils import GenerationTesterMixin +from tests.transformers.test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class MixtralModelTester: + def __init__( + self, + parent, + vocab_size=32000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + masked_softmax_fusion=True, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + is_training=True, + use_cache=False, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + attention_softmax_in_fp32=True, + pretraining_tp=1, # TP rank used when training with megatron + dtype="bfloat16", + slow_but_exact=False, + batch_size: int = 2, + seq_length: int = 10, + type_sequence_label_size=2, + activation_function="gelu", + num_labels=3, + num_choices=4, + scope=None, + dropout=0.56, + use_input_mask: bool = False, + use_labels: bool = False, + return_dict=False, + ): + self.parent: MixtralModelTest = parent + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.is_training = is_training + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.pretraining_tp = pretraining_tp + self.dtype = dtype + self.slow_but_exact = slow_but_exact + + self.batch_size = batch_size + self.seq_length = seq_length + self.type_sequence_label_size = type_sequence_label_size + self.activation_function = activation_function + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.dropout = dropout + + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.return_dict = return_dict + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self) -> MixtralConfig: + return MixtralConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + masked_softmax_fusion=self.masked_softmax_fusion, + layer_norm_epsilon=self.layer_norm_epsilon, + initializer_range=self.initializer_range, + use_cache=self.use_cache, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, + hidden_dropout=self.hidden_dropout, + attention_dropout=self.attention_dropout, + attention_softmax_in_fp32=self.attention_softmax_in_fp32, + pretraining_tp=self.pretraining_tp, + dtype=self.dtype, + slow_but_exact=self.slow_but_exact, + activation_function=self.activation_function, + ) + + def create_and_check_model( + self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MixtralModel(config) + model.eval() + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + + def create_and_check_model_attention_mask( + self, config: MixtralConfig, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = MixtralModel(config) + model.eval() + attn_mask_2d = random_attention_mask([self.batch_size, self.seq_length]) + result_2d = model(input_ids, attention_mask=attn_mask_2d)[0] + batch, seq_length = input_ids.shape + causal_mask = paddle.tril(paddle.ones((batch, seq_length, seq_length), dtype=attn_mask_2d.dtype)) + attn_mask_3d = causal_mask & attn_mask_2d.unsqueeze(-1) + result_3d = model(input_ids, attention_mask=attn_mask_3d)[0] + attn_mask_4d = attn_mask_3d.unsqueeze(1) + result_4d = model(input_ids, attention_mask=attn_mask_4d)[0] + result_no_attention_mask = model(input_ids, attention_mask=None)[0] + # Assert non-padding tokens have the same logits with different attention_mask shape + self.parent.assertTrue((result_2d[attn_mask_2d] == result_3d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_4d[attn_mask_2d]).all()) + self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all()) + + def create_and_check_model_past_large_inputs( + self, + config: MixtralConfig, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MixtralModel(config) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True, return_dict=self.return_dict) + past_key_values = outputs.past_key_values if self.return_dict else outputs[2] + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), self.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = paddle.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = paddle.concat([input_mask, next_mask], axis=-1) + + outputs = model( + next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True, return_dict=self.return_dict + ) + + output_from_no_past = outputs[2][0] + + outputs = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + return_dict=self.return_dict, + ) + + output_from_past = outputs[2][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(paddle.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args): + model = MixtralForCausalLM(config) + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def check_model_position_ids(self, config, input_ids, input_mask, *args): + model = MixtralForCausalLM(config) + model.eval() + + result_no_position_id = model( + input_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + batch_size, seq_len = input_ids.shape + position_ids = paddle.arange(seq_len).expand((batch_size, seq_len)) + result_position_id = model( + input_ids, + position_ids, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertTrue((result_position_id[1] == result_no_position_id[1]).all()) + else: + self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + + +class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + base_model_class = MixtralModel + return_dict = False + use_labels = False + use_test_model_name_list = False + + all_model_classes = (MixtralModel, MixtralForCausalLM) + all_generative_model_classes = {MixtralForCausalLM: (MixtralModel, "mixtral")} + + def setUp(self): + super().setUp() + + self.model_tester = MixtralModelTester(self) + self.config_tester = ConfigTester(self, config_class=MixtralConfig, vocab_size=256, hidden_size=24) + + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = inputs_dict[self.input_name] + attention_mask = paddle.ones_like(input_ids, dtype=paddle.int64) + + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + max_length = 3 + + return config, input_ids, attention_mask, max_length + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_attention_mask(*config_and_inputs) + + def test_model_position_ids(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_model_position_ids(*config_and_inputs) + + def test_generate_without_input_ids(self): + # this requires 4-D attention mask logic, which is not supported yet + pass + + def test_mixtral_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + +if __name__ == "__main__": + unittest.main()