You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert: assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
This happens because during the second all2all, the head count is already parallelized; hence, num_heads % seq_world_size != 0 returns true.
Second all2all input: [B, s, hc/sp, hs]. However, not always hc/sp % sp == 0.
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
# we only need num_heads once
num_heads = input.shape[2]
if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
# set heads at first call by num_total_heads.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
To Reproduce
To reproduce the error, one can set the SP=head_count.
Describe the bug
deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
This happens because during the second all2all, the head count is already parallelized; hence,
num_heads % seq_world_size != 0
returns true.Second all2all input:
[B, s, hc/sp, hs]
. However, not alwayshc/sp % sp == 0
.To Reproduce
To reproduce the error, one can set the SP=head_count.
Fix Suggestion:
Adjust num_heads accordingly:
The text was updated successfully, but these errors were encountered: