Skip to content

Commit

Permalink
vstrt: implement flexible output
Browse files Browse the repository at this point in the history
The flexible output mode (enabled with "flexible_output_prop") can return arbitrary number of output planes.
In the initialization stage of the filter instance, the output is augmented with the number of output planes.
The actual output frames are stored in the property map of the return frame.

Therefore,

  prop = "planes"
  output = core.trt.Model(src, engine_path, flexible_output_prop=prop)
  planes = [output["clip"].std.PropToClip(prop=f"{prop}{i}") for i in range(output["num_planes"])]

will return a list of output planes.
  • Loading branch information
WolframRhodium committed May 13, 2024
1 parent 6af7646 commit a167eec
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
12 changes: 7 additions & 5 deletions vstrt/trt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ std::variant<ErrorMessage, InferenceInstance> getInstance(

static inline
std::optional<ErrorMessage> checkEngine(
const std::unique_ptr<nvinfer1::ICudaEngine> & engine
const std::unique_ptr<nvinfer1::ICudaEngine> & engine,
bool flexible_output
) noexcept {

#if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
Expand Down Expand Up @@ -506,8 +507,8 @@ std::optional<ErrorMessage> checkEngine(
}

int out_channels = output_dims.d[1];
if (out_channels != 1 && out_channels != 3) {
return "output dimensions must be 1 or 3";
if (out_channels != 1 && out_channels != 3 && !flexible_output) {
return "output dimensions must be 1 or 3, or enable \"flexible_output\"";
}

int in_height = input_dims.d[2];
Expand Down Expand Up @@ -547,7 +548,8 @@ std::optional<ErrorMessage> checkEngine(
static inline
std::variant<ErrorMessage, std::unique_ptr<nvinfer1::ICudaEngine>> initEngine(
const char * engine_data, size_t engine_nbytes,
const std::unique_ptr<nvinfer1::IRuntime> & runtime
const std::unique_ptr<nvinfer1::IRuntime> & runtime,
bool flexible_output
) noexcept {

const auto set_error = [](const ErrorMessage & error_message) {
Expand All @@ -562,7 +564,7 @@ std::variant<ErrorMessage, std::unique_ptr<nvinfer1::ICudaEngine>> initEngine(
return set_error("engine deserialization failed");
}

if (auto err = checkEngine(engine); err.has_value()) {
if (auto err = checkEngine(engine, flexible_output); err.has_value()) {
return set_error(err.value());
}

Expand Down
5 changes: 3 additions & 2 deletions vstrt/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ void setDimensions(
VSCore * core,
const VSAPI * vsapi,
int sample_type,
int bits_per_sample
int bits_per_sample,
bool flexible_output
) noexcept {

#if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
Expand All @@ -42,7 +43,7 @@ void setDimensions(
vi->height *= out_height / in_height;
vi->width *= out_width / in_width;

if (out_dims.d[1] == 1) {
if (out_dims.d[1] == 1 || flexible_output) {
vi->format = vsapi->registerFormat(cmGray, sample_type, bits_per_sample, 0, 0, core);
} else if (out_dims.d[1] == 3) {
vi->format = vsapi->registerFormat(cmRGB, sample_type, bits_per_sample, 0, 0, core);
Expand Down
59 changes: 53 additions & 6 deletions vstrt/vs_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ struct vsTrtData {
std::mutex instances_lock;
std::vector<InferenceInstance> instances;

std::string flexible_output_prop;

[[nodiscard]]
int acquire() noexcept {
semaphore.acquire();
Expand Down Expand Up @@ -168,6 +170,8 @@ static const VSFrameRef *VS_CC vsTrtGetFrame(
src_frames[0], core
)};

std::vector<VSFrameRef *> dst_frames;

#if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
auto output_name = d->engines[0]->getIOTensorName(1);
const nvinfer1::Dims dst_dim { instance.exec_context->getTensorShape(output_name) };
Expand All @@ -181,8 +185,19 @@ static const VSFrameRef *VS_CC vsTrtGetFrame(

std::vector<uint8_t *> dst_ptrs;
dst_ptrs.reserve(dst_planes);
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i));
if (d->flexible_output_prop.empty()) {
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i));
}
} else {
for (int i = 0; i < dst_planes; ++i) {
auto frame { vsapi->newVideoFrame(
d->out_vi->format, d->out_vi->width, d->out_vi->height,
src_frames[0], core
)};
dst_frames.emplace_back(frame);
dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0));
}
}

const int h_scale = dst_tile_h / src_tile_h;
Expand Down Expand Up @@ -225,10 +240,24 @@ static const VSFrameRef *VS_CC vsTrtGetFrame(
frameCtx
);

for (const auto & frame : dst_frames) {
vsapi->freeFrame(frame);
}

vsapi->freeFrame(dst_frame);

return nullptr;
}

if (!d->flexible_output_prop.empty()) {
auto prop = vsapi->getFramePropsRW(dst_frame);

for (int i = 0; i < dst_planes; i++) {
auto key { d->flexible_output_prop + std::to_string(i) };
vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace);
vsapi->freeFrame(dst_frames[i]);
}
}

return dst_frame;
}
Expand Down Expand Up @@ -365,6 +394,11 @@ static void VS_CC vsTrtCreate(
}
d->logger.set_verbosity(static_cast<nvinfer1::ILogger::Severity>(verbosity));

auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error);
if (!error) {
d->flexible_output_prop = flexible_output_prop;
}

#ifdef USE_NVINFER_PLUGIN
// related to https://github.com/AmusementClub/vs-mlrt/discussions/65, for unknown reason
#if !(NV_TENSORRT_MAJOR == 9 && defined(_WIN32))
Expand All @@ -391,7 +425,7 @@ static void VS_CC vsTrtCreate(
engine_stream.read(engine_data.get(), engine_nbytes);

d->runtime.reset(nvinfer1::createInferRuntime(d->logger));
auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime);
auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime, !d->flexible_output_prop.empty());
if (std::holds_alternative<std::unique_ptr<nvinfer1::ICudaEngine>>(maybe_engine)) {
d->engines.push_back(std::move(std::get<std::unique_ptr<nvinfer1::ICudaEngine>>(maybe_engine)));
} else {
Expand All @@ -417,7 +451,7 @@ static void VS_CC vsTrtCreate(
// https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-821/developer-guide/index.html#perform-inference
// each optimization profile can only have one execution context when using dynamic shapes
if (is_dynamic && i < d->num_streams - 1) {
auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime);
auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime, !d->flexible_output_prop.empty());
if (std::holds_alternative<std::unique_ptr<nvinfer1::ICudaEngine>>(maybe_engine)) {
d->engines.push_back(std::move(std::get<std::unique_ptr<nvinfer1::ICudaEngine>>(maybe_engine)));
} else {
Expand Down Expand Up @@ -490,9 +524,21 @@ static void VS_CC vsTrtCreate(

setDimensions(
d->out_vi, d->instances[0].exec_context, core, vsapi,
output_sample_type, output_bits_per_sample
output_sample_type, output_bits_per_sample,
!d->flexible_output_prop.empty()
);

if (!d->flexible_output_prop.empty()) {
const auto & exec_context = d->instances[0].exec_context;
#if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
auto output_name = exec_context->getEngine().getIOTensorName(1);
const nvinfer1::Dims & out_dims = exec_context->getTensorShape(output_name);
#else // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
const nvinfer1::Dims & out_dims = exec_context->getBindingDimensions(1);
#endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85
vsapi->propSetInt(out, "num_planes", out_dims.d[1], paReplace);
}

vsapi->createFilter(
in, out, "Model",
vsTrtInit, vsTrtGetFrame, vsTrtFree,
Expand Down Expand Up @@ -557,7 +603,8 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
"device_id:int:opt;"
"use_cuda_graph:int:opt;"
"num_streams:int:opt;"
"verbosity:int:opt;",
"verbosity:int:opt;"
"flexible_output_prop:data:opt;",
vsTrtCreate,
nullptr,
plugin
Expand Down

0 comments on commit a167eec

Please sign in to comment.