Skip to content

Commit

Permalink
Enrich logic to fuse rotaryembedding with bias and support partial RE…
Browse files Browse the repository at this point in the history
… fusion into GQA (microsoft#20300)

### Description
This PR mainly focuses on adding two functionalities:
1. Fuse RotaryEmbedding op taking output from previous layers with bias
enabled.

> Matmul->RotaryEmbedding    ----->  Matmul->Add->RotatyEmbedding

2. Fuse GQA op for partial RotaryEmbedding applied in phi-2.

> # Partial rotary embedding
        query_rot, query_pass = (
            query_states[..., : self.rotary_emb.dim],
            query_states[..., self.rotary_emb.dim :],
        )
        key_rot, key_pass = (
            key_states[..., : self.rotary_emb.dim],
            key_states[..., self.rotary_emb.dim :],
        )
# [batch_size, seq_length, num_heads, head_dim //
config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin,
position_ids)

        # [batch_size, seq_length, num_heads, head_dim]
        query_states = torch.cat((query_rot, query_pass), dim=-1)
        key_states = torch.cat((key_rot, key_pass), dim=-1)

# Optimized graph

![image](https://github.com/microsoft/onnxruntime/assets/17421593/76fd8576-7e60-41af-9a4f-48d205fc6b56)
  • Loading branch information
zhangxiang1993 authored Apr 15, 2024
1 parent 7ec51f0 commit bf72f99
Showing 1 changed file with 210 additions and 10 deletions.
220 changes: 210 additions & 10 deletions onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# v_nodes_1 is for LLaMA-2 Microsoft
# v_nodes_3 is for LLaMA-2 Hugging Face
# v_nodes_4 is for LLaMA-2 70B model
# v_nodes_5 is for Phi-2 DirectML
past_v, present_v, past_seq_len = "", "", ""
v_nodes = None
add_v = None
v_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
Expand Down Expand Up @@ -491,6 +493,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
],
output_name_to_node=None,
)
v_nodes_5 = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)
if v_nodes_1 is not None:
reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
v_nodes = v_nodes_1
Expand Down Expand Up @@ -521,6 +528,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
v_nodes = v_nodes_4
past_v = concat_v.input[0]
present_v = concat_v.output[0]
elif v_nodes_5 is not None:
concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
matmul_v = add_v
v_nodes = v_nodes_5
past_v = concat_v.input[0]
present_v = concat_v.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match v path")
return
Expand Down Expand Up @@ -607,6 +620,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# k_nodes_4 is for LLaMA-2 70B Hugging Face
past_k, present_k = "", ""
k_nodes = None
slice_k = None
concat_k_half = None
k_nodes_1 = self.model.match_parent_path(
matmul_qk,
["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
Expand Down Expand Up @@ -790,6 +805,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
],
output_name_to_node=None,
)
k_nodes_5 = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 0, 0, 0, 1],
)
if k_nodes_1 is not None:
reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
k_nodes = k_nodes_1
Expand Down Expand Up @@ -823,13 +843,21 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
k_nodes = k_nodes_4
past_k = concat_k.input[0]
present_k = concat_k.output[0]
elif k_nodes_5 is not None:
_, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
k_nodes = k_nodes_5
past_k = concat_k.input[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match k nodes")
return

# q_nodes_1 is for LLaMA-2 Microsoft
# q_nodes_2 is for LLaMA-2 Hugging Face
# q_nodes_3 is for Phi-2 DirectML
q_nodes = None
slice_q = None
concat_q_half = None
q_nodes_1 = self.model.match_parent_path(
matmul_qk,
["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
Expand All @@ -840,12 +868,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
[0, 0, 0, 0],
)
q_nodes_3 = self.model.match_parent_path(
matmul_qk,
["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 0, 0, 1],
)
if q_nodes_1 is not None:
reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
q_nodes = q_nodes_1
elif q_nodes_2 is not None:
rotary_q, _, reshape_q, matmul_q = q_nodes_2
q_nodes = q_nodes_2
elif q_nodes_3 is not None:
concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
q_nodes = q_nodes_3
else:
logger.debug("fuse_rotary_attention: failed to match q nodes")
return
Expand Down Expand Up @@ -885,15 +921,132 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Rename inputs of rotary_q/k so it connects with output of matmul_q/k
# Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
# After: MatMul --> RotaryEmbedding
rotary_q.input[0] = matmul_q.output[0]
rotary_k.input[0] = matmul_k.output[0]
rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]

# Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
rotary_k.output[0] = rotary_k.name + "_output_0"
if concat_q_half is None:
rotary_k.output[0] = rotary_k.name + "_output_0"

if qkv_nodes == qkv_nodes_3:
qkv_nodes = qkv_nodes[1:]

def create_hidden_size_concat_node(reshape_q):
"""Detect num_heads and hidden_size for ONNX model from phi-2
Args:
reshape_q (NodeProto): reshape node for q
Returns:
hidden_size_concat_node(NodeProto): Concat node to be used by reshape
"""
concat = self.model.match_parent(reshape_q, "Concat", 1)

if concat is None:
logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
return None

# The shape is a tensor like [?, ?, num_heads, head_size]
num_head_constant_node = self.model.get_constant_value(concat.input[2])
head_size_constant_node = self.model.get_constant_value(concat.input[3])

if num_head_constant_node is None or head_size_constant_node is None:
logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
return None

num_head_value = num_head_constant_node[0]
head_size_value = head_size_constant_node[0]

hidden_size = num_head_value * head_size_value

hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
if self.model.get_initializer(hidden_size_initilizer) is None:
self.add_initializer(
name=hidden_size_initilizer,
data_type=TensorProto.INT64,
dims=[1],
vals=[hidden_size],
raw=False,
)

hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")

hidden_size_concat_node = helper.make_node(
"Concat",
inputs=[
concat.input[0],
concat.input[1],
hidden_size_initilizer,
],
outputs=[hidden_size_reshape_node_name + "output_0"],
name=hidden_size_reshape_node_name,
)
hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])

