diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index 19e88b402bcf..9b162d4a88c1 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -305,7 +305,11 @@ def load_resolved_archive_file( ) ) if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != paddle.float32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + # for parameters with float32 dtype, no need to have fp32 master weights. + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index cd999f1dba46..41ba54972efb 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -67,6 +67,7 @@ FP32_MASTER, UnifiedCheckpointOption, filter_params, + filter_sync_parameters, gather_sharded_object, generate_base_static_name, get_expected_state_dict, @@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) - no_sync_kname = [] - model_state_dict = get_expected_state_dict(model) - for k, v in model_state_dict.items(): - if getattr(v, "no_sync", False): - no_sync_kname.append(k) - - hcg = fleet.get_hybrid_communicate_group() - dp_group = hcg.get_data_parallel_group() - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 if self.args.use_expert_parallel: - for k in list(optim_state_dict.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - optim_state_dict.pop(k) - if master_weights is not None: - for k in list(master_weights.keys()): - model_k = k.split("/")[0] - if dp_rank > 0 and model_k not in no_sync_kname: - master_weights.pop(k) + model_state_dict = get_expected_state_dict(model) + filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False) optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -516,6 +501,10 @@ def unified_checkpoint_into_shards( config_to_save = copy.deepcopy(model_to_save.config) + if args.use_expert_parallel: + # ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0. + filter_sync_parameters(state_dict, is_model_weight=True) + if config_to_save.tensor_parallel_degree > 1: if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): tp_actions = model_to_save._get_tensor_parallel_convert_actions( @@ -625,6 +614,9 @@ def unified_optimizer_into_shards( tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() tp_size = tp_group.nranks + if args.use_expert_parallel: + filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False) + if tp_size > 1: # get tp_actions model_keys = [] @@ -643,7 +635,6 @@ def unified_optimizer_into_shards( optim_state_dict, tp_actions, filter_optim_keys, - state_dict if args.use_expert_parallel else None, ) empty_device_cache() @@ -653,7 +644,6 @@ def unified_optimizer_into_shards( master_weights, tp_actions, filter_master_keys, - state_dict if args.use_expert_parallel else None, ) empty_device_cache() diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index bbb49ae14820..413ca7c47210 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -354,9 +354,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): """ hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() - dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 # filter actions for pipeline mode if hcg.get_pipe_parallel_group().nranks > 1: @@ -373,10 +371,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if i > len(filter_keys) - 1: continue key = filter_keys[i] - tensor = state_dict[key] - # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. - if dp_rank > 0 and not getattr(tensor, "no_sync", False): + if key not in state_dict: continue + tensor = state_dict[key] if key in tp_actions: # Get tensor size tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks @@ -405,21 +402,13 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): return state_dict_to_save -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): """ Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight. """ hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() - dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - - no_sync_kname = [] - if model_state_dict is not None: - for k, v in model_state_dict.items(): - if getattr(v, "no_sync", False): - no_sync_kname.append(k) state_dict_to_save = {} max_key_len = max([len(_) for _ in all_filter_keys]) @@ -430,10 +419,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, continue # get base model key model_key = filter_keys[i].split("/")[0] - tensor = state_dict[filter_keys[i]] - # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. - if dp_rank > 0 and model_key not in no_sync_kname: + if filter_keys[i] not in state_dict: continue + tensor = state_dict[filter_keys[i]] if model_key in tp_actions: # for example: beta1, beta2 if tensor.numel().item() == 1: @@ -770,3 +758,31 @@ def save_config(model_to_save): # save generation config if model_to_save.can_generate(): model_to_save.generation_config.save_pretrained(save_directory) + + +def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True): + """Filter sync parameters under expert parallel mode.""" + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + if is_model_weight: + for key in list(model_state_dict.keys()): + if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False): + model_state_dict.pop(key) + else: + no_sync_kname = [] + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + for key in list(optim_state_dict.keys()): + model_key = key.split("/")[0] + if dp_rank > 0 and model_key not in no_sync_kname: + optim_state_dict.pop(key) + + if master_weights is not None: + for key in list(master_weights.keys()): + if dp_rank > 0 and key not in no_sync_kname: + master_weights.pop(key) diff --git a/tests/trainer/test_moe_unified_checkpoint.py b/tests/trainer/test_moe_unified_checkpoint.py new file mode 100644 index 000000000000..618e2b2f3daf --- /dev/null +++ b/tests/trainer/test_moe_unified_checkpoint.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest + +from paddlenlp.utils.downloader import get_path_from_url_with_filelock +from tests.parallel_launch import TestMultipleGpus +from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case +from tests.trainer.test_unified_checkpoint import remove_ckpt, remove_logs +from tests.trainer.trainer_utils import get_pretrain_arguments + +environment_variables = { + "NCCL_ALGO": "Tree", + "NVIDIA_TF32_OVERRIDE": "0", + "NCCL_IB_TIMEOUT": "22", + "NCCL_DEBUG": "INFO", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + "Flags_mp_aysnc_allreduce": "1", + "Flags_skip_mp_c_identity": "1", + "FLAGS_shard_norm_align_dp": "0", + "FLAGS_shard_use_reduce": "1", + "test_ci_no_save_model": "1", +} + +moe_arguments = { + "model_name_or_path": "__internal_testing__/unified-ckpt-qwen2moe", + "dataset_name_or_path": "./unified_checkpoint/peft_input/data/", + "output_dir": "./unified_checkpoint/checkpoints/qwen2moe_sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 8, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps": 16, + "learning_rate": 3e-04, + "max_steps": 10, + "save_steps": 6, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "steps", + "src_length": 1024, + "max_length": 2048, + "bf16": "true", + "fp16_opt_level": "O2", + "do_train": "true", + "do_eval": "false", + "disable_tqdm": "true", + "eval_with_do_generation": "false", + "recompute": "true", + "recompute_granularity": "full", + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "", + "lora": "false", + "zero_padding": "false", + "use_flash_attention": "false", + "unified_checkpoint": 1, + "continue_training": 0, + "sequence_parallel": 0, +} + + +def check_acc(log_dir="log"): + file_path = os.path.join(log_dir, "workerlog.n0.c0") + cmd = "grep -a 'global_step: 10' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'" + import subprocess + + res = subprocess.check_output(cmd, shell=True, text=True) + res = [float(x) for x in res.split()] + + return res + + +seed = 2024 + +rng = np.random.default_rng(seed=seed) + + +@pytest.mark.xdist_group(name="UC") +class TestUnifiedCheckpointBase(TestMultipleGpus): + @classmethod + @property + def __test__(cls): + return cls != TestUnifiedCheckpointBase + + def setUp(self): + """ + 1. update runfirst and rerun to run defined different config + 2. update need_allclose to True if you want to check the result + 3. update rtol to the relative value you want to check + """ + + self.configs = get_pretrain_arguments(moe_arguments) + os.environ.update(environment_variables) + + file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz" + input_dir = "unified_checkpoint/peft_input/" + os.makedirs(input_dir, exist_ok=True) + file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz") + if not os.path.exists(file_path): + get_path_from_url_with_filelock(file_, root_dir=input_dir) + + self.need_allclose = True + self.rtol = 1e-7 + + self.run_file = "llm/run_finetune.py" + + def runfirst(self, train_args): + self.run_n1c8(self.run_file, **train_args) + + def rerun(self, train_args): + self.run_n1c8(self.run_file, **train_args) + + @require_paddle_at_least_8_gpu + def testTP4DP2(self): + remove_logs() + remove_ckpt(moe_arguments["output_dir"]) + + train_args = self.configs["TP4DP2"] + self.runfirst(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP2Sharding4(self): + remove_logs() + remove_ckpt(moe_arguments["output_dir"]) + + train_args = self.configs["TP2Sharding4"] + self.runfirst(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) + + +@pytest.mark.xdist_group(name="UC") +class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase): + @skip_for_none_ce_case + @require_paddle_at_least_8_gpu + def testTP2Sharding4V2(self): + remove_logs() + remove_ckpt(moe_arguments["output_dir"]) + + train_args = self.configs["TP2Sharding4"] + train_args.update({"sharding_parallel_config": "split_param"}) + train_args.update({"amp_master_grad": True}) + self.runfirst(train_args) + self.rerun(train_args) + + if self.need_allclose: + res = check_acc() + assert len(res) == 2 + np.testing.assert_allclose(res[0], res[1], self.rtol) diff --git a/tests/trainer/trainer_utils.py b/tests/trainer/trainer_utils.py index ae9a40e61d59..cda374ce1c6a 100644 --- a/tests/trainer/trainer_utils.py +++ b/tests/trainer/trainer_utils.py @@ -141,6 +141,14 @@ def get_pretrain_arguments(pretrain_arguments): train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 8 configs["DP8"] = train_args + train_args = copy.deepcopy(pretrain_arguments) + train_args["tensor_parallel_degree"] = 2 + train_args["pipeline_parallel_degree"] = 1 + train_args["sharding_parallel_degree"] = 2 + train_args["sharding"] = "stage1" + train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 4 + configs["TP2DP2Sharding2"] = train_args + return configs