Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [Fix-Suggested] Checkpoint Inconsistency When Freezing Model Parameters Before deepspeed.initialize #6771

Open
traincheck-team opened this issue Nov 20, 2024 · 0 comments

Comments

@traincheck-team
Copy link

traincheck-team commented Nov 20, 2024

Related Issue: #5489

Describe the Bug

If the model has frozen parameters (requires_grad == False) before calling deepspeed.initialize, the frozen layers are excluded from the parameters managed by the DeepSpeed optimizer. However, the original (wrapped) optimizer still retains these parameters.

As a result, when calling the deepspeed_engine to save the model checkpoint, the optimizer state checkpoint (zero_pp_rank_0_mp_rank_00_optim_states.pt) only contains parameters managed by the DeepSpeed optimizer, leading to incomplete or inconsistent checkpointing.

Suspected Root Cause

When model parameters are modified (e.g., set requires_grad=False to freeze layers) before deepspeed.initialize, the DeepSpeed optimizer only manages the parameters that are still trainable. This inconsistency occurs because the frozen parameters are no longer part of the DeepSpeed-managed parameter group but remain covered by the original optimizer. Consequently, during checkpointing, only the active parameters are saved, resulting in a partial optimizer state and potential errors when attempting to resume training.

Expected Behavior / Suggested Fix

To address this issue, we propose the following fixes:

  1. Parameter Modification Enforcement: Ensure that deepspeed.initialize raises an error / warning if modifications to parameters (e.g., changing .requires_grad) are detected before the initialization.
  2. Consistent Checkpointing: Add a validation step during checkpoint saving to verify that all parameters in the optimizer match the model parameters, thereby ensuring consistency and providing a clear warning if a discrepancy is found.

Steps to Reproduce the Bug

Below is an example that reproduces the issue by freezing the model parameters before initializing DeepSpeed:

import torch
import deepspeed

# Define a model
def initialize_model():
    model = torch.nn.Linear(10, 1)
    return model

def freeze_parameters(model):
    for param in model.parameters():
        param.requires_grad = False

# Define an optimizer
def expose_bug():
    model = initialize_model()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

    # Freeze model parameters
    freeze_parameters(model)

    # Initialize DeepSpeed
    ds_config_fp16 = {
        "train_micro_batch_size_per_gpu": 1,
        "fp16": {"enabled": True},
        "zero_optimization": {"stage": 2}
    }
    model_engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config_params=ds_config_fp16)

    # Simulate checkpointing
    print("Checkpointing...")
    model_engine.save_checkpoint("./checkpoint_dir")

expose_bug()

To run the code, use deepspeed --num_gpus=1 bug.py.

Checks that the saved checkpoint is not consistent with the model. The script above demonstrates how the DeepSpeed optimizer does not account for the frozen model parameters during checkpointing, leading to incomplete optimizer state.

ds_report output

Click to Show
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/xxx/python3.10/site-packages/torch']
torch version .................... 2.2.2+cu121
deepspeed install path ........... ['/home/xxx/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.15.4, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 2.2, cuda 12.1
shared memory (/dev/shm) size .... 31.24 GB

I will be more than happy to contribute to the two suggested fixes, let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant