Skip to content

Commit

Permalink
MSE Node created
Browse files Browse the repository at this point in the history
  • Loading branch information
arabbitplays committed Dec 3, 2024
1 parent 41c1ba7 commit 60754ef
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 0 deletions.
60 changes: 60 additions & 0 deletions include/merian-nodes/nodes/mse_error/mse_error.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include "merian-nodes/connectors/managed_vk_buffer_out.hpp"
#include "merian-nodes/connectors/managed_vk_image_in.hpp"
#include "merian-nodes/graph/node.hpp"

#include "merian/vk/memory/resource_allocator.hpp"
#include "merian/vk/pipeline/pipeline.hpp"
#include "merian/vk/shader/shader_module.hpp"

namespace merian_nodes {

class ErrorToBuffer : public Node {
private:
static constexpr uint32_t local_size_x = 16;
static constexpr uint32_t local_size_y = 16;
static constexpr uint32_t workgroup_size = local_size_x * local_size_y;

struct PushConstant {
uint32_t divisor;

int size;
int offset;
int count;
};

public:
ErrorToBuffer(const ContextHandle context);

~ErrorToBuffer();

std::vector<InputConnectorHandle> describe_inputs() override;

std::vector<OutputConnectorHandle> describe_outputs(const NodeIOLayout& io_layout) override;

NodeStatusFlags on_connected([[maybe_unused]] const NodeIOLayout& io_layout,
const DescriptorSetLayoutHandle& descriptor_set_layout) override;

void process(GraphRun& run,
const vk::CommandBuffer& cmd,
const DescriptorSetHandle& descriptor_set,
const NodeIO& io) override;

private:
const ContextHandle context;

ManagedVkImageInHandle con_src1 = ManagedVkImageIn::compute_read("src1");
ManagedVkImageInHandle con_src2 = ManagedVkImageIn::compute_read("src2");
ManagedVkBufferOutHandle con_mean;

PushConstant pc;

ShaderModuleHandle image_to_buffer_shader;
ShaderModuleHandle reduce_buffer_shader;

PipelineHandle image_to_buffer;
PipelineHandle reduce_buffer;
};

} // namespace merian_nodes
4 changes: 4 additions & 0 deletions src/merian-nodes/graph/node_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "merian-nodes/nodes/image_write/image_write.hpp"
#include "merian-nodes/nodes/mean/mean.hpp"
#include "merian-nodes/nodes/median_approx/median.hpp"
#include "merian-nodes/nodes/mse_error/mse_error.hpp"
#include "merian-nodes/nodes/svgf/svgf.hpp"
#include "merian-nodes/nodes/taa/taa.hpp"
#include "merian-nodes/nodes/tonemap/tonemap.hpp"
Expand Down Expand Up @@ -66,6 +67,9 @@ NodeRegistry::NodeRegistry(const ContextHandle& context, const ResourceAllocator
register_node<MedianApproxNode>(NodeInfo{
"Median (Approximation)", "Computes an approximation of the median of a component.",
[=]() { return std::make_shared<MedianApproxNode>(context); }});
register_node<ErrorToBuffer>(
NodeInfo{"MSE Error", "Computes the mean square error of two images and outputs it as a single buffer element.",
[=]() { return std::make_shared<ErrorToBuffer>(context); }});
register_node<Shadertoy>(NodeInfo{"Shadertoy",
"Execute Shadertoy-like shaders (Limited implementation).",
[=]() { return std::make_shared<Shadertoy>(context); }});
Expand Down
1 change: 1 addition & 0 deletions src/merian-nodes/nodes/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ subdir('image_read')
subdir('image_write')
subdir('mean')
subdir('median_approx')
subdir('mse_error')
subdir('shadertoy')
subdir('svgf')
subdir('taa')
Expand Down
21 changes: 21 additions & 0 deletions src/merian-nodes/nodes/mse_error/layout.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#extension GL_EXT_scalar_block_layout : enable

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;

layout (constant_id = 2) const int SUBGROUP_SIZE = 0;


layout(set = 0, binding = 0) uniform sampler2D img_src1;
layout(set = 0, binding = 1) uniform sampler2D img_src2;

layout(set = 0, binding = 2, scalar) buffer restrict buf_result {
vec4 result[];
};

layout(push_constant) uniform PushStruct {
uint divisor;

int size;
int offset;
int count;
} params;
10 changes: 10 additions & 0 deletions src/merian-nodes/nodes/mse_error/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
shaders = [
'mse_image_to_buffer.comp',
'mse_reduce_buffer.comp',
]

foreach s : shaders
merian_nodes_src += shader_generator.process(s)
endforeach

merian_nodes_src += files('mse_error.cpp')
118 changes: 118 additions & 0 deletions src/merian-nodes/nodes/mse_error/mse_error.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#include "merian-nodes/nodes/mse_error/mse_error.hpp"

#include "merian/vk/pipeline/pipeline_compute.hpp"
#include "merian/vk/pipeline/pipeline_layout_builder.hpp"
#include "merian/vk/pipeline/specialization_info_builder.hpp"
#include "merian-nodes/graph/errors.hpp"

#include "image_to_buffer.comp.spv.h"
#include "reduce_buffer.comp.spv.h"

namespace merian_nodes {

ErrorToBuffer::ErrorToBuffer(const ContextHandle context) : Node(), context(context) {

image_to_buffer_shader = std::make_shared<ShaderModule>(
context, merian_image_to_buffer_comp_spv_size(), merian_image_to_buffer_comp_spv());
reduce_buffer_shader = std::make_shared<ShaderModule>(
context, merian_reduce_buffer_comp_spv_size(), merian_reduce_buffer_comp_spv());
}

ErrorToBuffer::~ErrorToBuffer() {}

std::vector<InputConnectorHandle> ErrorToBuffer::describe_inputs() {
return {con_src1, con_src2};
}

std::vector<OutputConnectorHandle> ErrorToBuffer::describe_outputs(const NodeIOLayout& io_layout) {
if (io_layout[con_src1]->create_info.extent != io_layout[con_src2]->create_info.extent) {
throw graph_errors::node_error{"image extents mismatch!"};
}
vk::Extent3D extent = io_layout[con_src1]->create_info.extent;

const auto group_count_x = (extent.width + local_size_x - 1) / local_size_x;
const auto group_count_y = (extent.height + local_size_y - 1) / local_size_y;
const std::size_t buffer_size = group_count_x * group_count_y;

con_mean = std::make_shared<ManagedVkBufferOut>(
"mean", vk::AccessFlagBits2::eShaderRead | vk::AccessFlagBits2::eShaderWrite,
vk::PipelineStageFlagBits2::eComputeShader, vk::ShaderStageFlagBits::eCompute,
vk::BufferCreateInfo({}, buffer_size * sizeof(glm::vec4),
vk::BufferUsageFlagBits::eStorageBuffer));

return {con_mean};
}

ErrorToBuffer::NodeStatusFlags
ErrorToBuffer::on_connected([[maybe_unused]] const NodeIOLayout& io_layout,
const DescriptorSetLayoutHandle& descriptor_set_layout) {
if (!image_to_buffer) {
auto pipe_layout = PipelineLayoutBuilder(context)
.add_descriptor_set_layout(descriptor_set_layout)
.add_push_constant<PushConstant>()
.build_pipeline_layout();
auto image_to_buffer_spec_builder = SpecializationInfoBuilder();
image_to_buffer_spec_builder.add_entry(
local_size_x, local_size_y,
context->physical_device.physical_device_subgroup_properties.subgroupSize);
SpecializationInfoHandle spec = image_to_buffer_spec_builder.build();
image_to_buffer =
std::make_shared<ComputePipeline>(pipe_layout, image_to_buffer_shader, spec);

auto reduce_buffer_spec_builder = SpecializationInfoBuilder();
reduce_buffer_spec_builder.add_entry(
local_size_x * local_size_y, 1,
context->physical_device.physical_device_subgroup_properties.subgroupSize);
spec = reduce_buffer_spec_builder.build();
reduce_buffer = std::make_shared<ComputePipeline>(pipe_layout, reduce_buffer_shader, spec);
}

return {};
}

void ErrorToBuffer::process([[maybe_unused]] GraphRun& run,
const vk::CommandBuffer& cmd,
const DescriptorSetHandle& descriptor_set,
const NodeIO& io) {
const auto group_count_x = (io[con_src1]->get_extent().width + local_size_x - 1) / local_size_x;
const auto group_count_y = (io[con_src1]->get_extent().height + local_size_y - 1) / local_size_y;

pc.divisor = io[con_src1]->get_extent().width * io[con_src1]->get_extent().height;

{
MERIAN_PROFILE_SCOPE_GPU(run.get_profiler(), cmd, "images to buffer");
image_to_buffer->bind(cmd);
image_to_buffer->bind_descriptor_set(cmd, descriptor_set);
image_to_buffer->push_constant(cmd, pc);
cmd.dispatch(group_count_x, group_count_y, 1);
}

pc.size = group_count_x * group_count_y;
pc.offset = 1;
pc.count = group_count_x * group_count_y;

while (pc.count > 1) {
MERIAN_PROFILE_SCOPE_GPU(run.get_profiler(), cmd,
fmt::format("reduce {} elements", pc.count));
auto bar = io[con_mean]->buffer_barrier(
vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite,
vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite);
cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eComputeShader, {}, {}, bar, {});

reduce_buffer->bind(cmd);
reduce_buffer->bind_descriptor_set(cmd, descriptor_set);
reduce_buffer->push_constant(cmd, pc);
cmd.dispatch((pc.count + workgroup_size - 1) / workgroup_size, 1, 1);

pc.count = (pc.count + workgroup_size - 1) / workgroup_size;
pc.offset *= workgroup_size;
}

/*run.add_submit_callback([&](const QueueHandle& queue, GraphRun& run) {
queue->wait_idle();
descriptor_set.
});*/
}

} // namespace merian_nodes
39 changes: 39 additions & 0 deletions src/merian-nodes/nodes/mse_error/mse_image_to_buffer.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#version 460
#extension GL_GOOGLE_include_directive : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable

