Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Fix expert parallel #9821

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,11 @@
)
)
if has_master_weights:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])

Check warning on line 309 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L308-L309

Added lines #L308 - L309 were not covered by tests
else:
# for moe gate with float32 dtype.
key_name = "_".join([static_name, key_name[1]])

Check warning on line 312 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L312

Added line #L312 was not covered by tests
else:
key_name = "_".join([static_name, key_name[1]])

Expand Down
30 changes: 27 additions & 3 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@

config_to_save = copy.deepcopy(model_to_save.config)

if args.use_expert_parallel:

Check warning on line 519 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L519

Added line #L519 was not covered by tests
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
for key in list(state_dict.keys()):
if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
state_dict.pop(key)

Check warning on line 526 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L521-L526

Added lines #L521 - L526 were not covered by tests

if config_to_save.tensor_parallel_degree > 1:
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
tp_actions = model_to_save._get_tensor_parallel_convert_actions(
Expand Down Expand Up @@ -622,8 +631,25 @@
filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True)
filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True)

tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()

Check warning on line 636 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L634-L636

Added lines #L634 - L636 were not covered by tests
tp_size = tp_group.nranks
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

Check warning on line 638 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L638

Added line #L638 was not covered by tests

if args.use_expert_parallel:
no_sync_kname = []
for k, v in state_dict.items():
if getattr(state_dict[k], "no_sync", False):
no_sync_kname.append(k)
for key in list(optim_state_dict.keys()):
model_key = key.split("/")[0]
if dp_rank > 0 and model_key not in no_sync_kname:
optim_state_dict.pop(key)
if master_weights is not None:
for key in list(master_weights.keys()):
if dp_rank > 0 and key not in no_sync_kname:
master_weights.pop(key)

Check warning on line 652 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L640-L652

Added lines #L640 - L652 were not covered by tests

if tp_size > 1:
# get tp_actions
Expand All @@ -643,7 +669,6 @@
optim_state_dict,
tp_actions,
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()

Expand All @@ -653,7 +678,6 @@
master_weights,
tp_actions,
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()

