Skip to content

Commit

Permalink
feat(common.py): update mla flops func
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Jan 2, 2025
1 parent 6826338 commit eae8b4f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def get_megatron_flops_mla(
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
"""

checkpoint_activations_factor = 4 if checkpoint else 3
# checkpoint_activations_factor = 4 if checkpoint else 3
checkpoint_activations_factor = 3

if use_swiglu:
mlp_ratio = mlp_ratio * 3 / 2
Expand Down Expand Up @@ -321,6 +322,7 @@ def get_megatron_flops_mla(
attn_flops = 4 * global_batch_size * seq_len**2 * attn_hidden_size
else:
attn_flops = 4 * global_batch_size * seq_len**2 * hidden_size
attn_flops = attn_flops / 2

# vocab
vocab_flops = 6 * global_batch_size * seq_len * hidden_size * vocab_size
Expand Down

0 comments on commit eae8b4f

Please sign in to comment.