Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] GW prefill merge qkv #12410

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 51 additions & 36 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,45 +141,60 @@ def attention(self,
if self.n_splits_linear != 1:
hidden_states = self.unsqueeze(hidden_states, axis=0)

query_states = self.linear(
hidden_states,
num_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)
if mode == "prefill":
concat_linear = self.linear(hidden_states,
num_key_value_heads * head_dim * 2 + num_heads * head_dim,
hidden_size,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
if q_bias is not None:
concat_linear = concat_linear + q_bias
query_states, key_states, value_states = self.variadic_split(
concat_linear, 2,
[num_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim]
)
else:
query_states = self.linear(
hidden_states,
num_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)

key_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)
key_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)

value_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)
value_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
)

if q_bias is not None:
query_states = query_states + q_bias
if k_bias is not None:
key_states = key_states + k_bias
if v_bias is not None:
value_states = value_states + v_bias
if q_bias is not None:
query_states = query_states + q_bias
if k_bias is not None:
key_states = key_states + k_bias
if v_bias is not None:
value_states = value_states + v_bias

query_states = self.reshape(
query_states, [1, seq_len, num_heads, head_dim]
Expand Down
89 changes: 65 additions & 24 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,18 @@ def __init__(
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]

if q_biases is None:
q_biases = []
k_biases = []
v_biases = []
for i in range(num_layers):
q_biases.append(self.create_input_op((self.num_heads * self.head_dim,)))
k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
if mode == "prefill":
q_biases = []
for i in range(num_layers):
q_biases.append(self.create_input_op((self.num_heads * self.head_dim + self.num_key_value_heads * self.head_dim * 2,)))
else:
q_biases = []
k_biases = []
v_biases = []
for i in range(num_layers):
q_biases.append(self.create_input_op((self.num_heads * self.head_dim,)))
k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,)))
else:
q_biases = [self.constant(w) for w in q_biases]
k_biases = [self.constant(w) for w in k_biases]
Expand Down Expand Up @@ -217,8 +222,8 @@ def __init__(
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
q_bias=q_biases[i],
k_bias=k_biases[i],
v_bias=v_biases[i],
k_bias=k_biases[i] if mode == "decode" else None,
v_bias=v_biases[i] if mode == "decode" else None,
past_key=past_keys[i],
past_value=past_values[i],
)
Expand All @@ -241,6 +246,11 @@ def __init__(
else:
self.compile()
print(f"{mode} end compiling")
qwen_size = "7b" if self.hidden_size == 3584 else "1.5b"
xml_path = f"gw/qwen-{qwen_size}-npu-qkv-split-{mode}-{num_layers}-{n_splits_linear}-{n_splits_down_proj}.xml"

if not os.path.exists(xml_path):
self.save(xml_path)

def build_decoder(
self,
Expand Down Expand Up @@ -524,8 +534,7 @@ def forward(
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.float16),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
inputs += (self.q_bias, self.k_bias, self.v_bias)
inputs += (self.layer_norm_0, self.layer_norm_1, self.q_bias)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
)
Expand Down Expand Up @@ -815,23 +824,52 @@ def run_prefill(
mlp_layer = curr_layer.mlp

weights = []
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_linear == 1:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
else:
qkv_weights = []
qkv_scales = []
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
qkv_weights.append(torch.stack(l_weights, axis=0))
qkv_scales.append(torch.stack(scales, axis=0))

weights.append((torch.cat(qkv_weights, dim=1), torch.cat(qkv_scales, dim=1)))

for layer_list in [attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

merge_bias = torch.cat([attn_layer.q_proj_dq_list.q_proj_dq_0.bias,
attn_layer.k_proj_dq_list.k_proj_dq_0.bias,
attn_layer.v_proj_dq_list.v_proj_dq_0.bias]).to(torch.float16)

new_decoderlayer = FusedQwenLowBitDecoderlayer(
weights,
num_heads=num_heads,
Expand All @@ -840,9 +878,12 @@ def run_prefill(
cached_sin=cached_sin,
layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1,
q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16),
k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16),
v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16),
# q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16),
# k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16),
# v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16),
q_bias=merge_bias,
k_bias=None,
v_bias=None,
layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
Expand Down
Loading