Skip to content

Commit

Permalink
Update TensorrtAPI to TensorRT 10
Browse files Browse the repository at this point in the history
* delete retrieve_indices_by_name()
* add member SampleUniquePtr<IRuntime> runtime
* replace getBindingDimensions() by getTensorShape()
* replace setBindingDimensions() by setInputShape()
* add link_libraries(stdc++fs) to CMakeLists.txt
* add include_directories("$ENV{TENSORRT_PATH}/samples/") to
CMakeLists.txt
  • Loading branch information
QueensGambit committed Sep 18, 2024
1 parent 09b5b5a commit e8f86ab
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 68 deletions.
3 changes: 2 additions & 1 deletion engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ include_directories("src/domain/crazyhouse")
include_directories("src/agents")
include_directories("src/agents/config")
include_directories("src/nn")

link_libraries(stdc++fs)

if (BACKEND_MXNET)
IF(DEFINED ENV{MXNET_PATH})
Expand Down Expand Up @@ -487,6 +487,7 @@ if (BACKEND_TENSORRT)
endif()
include_directories("$ENV{TENSORRT_PATH}/include")
include_directories("$ENV{TENSORRT_PATH}/samples/common/")
include_directories("$ENV{TENSORRT_PATH}/samples/")
add_definitions(-DTENSORRT)
endif()

Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/chess_related/chessbatchstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ int ChessBatchStream::getBatchSize() const

nvinfer1::Dims ChessBatchStream::getDims() const
{
Dims dims;
nvinfer1::Dims dims;
dims.nbDims = 4;
dims.d[0] = mBatchSize;
dims.d[1] = mDims.d[0];
Expand Down
2 changes: 1 addition & 1 deletion engine/src/environments/chess_related/chessbatchstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ChessBatchStream : public IBatchStream
int mBatchSize{0};
int mBatchCount{0};
int mMaxBatches{0};
Dims mDims{};
nvinfer1::Dims mDims{};
std::vector<float> mData;
std::vector<float> mLabels{};
};
Expand Down
79 changes: 23 additions & 56 deletions engine/src/nn/tensorrtapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,59 +88,15 @@ void TensorrtAPI::load_parameters()
// do nothing
}

bool TensorrtAPI::retrieve_indices_by_name(bool verbose)
{
idxInput = engine->getBindingIndex(nnDesign.inputLayerName.c_str());
if (idxInput == -1) {
info_string_important("Layer name '" + nnDesign.inputLayerName + "' not found.");
return false;
}
idxValueOutput = engine->getBindingIndex(nnDesign.valueOutputName.c_str());
if (idxValueOutput == -1) {
info_string_important("Layer name '" + nnDesign.valueOutputName + "' not found.");
return false;
}
idxPolicyOutput = engine->getBindingIndex(nnDesign.policySoftmaxOutputName.c_str());
if (idxPolicyOutput == -1) {
info_string_important("Layer name '" + nnDesign.policySoftmaxOutputName + "' not found.");
return false;
}
if (nnDesign.hasAuxiliaryOutputs) {
idxAuxiliaryOutput = engine->getBindingIndex(nnDesign.auxiliaryOutputName.c_str());
if (idxAuxiliaryOutput == -1) {
info_string_important("Layer name '" + nnDesign.auxiliaryOutputName + "' not found.");
return false;
}
}
if (verbose) {
info_string("Found 'idxInput' at index", idxInput);
info_string("Found 'idxValueOutput' at index", idxValueOutput);
info_string("Found 'idxPolicyOutput' at index", idxPolicyOutput);
if (nnDesign.hasAuxiliaryOutputs) {
info_string("Found 'idxAuxiliaryOutput' at index", idxAuxiliaryOutput);
}
}
return true;
}

