@@ -82,6 +82,57 @@ Status BuildNodeMap(const Graph& graph,
82
82
return Status::OK ();
83
83
}
84
84
85
+ Status FindExtraConcatInput (const Graph& graph,
86
+ const std::vector<std::string>& input_output_names,
87
+ std::vector<const Node*>* filter_concat_node) {
88
+ std::unordered_set<const Node*> candidate_node;
89
+ std::unordered_set<Node*> concat_nodes;
90
+ for (auto * node : graph.nodes ()) {
91
+ if (node->type_string () == " ConcatV2" ) {
92
+ concat_nodes.insert (node);
93
+ }
94
+ }
95
+ std::unordered_set<std::string> in_out_names;
96
+ for (auto & name : input_output_names) {
97
+ in_out_names.insert (name);
98
+ }
99
+ for (const Node* c_nodes : concat_nodes) {
100
+ std::vector<const Node*> in_placeholder;
101
+ ReverseDFSFrom (
102
+ graph, {c_nodes},
103
+ [&in_placeholder, in_out_names](const Node* node) {
104
+ if (in_out_names.find (node->name ()) != in_out_names.end ()) {
105
+ in_placeholder.emplace_back (node);
106
+ }
107
+ },
108
+ /* end*/ nullptr );
109
+ if (in_placeholder.size () > 1 ) { // verify node in common sub-graph
110
+ DataType t_types;
111
+ TF_RETURN_IF_ERROR (GetNodeAttr (c_nodes->attrs (), " T" , &t_types));
112
+ if (t_types == DT_FLOAT) {
113
+ candidate_node.insert (c_nodes);
114
+ }
115
+ }
116
+ }
117
+
118
+ for (const Node* cnode : candidate_node) {
119
+ bool is_admit = true ;
120
+ ReverseDFSFrom (graph, {cnode},
121
+ [&filter_concat_node, &is_admit, candidate_node,
122
+ cnode](const Node* node) {
123
+ if ((candidate_node.find (node) != candidate_node.end ()) &&
124
+ (cnode->name () != node->name ())) {
125
+ is_admit = false ;
126
+ }
127
+ },
128
+ /* end*/ nullptr );
129
+ if (is_admit) {
130
+ filter_concat_node->emplace_back (cnode);
131
+ }
132
+ }
133
+ return Status::OK ();
134
+ }
135
+
85
136
EngineInfo::EngineType GetEngineType (
86
137
const TRTOptimizationPass::ConversionParams& params) {
87
138
return (params.is_dynamic_op || params.use_calibration )
@@ -773,6 +824,14 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
773
824
for (const auto & node : input_output_names) {
774
825
segment_options.exclude_node_list .insert (node);
775
826
}
827
+ std::vector<const Node*> filter_concat_node;
828
+ TF_RETURN_IF_ERROR (
829
+ FindExtraConcatInput (graph, input_output_names, &filter_concat_node));
830
+ for (const auto * node : filter_concat_node) {
831
+ for (auto * inode : node->in_nodes ()) {
832
+ segment_options.exclude_node_list .insert (inode->name ());
833
+ }
834
+ }
776
835
777
836
segment_options.minimum_segment_size = params.minimum_segment_size ;
778
837
segment_options.use_implicit_batch = params.use_implicit_batch ;
0 commit comments