Skip to content

Commit

Permalink
fix(mha): passed dropout_p as 0 in mha eval
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Jan 29, 2024
1 parent 1229e2d commit af74d9f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deeplink_ext/internlm_ops/mha/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
# padded
return DeepLinkMultiHeadAttentionQKVPackedFunc.apply(
qkv,
self.dropout_p,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
Expand All @@ -59,7 +59,7 @@ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
qkv,
cu_seqlens,
max_seqlen,
self.dropout_p,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
Expand Down

0 comments on commit af74d9f

Please sign in to comment.