Expand Down
22 changes: 5 additions & 17 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,7 @@
"""
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
tp_rank = tp_group.rank
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

# filter actions for pipeline mode
if hcg.get_pipe_parallel_group().nranks > 1:
Expand All @@ -373,10 +371,9 @@
if i > len(filter_keys) - 1:
continue
key = filter_keys[i]
tensor = state_dict[key]
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
if dp_rank > 0 and not getattr(tensor, "no_sync", False):
if key not in state_dict:

Check warning on line 374 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L374

Added line #L374 was not covered by tests
continue
tensor = state_dict[key]

Check warning on line 376 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L376

Added line #L376 was not covered by tests
if key in tp_actions:
# Get tensor size
tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks
Expand Down Expand Up @@ -405,21 +402,13 @@
return state_dict_to_save


def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None):
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
"""
Merge tensor parallel according to tp_actions, used for master_weight and optimizer weight.
"""
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
tp_rank = tp_group.rank
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0

no_sync_kname = []
if model_state_dict is not None:
for k, v in model_state_dict.items():
if getattr(v, "no_sync", False):
no_sync_kname.append(k)

state_dict_to_save = {}
max_key_len = max([len(_) for _ in all_filter_keys])
Expand All @@ -430,10 +419,9 @@
continue
# get base model key
model_key = filter_keys[i].split("/")[0]
tensor = state_dict[filter_keys[i]]
# When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0.
if dp_rank > 0 and model_key not in no_sync_kname:
if filter_keys[i] not in state_dict:

Check warning on line 422 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L422

Added line #L422 was not covered by tests
continue
tensor = state_dict[filter_keys[i]]

Check warning on line 424 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L424

Added line #L424 was not covered by tests
if model_key in tp_actions:
# for example: beta1, beta2
if tensor.numel().item() == 1:
Expand Down
175 changes: 175 additions & 0 deletions tests/trainer/test_moe_unified_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import pytest

from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from tests.parallel_launch import TestMultipleGpus
from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case
from tests.trainer.test_unified_checkpoint import remove_ckpt, remove_logs
from tests.trainer.trainer_utils import get_pretrain_arguments

environment_variables = {
"NCCL_ALGO": "Tree",
"NVIDIA_TF32_OVERRIDE": "0",
"NCCL_IB_TIMEOUT": "22",
"NCCL_DEBUG": "INFO",
"FLAGS_embedding_deterministic": "1",
"FLAGS_cudnn_deterministic": "1",
"Flags_mp_aysnc_allreduce": "1",
"Flags_skip_mp_c_identity": "1",
"FLAGS_shard_norm_align_dp": "0",
"FLAGS_shard_use_reduce": "1",
"test_ci_no_save_model": "1",
}

moe_arguments = {
"model_name_or_path": "./tests/trainer/unified-ckpt-qwen2moe",
"dataset_name_or_path": "./unified_checkpoint/peft_input/data/",
"output_dir": "./unified_checkpoint/checkpoints/qwen2moe_sft_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps": 16,
"learning_rate": 3e-04,
"max_steps": 10,
"save_steps": 6,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "no",
"save_strategy": "steps",
"src_length": 1024,
"max_length": 2048,
"bf16": "true",
"fp16_opt_level": "O2",
"do_train": "true",
"do_eval": "false",
"disable_tqdm": "true",
"eval_with_do_generation": "false",
"recompute": "true",
"recompute_granularity": "full",
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding": "",
"lora": "false",
"zero_padding": "false",
"use_flash_attention": "false",
"unified_checkpoint": 1,
"continue_training": 0,
"sequence_parallel": 0,
}


def check_acc(log_dir="log"):
file_path = os.path.join(log_dir, "workerlog.n0.c0")
cmd = "grep -a 'global_step: 10' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'"
import subprocess

res = subprocess.check_output(cmd, shell=True, text=True)
res = [float(x) for x in res.split()]

return res


seed = 2024

rng = np.random.default_rng(seed=seed)


@pytest.mark.xdist_group(name="UC")
class TestUnifiedCheckpointBase(TestMultipleGpus):
@classmethod
@property
def __test__(cls):
return cls != TestUnifiedCheckpointBase

def setUp(self):
"""
1. update runfirst and rerun to run defined different config
2. update need_allclose to True if you want to check the result
3. update rtol to the relative value you want to check
"""

self.configs = get_pretrain_arguments(moe_arguments)
os.environ.update(environment_variables)

file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz"
input_dir = "unified_checkpoint/peft_input/"
os.makedirs(input_dir, exist_ok=True)
file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz")
if not os.path.exists(file_path):
get_path_from_url_with_filelock(file_, root_dir=input_dir)

self.need_allclose = True
self.rtol = 1e-7

self.run_file = "llm/run_finetune.py"

def runfirst(self, train_args):
self.run_n1c8(self.run_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_file, **train_args)

@require_paddle_at_least_8_gpu
def testTP4DP2(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP4DP2"]
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)

@require_paddle_at_least_8_gpu
def testTP2Sharding4(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)


@pytest.mark.xdist_group(name="UC")
class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase):
@skip_for_none_ce_case
@require_paddle_at_least_8_gpu
def testTP2Sharding4V2(self):
remove_logs()
remove_ckpt(moe_arguments["output_dir"])

train_args = self.configs["TP2Sharding4"]
train_args.update({"sharding_parallel_config": "split_param"})
train_args.update({"amp_master_grad": True})
self.runfirst(train_args)
self.rerun(train_args)

if self.need_allclose:
res = check_acc()
assert len(res) == 2
np.testing.assert_allclose(res[0], res[1], self.rtol)
8 changes: 8 additions & 0 deletions tests/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def get_pretrain_arguments(pretrain_arguments):
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 8
configs["DP8"] = train_args

train_args = copy.deepcopy(pretrain_arguments)
train_args["tensor_parallel_degree"] = 2
train_args["pipeline_parallel_degree"] = 1
train_args["sharding_parallel_degree"] = 2
train_args["sharding"] = "stage1"
train_args["gradient_accumulation_steps"] = train_args["gradient_accumulation_steps"] // 4
configs["TP2DP2Sharding2"] = train_args

return configs


Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/unified-ckpt-qwen2moe/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"architectures": [
"Qwen2MoeForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"decoder_sparse_step": 1,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2_moe",
"moe_intermediate_size": 2560,
"norm_topk_prob": false,
"num_attention_heads": 28,
"num_experts": 8,
"num_experts_per_tok": 2,
"num_hidden_layers": 8,
"num_key_value_heads": 4,
"output_router_logits": false,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 20480,
"sliding_window": 131072,
"tie_word_embeddings": false,
"dtype": "bfloat16",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
8 changes: 8 additions & 0 deletions tests/trainer/unified-ckpt-qwen2moe/generation_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"bos_token_id": 151643,
"pad_token_id": 151643,
"eos_token_id": [
151645,
151643
]
}
Loading