Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] clip_grad_norm for zero_optimization mode is not working #6767

Open
chengmengli06 opened this issue Nov 20, 2024 · 3 comments
Open

[BUG] clip_grad_norm for zero_optimization mode is not working #6767

chengmengli06 opened this issue Nov 20, 2024 · 3 comments
Labels
bug Something isn't working training

Comments

@chengmengli06
Copy link

set "gradient_clipping" in deepspeed does not work, look into the source code in deepspeed.runtime.engine.DeepSpeedEngine,in line 2101

    def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
        if self.gradient_clipping() > 0.0:
            if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
                self.clip_fp32_gradients()
            elif self.amp_enabled():
                # AMP's recommended way of doing clipping
                # https://nvidia.github.io/apex/advanced.html#gradient-clipping
                master_params = amp.master_params(self.optimizer)
                clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
        self.optimizer.step()

thus gradient clipping do nothing at all!!!

@chengmengli06 chengmengli06 added bug Something isn't working compression labels Nov 20, 2024
@chengmengli06 chengmengli06 changed the title [REQUEST]Please add clip_grad_norm for zero_optimization mode [Bug] clip_grad_norm for zero_optimization mode is not working Nov 20, 2024
@chengmengli06 chengmengli06 changed the title [Bug] clip_grad_norm for zero_optimization mode is not working [BUG] clip_grad_norm for zero_optimization mode is not working Nov 20, 2024
@tjruwase
Copy link
Contributor

@chengmengli06, this is incorrect reading of the code. Gradient clipping is handled in the respective optimizer implementations such as:

  1. bf16 optim
  2. fp16 optim
  3. zero

@chengmengli06
Copy link
Author

I find it, and verify that it does work under zero_2 mode. Thanks!

@chengmengli06
Copy link
Author

chengmengli06 commented Nov 21, 2024

@tjruwase another question is how log the pre-clip and after clip gradient norms to tensorboard? is there any interface to get the pre and after clip gradient norms?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants