diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 08fd09b6..ffc8409e 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -257,7 +257,8 @@ def get_megatron_flops_mla( Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf """ - checkpoint_activations_factor = 4 if checkpoint else 3 + # checkpoint_activations_factor = 4 if checkpoint else 3 + checkpoint_activations_factor = 3 if use_swiglu: mlp_ratio = mlp_ratio * 3 / 2 @@ -321,6 +322,7 @@ def get_megatron_flops_mla( attn_flops = 4 * global_batch_size * seq_len**2 * attn_hidden_size else: attn_flops = 4 * global_batch_size * seq_len**2 * hidden_size + attn_flops = attn_flops / 2 # vocab vocab_flops = 6 * global_batch_size * seq_len * hidden_size * vocab_size