void TensorrtAPI::init_nn_design()
{
nnDesign.hasAuxiliaryOutputs = engine->getNbBindings() > 3;
if (!retrieve_indices_by_name(generatedTrtFromONNX)) {
info_string_important("Fallback to default indices.");
idxInput = nnDesign.inputIdx;
idxValueOutput = nnDesign.valueOutputIdx + nnDesign.nbInputs;
idxPolicyOutput = nnDesign.policyOutputIdx + nnDesign.nbInputs;
idxAuxiliaryOutput = nnDesign.auxiliaryOutputIdx + nnDesign.nbInputs;
}

set_shape(nnDesign.inputShape, engine->getBindingDimensions(idxInput));
set_shape(nnDesign.inputShape, engine->getTensorShape(nnDesign.inputLayerName.c_str()));
// make sure that the first dimension is the batch size, otherwise '-1' could cause problems
nnDesign.inputShape.v[0] = batchSize;
set_shape(nnDesign.valueOutputShape, engine->getBindingDimensions(idxValueOutput));
set_shape(nnDesign.policyOutputShape, engine->getBindingDimensions(idxPolicyOutput));
set_shape(nnDesign.valueOutputShape, engine->getTensorShape(nnDesign.valueOutputName.c_str()));
set_shape(nnDesign.policyOutputShape, engine->getTensorShape(nnDesign.policySoftmaxOutputName.c_str()));
if (nnDesign.hasAuxiliaryOutputs) {
set_shape(nnDesign.auxiliaryOutputShape, engine->getBindingDimensions(idxAuxiliaryOutput));
set_shape(nnDesign.auxiliaryOutputShape, engine->getTensorShape(nnDesign.auxiliaryOutputName.c_str()));
}
nnDesign.isPolicyMap = unsigned(nnDesign.policyOutputShape.v[1]) != StateConstants::NB_LABELS();
}
Expand All @@ -151,7 +107,7 @@ void TensorrtAPI::bind_executor()
context = SampleUniquePtr<nvinfer1::IExecutionContext>(engine->createExecutionContext());
Dims inputDims;
set_dims(inputDims, nnDesign.inputShape);
context->setBindingDimensions(0, inputDims);
context->setInputShape(nnDesign.inputLayerName.c_str(), inputDims);

