Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pipeline parallelism+fp16+moe isn't working #6714

Open
NeferpitouS3 opened this issue Nov 5, 2024 · 3 comments
Open

[BUG] pipeline parallelism+fp16+moe isn't working #6714

NeferpitouS3 opened this issue Nov 5, 2024 · 3 comments

Comments

@NeferpitouS3
Copy link

Describe the bug
My model use deepspeed PipelineModule(num_stages=4) split into 4 parts, and my deepspeed.moe.layer.MoE is only set in the pipeline stage1 layer. When my model train_batch, the program will get stuck, the specific issue occurs in FP16_Optimizer step.

Here is our deepspeed config

{
   "train_batch_size": 4,
   "train_micro_batch_size_per_gpu" 1,
   "fp16": {
      "enabled": true,
      "auto_cast": true
   },
   "optimizer": {
      "type": "AdamW",
      "params": {
         "lr": 0.001,
         "betas": [
            0.9,
            0.95
         ],
         "weight_decay": 0.05
      }
   },
   "zero_optimization": {
      "stage": 0
   }
}

Source code with issues
my pipeline_parallel_world_size is 4, the code will enter the following branch, but my moe layer only is set in pipeline stage1, then all_reduce will make program stuck. If I delete this code, it will run successfully.

elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))

I don't know why all_reduce needs to be done here, it doesn't seem meaningful

@ranzhejiang
Copy link
Contributor

can you provide the whole script to reproduce it?

@NeferpitouS3
Copy link
Author

Here is a simple example adapted from DeepspeedExamples.training.cifar.

class net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = layer1()
        self.layer2 = layer2()
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class layer1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        return x

class layer2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = DeepSpeedMoEMlp()
        self.fc3 = nn.Linear(120, 10)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

class DeepSpeedMoEMlp(nn.Module):
    def __init__(self):
        super(DeepSpeedMoEMlp, self).__init__()
        self.fc2 = nn.Linear(120, 120)
        self._moe_layer = deepspeed.moe.layer.MoE(
            hidden_size=120,
            expert=self.fc2,
            num_experts=4,
            k=1,
            capacity_factor=1.25,
            use_tutel=True
        )
    def forward(self, x):
        x, _, _ = self._moe_layer(x)
        return x

if __name__ == "__main__":
    deepspeed.init_distributed()
    model = net()
    layer = torch.nn.Sequential(
        model.layer1,
        model.layer2
    )
    criterion = nn.CrossEntropyLoss()
    pipeline_model = PipelineModule(layers=layer, loss_fn= criterion, num_stages=2, partition_method="uniform")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=False,
                                            transform=transform)
    engine, _, data_loader, _ = deepspeed.initialize(
        model=pipeline_model,
        model_parameters=pipeline_model.parameters(),
        config=get_ds_config(),
        training_data=trainset
    )
    for epoch in range(3):
        for i in range(len(trainset)):
            loss = engine.train_batch()
            print(f"step{i}, loss: {loss}")

After running this code with deepspeed --include="localhost:0,1" test.py --deepspeed .The problem I mentioned above will be reproduced.

@traincheck-team
Copy link

traincheck-team commented Nov 23, 2024

Hi @ranzhejiang @NeferpitouS3, we tried your example, ran it through our bug diagnosis tool, and found out that the bug is caused by inconsistent optimizer states across the workers.

The all-reduce function that @NeferpitouS3 referred to is invoked inside get_global_norm_of_tensors@/runtime/utils.py

The invocation of this get_global_norm_of_tensors API is controlled by optimizer.has_moe_layer in the FP16 optimizer

Image

Not surprisingly, we found out that this get_global_norm_of_tensors was never called on any workers except for the first worker, i.e. the stuck worker, indicating self.has_moe_layers is only True on the stuck worker!

We also confirmed this by printing out the self.has_moe_layers inside each step, and its indeed True on worker rank 0 while false on other workers.

Suggested Next Steps:

Inconsistency of this self.has_moe_layers can indicate two possible root cause locations:

  1. Deepspeed didn't initialize FP16_Optimizer properly
  2. The attribute self.has_moe_layers is accidentally modified somewhere

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants