diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 618d3c2fab12c..7384cace21a67 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -362,8 +362,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face # v_nodes_4 is for LLaMA-2 70B model + # v_nodes_5 is for Phi-2 DirectML past_v, present_v, past_seq_len = "", "", "" v_nodes = None + add_v = None v_nodes_1 = self.model.match_parent_path( matmul_qkv, ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"], @@ -491,6 +493,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ], output_name_to_node=None, ) + v_nodes_5 = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, 1], + ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 @@ -521,6 +528,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_nodes = v_nodes_4 past_v = concat_v.input[0] present_v = concat_v.output[0] + elif v_nodes_5 is not None: + concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5 + matmul_v = add_v + v_nodes = v_nodes_5 + past_v = concat_v.input[0] + present_v = concat_v.output[0] else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -607,6 +620,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None + slice_k = None + concat_k_half = None k_nodes_1 = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"], @@ -790,6 +805,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ], output_name_to_node=None, ) + k_nodes_5 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, 0, 0, 0, 1], + ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 @@ -823,13 +843,21 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_4 past_k = concat_k.input[0] present_k = concat_k.output[0] + elif k_nodes_5 is not None: + _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5 + k_nodes = k_nodes_5 + past_k = concat_k.input[0] + present_k = concat_k.output[0] else: logger.debug("fuse_rotary_attention: failed to match k nodes") return # q_nodes_1 is for LLaMA-2 Microsoft # q_nodes_2 is for LLaMA-2 Hugging Face + # q_nodes_3 is for Phi-2 DirectML q_nodes = None + slice_q = None + concat_q_half = None q_nodes_1 = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"], @@ -840,12 +868,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0], ) + q_nodes_3 = self.model.match_parent_path( + matmul_qk, + ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 0, 0, 1], + ) if q_nodes_1 is not None: reshape_q_2, _, rotary_q, matmul_q = q_nodes_1 q_nodes = q_nodes_1 elif q_nodes_2 is not None: rotary_q, _, reshape_q, matmul_q = q_nodes_2 q_nodes = q_nodes_2 + elif q_nodes_3 is not None: + concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3 + q_nodes = q_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match q nodes") return @@ -885,15 +921,132 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Rename inputs of rotary_q/k so it connects with output of matmul_q/k # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding # After: MatMul --> RotaryEmbedding - rotary_q.input[0] = matmul_q.output[0] - rotary_k.input[0] = matmul_k.output[0] + rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0] + rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0] # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) - rotary_k.output[0] = rotary_k.name + "_output_0" + if concat_q_half is None: + rotary_k.output[0] = rotary_k.name + "_output_0" if qkv_nodes == qkv_nodes_3: qkv_nodes = qkv_nodes[1:] + def create_hidden_size_concat_node(reshape_q): + """Detect num_heads and hidden_size for ONNX model from phi-2 + Args: + reshape_q (NodeProto): reshape node for q + Returns: + hidden_size_concat_node(NodeProto): Concat node to be used by reshape + """ + concat = self.model.match_parent(reshape_q, "Concat", 1) + + if concat is None: + logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q") + return None + + # The shape is a tensor like [?, ?, num_heads, head_size] + num_head_constant_node = self.model.get_constant_value(concat.input[2]) + head_size_constant_node = self.model.get_constant_value(concat.input[3]) + + if num_head_constant_node is None or head_size_constant_node is None: + logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size") + return None + + num_head_value = num_head_constant_node[0] + head_size_value = head_size_constant_node[0] + + hidden_size = num_head_value * head_size_value + + hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size") + if self.model.get_initializer(hidden_size_initilizer) is None: + self.add_initializer( + name=hidden_size_initilizer, + data_type=TensorProto.INT64, + dims=[1], + vals=[hidden_size], + raw=False, + ) + + hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat") + + hidden_size_concat_node = helper.make_node( + "Concat", + inputs=[ + concat.input[0], + concat.input[1], + hidden_size_initilizer, + ], + outputs=[hidden_size_reshape_node_name + "output_0"], + name=hidden_size_reshape_node_name, + ) + hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)]) + + return hidden_size_concat_node + + # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA + if concat_q_half and concat_k_half: + # Transpose the key output of rotary Embedding + k_transpose_node_name = self.model.create_node_name("Transpose") + k_tranpose_output_name = k_transpose_node_name + "_output_0" + k_transpose_node = helper.make_node( + "Transpose", + inputs=[concat_k_half.output[0]], + outputs=[k_tranpose_output_name], + name=k_transpose_node_name, + ) + + k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])]) + + # Transpose the query output of rotary Embedding + q_transpose_node_name = self.model.create_node_name("Transpose") + q_tranpose_output_name = q_transpose_node_name + "_output_0" + q_transpose_node = helper.make_node( + "Transpose", + inputs=[concat_q_half.output[0]], + outputs=[q_tranpose_output_name], + name=q_transpose_node_name, + ) + + q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])]) + + hidden_size_concat_node = create_hidden_size_concat_node(reshape_k) + if hidden_size_concat_node is None: + logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node") + return + + # Reshape the Rotary Embedding output for key for 4D to 3D + concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half") + concat_k_reshape_node = helper.make_node( + "Reshape", + inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]], + outputs=[concat_k_reshape_node_name + "_output_0"], + name=concat_k_reshape_node_name, + ) + + # Reshape the Rotary Embedding output for query from 4D to 3D + concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half") + concat_q_reshape_node = helper.make_node( + "Reshape", + inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]], + outputs=[concat_q_reshape_node_name + "_output_0"], + name=concat_q_reshape_node_name, + ) + + rotary_k = concat_k_reshape_node + rotary_q = concat_q_reshape_node + + self.nodes_to_add.append(hidden_size_concat_node) + self.nodes_to_add.append(k_transpose_node) + self.nodes_to_add.append(q_transpose_node) + self.nodes_to_add.append(concat_k_reshape_node) + self.nodes_to_add.append(concat_q_reshape_node) + + self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name + self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name + self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name + self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name + self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name + new_node = self.create_mha_node( matmul_q.input[0], root_output, @@ -917,7 +1070,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend(qkv_nodes[1:]) if v_nodes != v_nodes_4: - self.nodes_to_remove.extend(v_nodes[:-1]) + self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2]) else: nodes_to_keep = [v_nodes[0][-1]] for temp_path in v_nodes: @@ -936,6 +1089,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_5: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[1]) elif k_nodes == k_nodes_4: nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] for temp_path in k_nodes: @@ -946,7 +1102,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif q_nodes == q_nodes_2: self.nodes_to_remove.append(q_nodes[1]) self.nodes_to_remove.append(q_nodes[2]) - self.prune_graph = True @@ -1167,30 +1322,66 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # return x_embed # Check paths for rotate_half(x) - rotate_half_x2_path_1 = self.model.match_parent_path( + rotate_half_x2_path_1_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Transpose"], [1, 0, 0, 0, 0], ) - rotate_half_x2_path_2 = self.model.match_parent_path( + + rotate_half_x2_path_1_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Slice"], + [1, 0, 0, 0, 0], + ) + + rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2 + + rotate_half_x2_path_2_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], [1, 0, 0, 0, 1, 0, 0, 0, 0], ) + + rotate_half_x2_path_2_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"], + [1, 0, 0, 0, 1, 0, 0, 0, 0], + ) + + rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2 + if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None: logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half") return - rotate_half_x1_path_1 = self.model.match_parent_path( + rotate_half_x1_path_1_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Transpose"], [1, 0, 1, 0], ) - rotate_half_x1_path_2 = self.model.match_parent_path( + + rotate_half_x1_path_1_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Slice"], + [1, 0, 1, 0], + ) + + rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2 + + rotate_half_x1_path_2_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], [1, 0, 1, 2, 0, 0, 0, 0], ) + + rotate_half_x1_path_2_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"], + [1, 0, 1, 2, 0, 0, 0, 0], + ) + + rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2 + if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None: logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half") return @@ -1205,11 +1396,20 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return # Check path for x - x_path = self.model.match_parent_path( + x_path_1 = self.model.match_parent_path( node, ["Mul", "Transpose"], [0, 0], ) + + x_path_2 = self.model.match_parent_path( + node, + ["Mul", "Slice"], + [0, 0], + ) + + x_path = x_path_1 or x_path_2 + if x_path is None: logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half") return