From eae8b4f8118a733a2407a6e87e1e2d0f3683da47 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Thu, 2 Jan 2025 12:05:12 +0800 Subject: [PATCH] feat(common.py): update mla flops func --- internlm/utils/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 08fd09b6..ffc8409e 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -257,7 +257,8 @@ def get_megatron_flops_mla( Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf """ - checkpoint_activations_factor = 4 if checkpoint else 3 + # checkpoint_activations_factor = 4 if checkpoint else 3 + checkpoint_activations_factor = 3 if use_swiglu: mlp_ratio = mlp_ratio * 3 / 2 @@ -321,6 +322,7 @@ def get_megatron_flops_mla( attn_flops = 4 * global_batch_size * seq_len**2 * attn_hidden_size else: attn_flops = 4 * global_batch_size * seq_len**2 * hidden_size + attn_flops = attn_flops / 2 # vocab vocab_flops = 6 * global_batch_size * seq_len * hidden_size * vocab_size