Skip to content

Commit 19a895e

Browse files
committed
[TensorRT] Add inputs of the very first ConcatOp to exclude list.
Signed-off-by: 泊霆 <[email protected]>
1 parent 023af9c commit 19a895e

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc

+59
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,57 @@ Status BuildNodeMap(const Graph& graph,
8282
return Status::OK();
8383
}
8484

85+
Status FindExtraConcatInput(const Graph& graph,
86+
const std::vector<std::string>& input_output_names,
87+
std::vector<const Node*>* filter_concat_node) {
88+
std::unordered_set<const Node*> candidate_node;
89+
std::unordered_set<Node*> concat_nodes;
90+
for (auto* node : graph.nodes()) {
91+
if (node->type_string() == "ConcatV2") {
92+
concat_nodes.insert(node);
93+
}
94+
}
95+
std::unordered_set<std::string> in_out_names;
96+
for (auto& name : input_output_names) {
97+
in_out_names.insert(name);
98+
}
99+
for (const Node* c_nodes : concat_nodes) {
100+
std::vector<const Node*> in_placeholder;
101+
ReverseDFSFrom(
102+
graph, {c_nodes},
103+
[&in_placeholder, in_out_names](const Node* node) {
104+
if (in_out_names.find(node->name()) != in_out_names.end()) {
105+
in_placeholder.emplace_back(node);
106+
}
107+
},
108+
/*end*/ nullptr);
109+
if (in_placeholder.size() > 1) { // verify node in common sub-graph
110+
DataType t_types;
111+
TF_RETURN_IF_ERROR(GetNodeAttr(c_nodes->attrs(), "T", &t_types));
112+
if (t_types == DT_FLOAT) {
113+
candidate_node.insert(c_nodes);
114+
}
115+
}
116+
}
117+
118+
for (const Node* cnode : candidate_node) {
119+
bool is_admit = true;
120+
ReverseDFSFrom(graph, {cnode},
121+
[&filter_concat_node, &is_admit, candidate_node,
122+
cnode](const Node* node) {
123+
if ((candidate_node.find(node) != candidate_node.end()) &&
124+
(cnode->name() != node->name())) {
125+
is_admit = false;
126+
}
127+
},
128+
/*end*/ nullptr);
129+
if (is_admit) {
130+
filter_concat_node->emplace_back(cnode);
131+
}
132+
}
133+
return Status::OK();
134+
}
135+
85136
EngineInfo::EngineType GetEngineType(
86137
const TRTOptimizationPass::ConversionParams& params) {
87138
return (params.is_dynamic_op || params.use_calibration)
@@ -773,6 +824,14 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
773824
for (const auto& node : input_output_names) {
774825
segment_options.exclude_node_list.insert(node);
775826
}
827+
std::vector<const Node*> filter_concat_node;
828+
TF_RETURN_IF_ERROR(
829+
FindExtraConcatInput(graph, input_output_names, &filter_concat_node));
830+
for (const auto* node : filter_concat_node) {
831+
for (auto* inode : node->in_nodes()) {
832+
segment_options.exclude_node_list.insert(inode->name());
833+
}
834+
}
776835

777836
segment_options.minimum_segment_size = params.minimum_segment_size;
778837
segment_options.use_implicit_batch = params.use_implicit_batch;

0 commit comments

Comments
 (0)