return hidden_size_concat_node

# Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
if concat_q_half and concat_k_half:
# Transpose the key output of rotary Embedding
k_transpose_node_name = self.model.create_node_name("Transpose")
k_tranpose_output_name = k_transpose_node_name + "_output_0"
k_transpose_node = helper.make_node(
"Transpose",
inputs=[concat_k_half.output[0]],
outputs=[k_tranpose_output_name],
name=k_transpose_node_name,
)

k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])

# Transpose the query output of rotary Embedding
q_transpose_node_name = self.model.create_node_name("Transpose")
q_tranpose_output_name = q_transpose_node_name + "_output_0"
q_transpose_node = helper.make_node(
"Transpose",
inputs=[concat_q_half.output[0]],
outputs=[q_tranpose_output_name],
name=q_transpose_node_name,
)

q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])

hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
if hidden_size_concat_node is None:
logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
return

# Reshape the Rotary Embedding output for key for 4D to 3D
concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
concat_k_reshape_node = helper.make_node(
"Reshape",
inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
outputs=[concat_k_reshape_node_name + "_output_0"],
name=concat_k_reshape_node_name,
)

# Reshape the Rotary Embedding output for query from 4D to 3D
concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
concat_q_reshape_node = helper.make_node(
"Reshape",
inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
outputs=[concat_q_reshape_node_name + "_output_0"],
name=concat_q_reshape_node_name,
)

rotary_k = concat_k_reshape_node
rotary_q = concat_q_reshape_node

self.nodes_to_add.append(hidden_size_concat_node)
self.nodes_to_add.append(k_transpose_node)
self.nodes_to_add.append(q_transpose_node)
self.nodes_to_add.append(concat_k_reshape_node)
self.nodes_to_add.append(concat_q_reshape_node)

self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name

new_node = self.create_mha_node(
matmul_q.input[0],
root_output,
Expand All @@ -917,7 +1070,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
self.nodes_to_remove.extend(qkv_nodes[1:])

if v_nodes != v_nodes_4:
self.nodes_to_remove.extend(v_nodes[:-1])
self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
else:
nodes_to_keep = [v_nodes[0][-1]]
for temp_path in v_nodes:
Expand All @@ -936,6 +1089,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
self.nodes_to_remove.append(k_nodes[1])
self.nodes_to_remove.append(k_nodes[3])
self.nodes_to_remove.append(k_nodes[4])
elif k_nodes == k_nodes_5:
self.nodes_to_remove.append(k_nodes[0])
self.nodes_to_remove.append(k_nodes[1])
elif k_nodes == k_nodes_4:
nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
for temp_path in k_nodes:
Expand All @@ -946,7 +1102,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
elif q_nodes == q_nodes_2:
self.nodes_to_remove.append(q_nodes[1])
self.nodes_to_remove.append(q_nodes[2])

self.prune_graph = True


Expand Down Expand Up @@ -1167,30 +1322,66 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
# return x_embed

# Check paths for rotate_half(x)
rotate_half_x2_path_1 = self.model.match_parent_path(
rotate_half_x2_path_1_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Transpose"],
[1, 0, 0, 0, 0],
)
rotate_half_x2_path_2 = self.model.match_parent_path(

rotate_half_x2_path_1_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Slice"],
[1, 0, 0, 0, 0],
)

rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2

rotate_half_x2_path_2_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
[1, 0, 0, 0, 1, 0, 0, 0, 0],
)

rotate_half_x2_path_2_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
[1, 0, 0, 0, 1, 0, 0, 0, 0],
)

rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2

if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
return

rotate_half_x1_path_1 = self.model.match_parent_path(
rotate_half_x1_path_1_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Transpose"],
[1, 0, 1, 0],
)
rotate_half_x1_path_2 = self.model.match_parent_path(

rotate_half_x1_path_1_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Slice"],
[1, 0, 1, 0],
)

rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2

rotate_half_x1_path_2_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
[1, 0, 1, 2, 0, 0, 0, 0],
)

rotate_half_x1_path_2_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
[1, 0, 1, 2, 0, 0, 0, 0],
)

rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2

if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
return
Expand All @@ -1205,11 +1396,20 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
return

# Check path for x
x_path = self.model.match_parent_path(
x_path_1 = self.model.match_parent_path(
node,
["Mul", "Transpose"],
[0, 0],
)

x_path_2 = self.model.match_parent_path(
node,
["Mul", "Slice"],
[0, 0],
)

x_path = x_path_1 or x_path_2

if x_path is None:
logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
return
Expand Down

0 comments on commit bf72f99

Please sign in to comment.