// create buffers object with respect to the engine and batch size
CHECK(cudaStreamCreate(&stream));
Expand Down Expand Up @@ -184,8 +140,19 @@ void TensorrtAPI::predict(float* inputPlanes, float* valueOutput, float* probOut
CHECK(cudaMemcpyAsync(deviceMemory[idxInput], inputPlanes, memorySizes[idxInput],
cudaMemcpyHostToDevice, stream));

context->setTensorAddress(nnDesign.inputLayerName.c_str(), deviceMemory[idxInput]);
context->setTensorAddress(nnDesign.valueOutputName.c_str(), deviceMemory[idxValueOutput]);
context->setTensorAddress(nnDesign.policySoftmaxOutputName.c_str(), deviceMemory[idxPolicyOutput]);
#ifdef DYNAMIC_NN_ARCH
if (has_auxiliary_outputs()) {
#else
if (StateConstants::NB_AUXILIARY_OUTPUTS()) {
#endif
context->setTensorAddress(nnDesign.auxiliaryOutputName.c_str(), deviceMemory[idxAuxiliaryOutput]);
}

// run inference for given data
context->enqueueV2(deviceMemory, stream, nullptr);
context->enqueueV3(stream);

// copy output from device back to host
CHECK(cudaMemcpyAsync(valueOutput, deviceMemory[idxValueOutput],
Expand All @@ -209,7 +176,6 @@ ICudaEngine* TensorrtAPI::create_cuda_engine_from_onnx()
info_string("This may take a few minutes...");
// create an engine builder
SampleUniquePtr<IBuilder> builder = SampleUniquePtr<IBuilder>(createInferBuilder(gLogger.getTRTLogger()));
builder->setMaxBatchSize(int(batchSize));

// create an ONNX network object
const uint32_t explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
Expand All @@ -232,7 +198,7 @@ ICudaEngine* TensorrtAPI::create_cuda_engine_from_onnx()
SampleUniquePtr<nvinfer1::IBuilderConfig> config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
unique_ptr<IInt8Calibrator> calibrator;
unique_ptr<IBatchStream> calibrationStream;
set_config_settings(config, 1_GiB, calibrator, calibrationStream);
set_config_settings(config, calibrator, calibrationStream);

IOptimizationProfile* profile = builder->createOptimizationProfile();

Expand All @@ -243,12 +209,14 @@ ICudaEngine* TensorrtAPI::create_cuda_engine_from_onnx()
profile->setDimensions(nnDesign.inputLayerName.c_str(), OptProfileSelector::kMAX, inputDims);
config->addOptimizationProfile(profile);

nnDesign.hasAuxiliaryOutputs = network->getNbOutputs() > 2;

// build an engine from the TensorRT network with a given configuration struct
#ifdef TENSORRT7
return builder->buildEngineWithConfig(*network, *config);
#else
SampleUniquePtr<IHostMemory> serializedModel{builder->buildSerializedNetwork(*network, *config)};
SampleUniquePtr<IRuntime> runtime{createInferRuntime(sample::gLogger.getTRTLogger())};
runtime = SampleUniquePtr<IRuntime>(createInferRuntime(sample::gLogger.getTRTLogger()));

// build an engine from the serialized model
return runtime->deserializeCudaEngine(serializedModel->data(), serializedModel->size());;
Expand All @@ -263,7 +231,7 @@ ICudaEngine* TensorrtAPI::get_cuda_engine() {
const char* buffer = read_buffer(trtFilePath, bufferSize);
if (buffer) {
info_string("deserialize engine:", trtFilePath);
unique_ptr<IRuntime, samplesCommon::InferDeleter> runtime{createInferRuntime(gLogger)};
runtime = unique_ptr<IRuntime, samplesCommon::InferDeleter>{createInferRuntime(gLogger)};
#ifdef TENSORRT7
engine = runtime->deserializeCudaEngine(buffer, bufferSize, nullptr);
#else
Expand Down Expand Up @@ -293,10 +261,9 @@ ICudaEngine* TensorrtAPI::get_cuda_engine() {
}

void TensorrtAPI::set_config_settings(SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
size_t maxWorkspace, unique_ptr<IInt8Calibrator>& calibrator,
unique_ptr<IInt8Calibrator>& calibrator,
unique_ptr<IBatchStream>& calibrationStream)
{
config->setMaxWorkspaceSize(maxWorkspace);
switch (precision) {
case float32:
// default: do nothing
Expand Down
12 changes: 3 additions & 9 deletions engine/src/nn/tensorrtapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "BatchStream.h"

using namespace std;
using namespace nvinfer1;

enum Precision {
float32,
Expand Down Expand Up @@ -77,6 +78,7 @@ class TensorrtAPI : public NeuralNetAPI
string trtFilePath;
std::shared_ptr<nvinfer1::ICudaEngine> engine;
SampleUniquePtr<nvinfer1::IExecutionContext> context;
SampleUniquePtr<IRuntime> runtime;
cudaStream_t stream;
bool generatedTrtFromONNX;
public:
Expand All @@ -93,13 +95,6 @@ class TensorrtAPI : public NeuralNetAPI

void predict(float* inputPlanes, float* valueOutput, float* probOutputs, float* auxiliaryOutputs) override;

/**
* @brief retrieve_indices_by_name Sets the layer name indices by names.
* @param verbose If true debug info will be shown
* @return True if all layer names were found, else false
*/
bool retrieve_indices_by_name(bool verbose);

private:
void load_model() override;
void load_parameters() override;
Expand All @@ -123,12 +118,11 @@ class TensorrtAPI : public NeuralNetAPI
/**
* @brief set_config_settings Sets the configuration object which will be later used to build the engine
* @param config Configuration object
* @param maxWorkspace Maximum allowable GPU work space for TensorRT tactic selection (e.g. 16_MiB, 1_GiB)
* @param calibrator INT8 calibration object
* @param calibrationStream Calibration stream used for INT8 calibration
*/
void set_config_settings(SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
size_t maxWorkspace, unique_ptr<IInt8Calibrator>& calibrator,
unique_ptr<IInt8Calibrator>& calibrator,
unique_ptr<IBatchStream>& calibrationStream);


Expand Down

0 comments on commit e8f86ab

Please sign in to comment.