Skip to content

Commit

Permalink
vsov: implement flexible output
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed May 13, 2024
1 parent a167eec commit 325f54f
Showing 1 changed file with 58 additions and 13 deletions.
71 changes: 58 additions & 13 deletions vsov/vs_openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ static std::optional<std::string> checkNodes(
[[nodiscard]]
static std::optional<std::string> checkIOInfo(
const ov::Output<ov::Node> & info,
bool is_output
bool is_output,
bool flexible_output
) {

if (info.get_element_type() != ov::element::f32) {
Expand All @@ -124,8 +125,8 @@ static std::optional<std::string> checkIOInfo(

if (is_output) {
auto out_channels = dims[1];
if (out_channels != 1 && out_channels != 3) {
return "output dimensions must be 1 or 3";
if (out_channels != 1 && out_channels != 3 && !flexible_output) {
return "output dimensions must be 1 or 3, or enable \"flexible_output\"";
}
}

Expand All @@ -135,15 +136,16 @@ static std::optional<std::string> checkIOInfo(

[[nodiscard]]
static std::optional<std::string> checkNetwork(
const std::shared_ptr<ov::Model> & network
const std::shared_ptr<ov::Model> & network,
bool flexible_output
) {

if (auto num_inputs = std::size(network->inputs()); num_inputs != 1) {
return "network input count must be 1, got " + std::to_string(num_inputs);
}

const auto & input_info = network->input();
if (auto err = checkIOInfo(input_info, false); err.has_value()) {
if (auto err = checkIOInfo(input_info, false, flexible_output); err.has_value()) {
return err.value();
}

Expand All @@ -152,7 +154,7 @@ static std::optional<std::string> checkNetwork(
}

const auto & output_info = network->output();
if (auto err = checkIOInfo(output_info, true); err.has_value()) {
if (auto err = checkIOInfo(output_info, true, flexible_output); err.has_value()) {
return err.value();
}

Expand Down Expand Up @@ -193,7 +195,8 @@ static void setDimensions(
std::unique_ptr<VSVideoInfo> & vi,
const ov::CompiledModel & network,
VSCore * core,
const VSAPI * vsapi
const VSAPI * vsapi,
bool flexible_output
) {

const auto & in_dims = network.input().get_shape();
Expand All @@ -202,7 +205,7 @@ static void setDimensions(
vi->height *= out_dims[2] / in_dims[2];
vi->width *= out_dims[3] / in_dims[3];

if (out_dims[1] == 1) {
if (out_dims[1] == 1 || flexible_output) {
vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core);
} else if (out_dims[1] == 3) {
vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core);
Expand Down Expand Up @@ -273,6 +276,8 @@ struct OVData {
ov::CompiledModel executable_network;
std::unordered_map<std::thread::id, ov::InferRequest> infer_requests;
std::shared_mutex infer_requests_lock;

std::string flexible_output_prop;
};


Expand Down Expand Up @@ -344,6 +349,9 @@ static const VSFrameRef *VS_CC vsOvGetFrame(
d->out_vi->format, d->out_vi->width, d->out_vi->height,
src_frames.front(), core
);

std::vector<VSFrameRef *> dst_frames;

auto dst_stride = vsapi->getStride(dst_frame, 0);
auto dst_bytes = vsapi->getFrameFormat(dst_frame)->bytesPerSample;
auto dst_tile_shape = getShape(d->executable_network, false);
Expand All @@ -352,9 +360,21 @@ static const VSFrameRef *VS_CC vsOvGetFrame(
auto dst_tile_w_bytes = dst_tile_w * dst_bytes;
auto dst_tile_bytes = dst_tile_h * dst_tile_w_bytes;
auto dst_planes = dst_tile_shape[1];
std::array<uint8_t *, 3> dst_ptrs {};
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs[i] = vsapi->getWritePtr(dst_frame, i);

std::vector<uint8_t *> dst_ptrs;
if (d->flexible_output_prop.empty()) {
for (int i = 0; i < dst_planes; ++i) {
dst_ptrs.emplace_back(vsapi->getWritePtr(dst_frame, i));
}
} else {
for (int i = 0; i < dst_planes; ++i) {
auto frame { vsapi->newVideoFrame(
d->out_vi->format, d->out_vi->width, d->out_vi->height,
src_frames[0], core
)};
dst_frames.emplace_back(frame);
dst_ptrs.emplace_back(vsapi->getWritePtr(frame, 0));
}
}

auto h_scale = dst_tile_h / src_tile_h;
Expand All @@ -368,6 +388,10 @@ static const VSFrameRef *VS_CC vsOvGetFrame(

vsapi->freeFrame(dst_frame);

for (const auto & frame : dst_frames) {
vsapi->freeFrame(frame);
}

for (const auto & frame : src_frames) {
vsapi->freeFrame(frame);
}
Expand Down Expand Up @@ -474,6 +498,16 @@ static const VSFrameRef *VS_CC vsOvGetFrame(
vsapi->freeFrame(frame);
}

if (!d->flexible_output_prop.empty()) {
auto prop = vsapi->getFramePropsRW(dst_frame);

for (int i = 0; i < dst_planes; i++) {
auto key { d->flexible_output_prop + std::to_string(i) };
vsapi->propSetFrame(prop, key.c_str(), dst_frames[i], paReplace);
vsapi->freeFrame(dst_frames[i]);
}
}

return dst_frame;
}

Expand Down Expand Up @@ -615,6 +649,11 @@ static void VS_CC vsOvCreate(
path_view = path;
}

auto flexible_output_prop = vsapi->propGetData(in, "flexible_output_prop", 0, &error);
if (!error) {
d->flexible_output_prop = flexible_output_prop;
}

auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization);
if (std::holds_alternative<std::string>(result)) {
return set_error(std::get<std::string>(result));
Expand Down Expand Up @@ -657,7 +696,7 @@ static void VS_CC vsOvCreate(
return set_error("[Standard exception] ReadNetwork(): "s + e.what());
}

if (auto err = checkNetwork(network); err.has_value()) {
if (auto err = checkNetwork(network, !d->flexible_output_prop.empty()); err.has_value()) {
return set_error(err.value());
}

Expand Down Expand Up @@ -696,13 +735,18 @@ static void VS_CC vsOvCreate(
return set_error(err.value());
}

setDimensions(d->out_vi, d->executable_network, core, vsapi);
setDimensions(d->out_vi, d->executable_network, core, vsapi, !d->flexible_output_prop.empty());

VSCoreInfo core_info;
vsapi->getCoreInfo2(core, &core_info);
d->infer_requests.reserve(core_info.numThreads);
}

if (!d->flexible_output_prop.empty()) {
auto num_planes = d->executable_network.output(0).get_shape()[1];
vsapi->propSetInt(out, "num_planes", static_cast<int>(num_planes), paReplace);
}

vsapi->createFilter(
in, out, "Model",
vsOvInit, vsOvGetFrame, vsOvFree,
Expand Down Expand Up @@ -738,6 +782,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
#ifdef ENABLE_VISUALIZATION
"dot_path:data:opt;"
#endif
"flexible_output_prop:data:opt;"
, vsOvCreate,
nullptr,
plugin
Expand Down

0 comments on commit 325f54f

Please sign in to comment.