From e885d08d37ec856a1dfdda4f67883848e3e4bb5e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 24 Feb 2025 16:43:35 +0100 Subject: [PATCH] Fix --- optimum/neuron/distributed/base.py | 2 +- optimum/neuron/distributed/checkpointing.py | 12 ++++++++++-- optimum/neuron/distributed/utils.py | 5 ++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 6aacbe755..e1312f3ac 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -680,7 +680,7 @@ def should_parallelize_layer_predicate_func(layer): "kv_size_multiplier": None, "fuse_qkv": None, "q_output_size_per_partition": None, - "kv_output_size_per_partition": None , + "kv_output_size_per_partition": None, } for mod in model.modules(): if isinstance(mod, OptimumGQAQKVColumnParallelLinear): diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index a9e7a7c30..5d07ffa0f 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -158,9 +158,17 @@ def consolidate_tensor_parallel_checkpoints( if weight_name == "weight_q": s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"]) elif weight_name == "weight_k": - s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"]) + s = slice( + gqa_qkv_metadata["q_output_size_per_partition"], + gqa_qkv_metadata["q_output_size_per_partition"] + + gqa_qkv_metadata["kv_output_size_per_partition"], + ) elif weight_name == "weight_v": - s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None) + s = slice( + gqa_qkv_metadata["q_output_size_per_partition"] + + gqa_qkv_metadata["kv_output_size_per_partition"], + None, + ) else: s = slice(None, None) else: diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 569d14542..fcf258898 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -804,7 +804,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( if proj_name == "q": s = slice(0, layer.q_output_size_per_partition) elif proj_name == "k": - s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition) + s = slice( + layer.q_output_size_per_partition, + layer.q_output_size_per_partition + layer.kv_output_size_per_partition, + ) else: s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None) weight[s, :] = weight_data