#include "layout.glsl"

shared vec4 subgroup_sums[gl_WorkGroupSize.x * gl_WorkGroupSize.y / SUBGROUP_SIZE + 1];

void main() {
const ivec2 ipos = ivec2(gl_GlobalInvocationID);

vec4 v;
if (any(greaterThanEqual(ipos, textureSize(img_src1, 0))) || any(greaterThanEqual(ipos, textureSize(img_src2, 0)))) {
v = vec4(0);
} else {
v = (texelFetch(img_src1, ipos, 0) - texelFetch(img_src2, ipos, 0));
v= v * v / params.divisor;
}

v.x = subgroupAdd(v.x);
v.y = subgroupAdd(v.y);
v.z = subgroupAdd(v.z);
v.w = subgroupAdd(v.w);

if (subgroupElect()) {
subgroup_sums[gl_SubgroupID] = v;
}

barrier();

if (gl_LocalInvocationIndex == 0) {
vec4 sum = vec4(0);
for (int i = 0; i < gl_NumSubgroups; i++) {
sum += subgroup_sums[i];
}

result[gl_WorkGroupID.y + gl_NumWorkGroups.y * gl_WorkGroupID.x] = sum;
}
}
39 changes: 39 additions & 0 deletions src/merian-nodes/nodes/mse_error/mse_reduce_buffer.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#version 460
#extension GL_GOOGLE_include_directive : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable

#include "layout.glsl"

shared vec4 subgroup_sums[gl_WorkGroupSize.x * gl_WorkGroupSize.y / SUBGROUP_SIZE + 1];

void main() {
uint index = params.offset * (gl_LocalInvocationIndex + gl_WorkGroupID.x * gl_WorkGroupSize.x);

vec4 v;
if (index < params.size) {
v = result[index];
} else {
v = vec4(0);
}

v.x = subgroupAdd(v.x);
v.y = subgroupAdd(v.y);
v.z = subgroupAdd(v.z);
v.w = subgroupAdd(v.w);

if (subgroupElect()) {
subgroup_sums[gl_SubgroupID] = v;
}

barrier();

if (gl_LocalInvocationIndex == 0) {
vec4 sum = vec4(0);
for (int i = 0; i < gl_NumSubgroups; i++) {
sum += subgroup_sums[i];
}

result[index] = sum;
}

}

0 comments on commit 60754ef

Please sign in to comment.