Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 24, 2025
1 parent 08d58f8 commit e885d08
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def should_parallelize_layer_predicate_func(layer):
"kv_size_multiplier": None,
"fuse_qkv": None,
"q_output_size_per_partition": None,
"kv_output_size_per_partition": None ,
"kv_output_size_per_partition": None,
}
for mod in model.modules():
if isinstance(mod, OptimumGQAQKVColumnParallelLinear):
Expand Down
12 changes: 10 additions & 2 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,17 @@ def consolidate_tensor_parallel_checkpoints(
if weight_name == "weight_q":
s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"])
elif weight_name == "weight_k":
s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"])
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"],
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
)
elif weight_name == "weight_v":
s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None)
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
None,
)
else:
s = slice(None, None)
else:
Expand Down
5 changes: 4 additions & 1 deletion optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
if proj_name == "q":
s = slice(0, layer.q_output_size_per_partition)
elif proj_name == "k":
s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition)
s = slice(
layer.q_output_size_per_partition,
layer.q_output_size_per_partition + layer.kv_output_size_per_partition,
)
else:
s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None)
weight[s, :] = weight_data
Expand Down

0 comments on commit e885d08

Please sign in to comment.