diff --git a/vstrt/trt_utils.h b/vstrt/trt_utils.h index 3c0feb9..dcb2d0f 100644 --- a/vstrt/trt_utils.h +++ b/vstrt/trt_utils.h @@ -441,7 +441,8 @@ std::variant getInstance( static inline std::optional checkEngine( - const std::unique_ptr & engine + const std::unique_ptr & engine, + bool flexible_output ) noexcept { #if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 @@ -506,8 +507,8 @@ std::optional checkEngine( } int out_channels = output_dims.d[1]; - if (out_channels != 1 && out_channels != 3) { - return "output dimensions must be 1 or 3"; + if (out_channels != 1 && out_channels != 3 && !flexible_output) { + return "output dimensions must be 1 or 3, or enable \"flexible_output\""; } int in_height = input_dims.d[2]; @@ -547,7 +548,8 @@ std::optional checkEngine( static inline std::variant> initEngine( const char * engine_data, size_t engine_nbytes, - const std::unique_ptr & runtime + const std::unique_ptr & runtime, + bool flexible_output ) noexcept { const auto set_error = [](const ErrorMessage & error_message) { @@ -562,7 +564,7 @@ std::variant> initEngine( return set_error("engine deserialization failed"); } - if (auto err = checkEngine(engine); err.has_value()) { + if (auto err = checkEngine(engine, flexible_output); err.has_value()) { return set_error(err.value()); } diff --git a/vstrt/utils.h b/vstrt/utils.h index 77fd555..b137f3a 100644 --- a/vstrt/utils.h +++ b/vstrt/utils.h @@ -20,7 +20,8 @@ void setDimensions( VSCore * core, const VSAPI * vsapi, int sample_type, - int bits_per_sample + int bits_per_sample, + bool flexible_output ) noexcept { #if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 @@ -42,7 +43,7 @@ void setDimensions( vi->height *= out_height / in_height; vi->width *= out_width / in_width; - if (out_dims.d[1] == 1) { + if (out_dims.d[1] == 1 || flexible_output) { vi->format = vsapi->registerFormat(cmGray, sample_type, bits_per_sample, 0, 0, core); } else if (out_dims.d[1] == 3) { vi->format = vsapi->registerFormat(cmRGB, sample_type, bits_per_sample, 0, 0, core); diff --git a/vstrt/vs_tensorrt.cpp b/vstrt/vs_tensorrt.cpp index 0f22832..2c2d5a9 100644 --- a/vstrt/vs_tensorrt.cpp +++ b/vstrt/vs_tensorrt.cpp @@ -82,6 +82,8 @@ struct vsTrtData { std::mutex instances_lock; std::vector instances; + std::string flexible_output_prop; + [[nodiscard]] int acquire() noexcept { semaphore.acquire(); @@ -168,6 +170,8 @@ static const VSFrameRef *VS_CC vsTrtGetFrame( src_frames[0], core )}; + std::vector dst_frames; + #if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 auto output_name = d->engines[0]->getIOTensorName(1); const nvinfer1::Dims dst_dim { instance.exec_context->getTensorShape(output_name) }; @@ -181,8 +185,19 @@ static const VSFrameRef *VS_CC vsTrtGetFrame( std::vector dst_ptrs; dst_ptrs.reserve(dst_planes); - for (int i = 0; i < dst_planes; ++i) { - dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i)); + if (d->flexible_output_prop.empty()) { + for (int i = 0; i < dst_planes; ++i) { + dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i)); + } + } else { + for (int i = 0; i < dst_planes; ++i) { + auto frame { vsapi->newVideoFrame( + d->out_vi->format, d->out_vi->width, d->out_vi->height, + src_frames[0], core + )}; + dst_frames.emplace_back(frame); + dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0)); + } } const int h_scale = dst_tile_h / src_tile_h; @@ -225,10 +240,24 @@ static const VSFrameRef *VS_CC vsTrtGetFrame( frameCtx ); + for (const auto & frame : dst_frames) { + vsapi->freeFrame(frame); + } + vsapi->freeFrame(dst_frame); return nullptr; } + + if (!d->flexible_output_prop.empty()) { + auto prop = vsapi->getFramePropsRW(dst_frame); + + for (int i = 0; i < dst_planes; i++) { + auto key { d->flexible_output_prop + std::to_string(i) }; + vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace); + vsapi->freeFrame(dst_frames[i]); + } + } return dst_frame; } @@ -365,6 +394,11 @@ static void VS_CC vsTrtCreate( } d->logger.set_verbosity(static_cast(verbosity)); + auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error); + if (!error) { + d->flexible_output_prop = flexible_output_prop; + } + #ifdef USE_NVINFER_PLUGIN // related to https://github.com/AmusementClub/vs-mlrt/discussions/65, for unknown reason #if !(NV_TENSORRT_MAJOR == 9 && defined(_WIN32)) @@ -391,7 +425,7 @@ static void VS_CC vsTrtCreate( engine_stream.read(engine_data.get(), engine_nbytes); d->runtime.reset(nvinfer1::createInferRuntime(d->logger)); - auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime); + auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime, !d->flexible_output_prop.empty()); if (std::holds_alternative>(maybe_engine)) { d->engines.push_back(std::move(std::get>(maybe_engine))); } else { @@ -417,7 +451,7 @@ static void VS_CC vsTrtCreate( // https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-821/developer-guide/index.html#perform-inference // each optimization profile can only have one execution context when using dynamic shapes if (is_dynamic && i < d->num_streams - 1) { - auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime); + auto maybe_engine = initEngine(engine_data.get(), engine_nbytes, d->runtime, !d->flexible_output_prop.empty()); if (std::holds_alternative>(maybe_engine)) { d->engines.push_back(std::move(std::get>(maybe_engine))); } else { @@ -490,9 +524,21 @@ static void VS_CC vsTrtCreate( setDimensions( d->out_vi, d->instances[0].exec_context, core, vsapi, - output_sample_type, output_bits_per_sample + output_sample_type, output_bits_per_sample, + !d->flexible_output_prop.empty() ); + if (!d->flexible_output_prop.empty()) { + const auto & exec_context = d->instances[0].exec_context; + #if NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 + auto output_name = exec_context->getEngine().getIOTensorName(1); + const nvinfer1::Dims & out_dims = exec_context->getTensorShape(output_name); + #else // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 + const nvinfer1::Dims & out_dims = exec_context->getBindingDimensions(1); + #endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85 + vsapi->propSetInt(out, "num_planes", out_dims.d[1], paReplace); + } + vsapi->createFilter( in, out, "Model", vsTrtInit, vsTrtGetFrame, vsTrtFree, @@ -557,7 +603,8 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( "device_id:int:opt;" "use_cuda_graph:int:opt;" "num_streams:int:opt;" - "verbosity:int:opt;", + "verbosity:int:opt;" + "flexible_output_prop:data:opt;", vsTrtCreate, nullptr, plugin