Skip to content

Commit 0eaa93e

Browse files
committed
[TensorRT] Upgrade TF-TRT version to TF2's implementation.
Signed-off-by: 泊霆 <[email protected]>
1 parent 5eabe5f commit 0eaa93e

File tree

100 files changed

+26474
-8636
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+26474
-8636
lines changed

tensorflow/compiler/tf2tensorrt/BUILD

+559-103
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
def get_linked_tensorrt_version() -> tuple[int,int,int]: ...
17+
def get_loaded_tensorrt_version() -> tuple[int,int,int]: ...
18+
def get_registered_op_converters() -> list[str]: ...
19+
def is_tensorrt_enabled() -> bool: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_DATAVEC_H_
17+
#define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_DATAVEC_H_
18+
19+
#include <vector>
20+
21+
#include "tensorflow/core/framework/tensor.h"
22+
#include "tensorflow/core/platform/types.h"
23+
24+
namespace tensorflow {
25+
namespace tensorrt {
26+
27+
// Input/output data format for OpConverterTest::BuildAndRun().
28+
struct InputOutputData {
29+
size_t TotalBytes() const { return tensor.TotalBytes(); }
30+
string name;
31+
Tensor tensor;
32+
};
33+
34+
using DataVec = std::vector<InputOutputData>;
35+
36+
} // namespace tensorrt
37+
} // namespace tensorflow
38+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
17+
18+
#include <tuple>
19+
20+
#if GOOGLE_CUDA && GOOGLE_TENSORRT
21+
#include "absl/base/call_once.h"
22+
#include "absl/strings/str_cat.h"
23+
#include "absl/strings/str_join.h"
24+
#include "tensorflow/core/lib/core/errors.h"
25+
#include "tensorflow/core/profiler/lib/traceme.h"
26+
#include "third_party/tensorrt/NvInferPlugin.h"
27+
28+
#endif
29+
30+
namespace tensorflow {
31+
namespace tensorrt {
32+
33+
std::tuple<int, int, int> GetLinkedTensorRTVersion() {
34+
#if GOOGLE_CUDA && GOOGLE_TENSORRT
35+
return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
36+
NV_TENSORRT_PATCH};
37+
#else
38+
return std::tuple<int, int, int>{0, 0, 0};
39+
#endif
40+
}
41+
42+
std::tuple<int, int, int> GetLoadedTensorRTVersion() {
43+
#if GOOGLE_CUDA && GOOGLE_TENSORRT
44+
int ver = getInferLibVersion();
45+
int major = ver / 1000;
46+
ver = ver - major * 1000;
47+
int minor = ver / 100;
48+
int patch = ver - minor * 100;
49+
return std::tuple<int, int, int>{major, minor, patch};
50+
#else
51+
return std::tuple<int, int, int>{0, 0, 0};
52+
#endif
53+
}
54+
55+
} // namespace tensorrt
56+
} // namespace tensorflow
57+
58+
#if GOOGLE_CUDA && GOOGLE_TENSORRT
59+
namespace tensorflow {
60+
namespace tensorrt {
61+
62+
Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
63+
const nvinfer1::ICudaEngine* cuda_engine,
64+
int* binding_index) {
65+
tensorflow::profiler::TraceMe activity(
66+
"GetTrtBindingIndex", tensorflow::profiler::TraceMeLevel::kInfo);
67+
// If the engine has been built for K profiles, the first getNbBindings() / K
68+
// bindings are used by profile number 0, the following getNbBindings() / K
69+
// bindings are used by profile number 1 etc.
70+
//
71+
// GetBindingIndex(tensor_name) returns the binding index for the progile 0.
72+
// We can also consider it as a "binding_index_within_profile".
73+
*binding_index = cuda_engine->getBindingIndex(tensor_name);
74+
if (*binding_index == -1) {
75+
const string msg = absl::StrCat("Input node ", tensor_name, " not found");
76+
return errors::NotFound(msg);
77+
}
78+
int n_profiles = cuda_engine->getNbOptimizationProfiles();
79+
// If we have more then one optimization profile, then we need to shift the
80+
// binding index according to the following formula:
81+
// binding_index_within_engine = binding_index_within_profile +
82+
// profile_index * bindings_per_profile
83+
const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles;
84+
*binding_index = *binding_index + profile_index * bindings_per_profile;
85+
return Status::OK();
86+
}
87+
88+
Status GetTrtBindingIndex(int network_input_index, int profile_index,
89+
const nvinfer1::ICudaEngine* cuda_engine,
90+
int* binding_index) {
91+
const string input_name =
92+
absl::StrCat(IONamePrefixes::kInputPHName, network_input_index);
93+
return GetTrtBindingIndex(input_name.c_str(), profile_index, cuda_engine,
94+
binding_index);
95+
}
96+
97+
namespace {
98+
99+
void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
100+
#if defined(PLATFORM_WINDOWS)
101+
LOG_WARNING_WITH_PREFIX
102+
<< "Windows support is provided experimentally. No guarantee is made "
103+
"regarding functionality or engineering support. Use at your own "
104+
"risk.";
105+
#endif
106+
LOG(INFO) << "Linked TensorRT version: "
107+
<< absl::StrJoin(GetLinkedTensorRTVersion(), ".");
108+
LOG(INFO) << "Loaded TensorRT version: "
109+
<< absl::StrJoin(GetLoadedTensorRTVersion(), ".");
110+
111+
bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
112+
if (!plugin_initialized) {
113+
LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
114+
"fail later.";
115+
}
116+
117+
int num_trt_plugins = 0;
118+
nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
119+
getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
120+
if (!trt_plugin_creator_list) {
121+
LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
122+
} else {
123+
VLOG(1) << "Found the following " << num_trt_plugins
124+
<< " TensorRT plugins in registry:";
125+
for (int i = 0; i < num_trt_plugins; ++i) {
126+
if (!trt_plugin_creator_list[i]) {
127+
LOG_WARNING_WITH_PREFIX
128+
<< "TensorRT plugin at index " << i
129+
<< " is not accessible (null pointer returned by "
130+
"getPluginCreatorList for this plugin)";
131+
} else {
132+
VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName();
133+
}
134+
}
135+
}
136+
}
137+
138+
} // namespace
139+
140+
void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
141+
static absl::once_flag once;
142+
absl::call_once(once, InitializeTrtPlugins, trt_logger);
143+
}
144+
145+
} // namespace tensorrt
146+
} // namespace tensorflow
147+
148+
namespace nvinfer1 {
149+
std::ostream& operator<<(std::ostream& os,
150+
const nvinfer1::TensorFormat& format) {
151+
os << "nvinfer1::TensorFormat::";
152+
switch (format) {
153+
case nvinfer1::TensorFormat::kLINEAR:
154+
os << "kLINEAR";
155+
break;
156+
157+
case nvinfer1::TensorFormat::kCHW2:
158+
os << "kCHW2";
159+
break;
160+
161+
case nvinfer1::TensorFormat::kHWC8:
162+
os << "kHWC8";
163+
break;
164+
165+
case nvinfer1::TensorFormat::kCHW4:
166+
os << "kCHW4";
167+
break;
168+
169+
case nvinfer1::TensorFormat::kCHW16:
170+
os << "kCHW16";
171+
break;
172+
173+
case nvinfer1::TensorFormat::kCHW32:
174+
os << "kCHW32";
175+
break;
176+
177+
#if IS_TRT_VERSION_GE(8, 0, 0, 0)
178+
case nvinfer1::TensorFormat::kDHWC8:
179+
os << "kDHWC8";
180+
break;
181+
182+
case nvinfer1::TensorFormat::kCDHW32:
183+
os << "kCDHW32";
184+
break;
185+
186+
case nvinfer1::TensorFormat::kHWC:
187+
os << "kHWC";
188+
break;
189+
190+
case nvinfer1::TensorFormat::kDLA_LINEAR:
191+
os << "kDLA_LINEAR";
192+
break;
193+
194+
case nvinfer1::TensorFormat::kDLA_HWC4:
195+
os << "kDLA_HWC4";
196+
break;
197+
198+
case nvinfer1::TensorFormat::kHWC16:
199+
os << "kHWC16";
200+
break;
201+
#endif
202+
203+
default:
204+
os << "unknown format";
205+
}
206+
return os;
207+
}
208+
209+
std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& v) {
210+
os << "nvinfer1::DataType::";
211+
switch (v) {
212+
case nvinfer1::DataType::kFLOAT:
213+
os << "kFLOAT";
214+
break;
215+
case nvinfer1::DataType::kHALF:
216+
os << "kHalf";
217+
break;
218+
#if IS_TRT_VERSION_GE(8, 6, 0, 0)
219+
case nvinfer1::DataType::kFP8:
220+
os << "kFP8";
221+
break;
222+
#endif
223+
case nvinfer1::DataType::kINT8:
224+
os << "kINT8";
225+
break;
226+
case nvinfer1::DataType::kINT32:
227+
os << "kINT32";
228+
break;
229+
case nvinfer1::DataType::kBOOL:
230+
os << "kBOOL";
231+
break;
232+
#if IS_TRT_VERSION_GE(8, 5, 0, 0)
233+
case nvinfer1::DataType::kUINT8:
234+
os << "kUINT8";
235+
break;
236+
#endif
237+
}
238+
return os;
239+
}
240+
} // namespace nvinfer1
241+
242+
#endif

0 commit comments

Comments
 (0)