From 60754efeb46e3c15689287aee1319a6171299ef9 Mon Sep 17 00:00:00 2001 From: oschdi Date: Tue, 3 Dec 2024 18:10:32 +0100 Subject: [PATCH] MSE Node created --- .../nodes/mse_error/mse_error.hpp | 60 +++++++++ src/merian-nodes/graph/node_registry.cpp | 4 + src/merian-nodes/nodes/meson.build | 1 + src/merian-nodes/nodes/mse_error/layout.glsl | 21 ++++ src/merian-nodes/nodes/mse_error/meson.build | 10 ++ .../nodes/mse_error/mse_error.cpp | 118 ++++++++++++++++++ .../nodes/mse_error/mse_image_to_buffer.comp | 39 ++++++ .../nodes/mse_error/mse_reduce_buffer.comp | 39 ++++++ 8 files changed, 292 insertions(+) create mode 100644 include/merian-nodes/nodes/mse_error/mse_error.hpp create mode 100644 src/merian-nodes/nodes/mse_error/layout.glsl create mode 100644 src/merian-nodes/nodes/mse_error/meson.build create mode 100644 src/merian-nodes/nodes/mse_error/mse_error.cpp create mode 100644 src/merian-nodes/nodes/mse_error/mse_image_to_buffer.comp create mode 100644 src/merian-nodes/nodes/mse_error/mse_reduce_buffer.comp diff --git a/include/merian-nodes/nodes/mse_error/mse_error.hpp b/include/merian-nodes/nodes/mse_error/mse_error.hpp new file mode 100644 index 00000000..91607033 --- /dev/null +++ b/include/merian-nodes/nodes/mse_error/mse_error.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "merian-nodes/connectors/managed_vk_buffer_out.hpp" +#include "merian-nodes/connectors/managed_vk_image_in.hpp" +#include "merian-nodes/graph/node.hpp" + +#include "merian/vk/memory/resource_allocator.hpp" +#include "merian/vk/pipeline/pipeline.hpp" +#include "merian/vk/shader/shader_module.hpp" + +namespace merian_nodes { + +class ErrorToBuffer : public Node { +private: + static constexpr uint32_t local_size_x = 16; + static constexpr uint32_t local_size_y = 16; + static constexpr uint32_t workgroup_size = local_size_x * local_size_y; + + struct PushConstant { + uint32_t divisor; + + int size; + int offset; + int count; + }; + +public: + ErrorToBuffer(const ContextHandle context); + + ~ErrorToBuffer(); + + std::vector describe_inputs() override; + + std::vector describe_outputs(const NodeIOLayout& io_layout) override; + + NodeStatusFlags on_connected([[maybe_unused]] const NodeIOLayout& io_layout, + const DescriptorSetLayoutHandle& descriptor_set_layout) override; + + void process(GraphRun& run, + const vk::CommandBuffer& cmd, + const DescriptorSetHandle& descriptor_set, + const NodeIO& io) override; + +private: + const ContextHandle context; + + ManagedVkImageInHandle con_src1 = ManagedVkImageIn::compute_read("src1"); + ManagedVkImageInHandle con_src2 = ManagedVkImageIn::compute_read("src2"); + ManagedVkBufferOutHandle con_mean; + + PushConstant pc; + + ShaderModuleHandle image_to_buffer_shader; + ShaderModuleHandle reduce_buffer_shader; + + PipelineHandle image_to_buffer; + PipelineHandle reduce_buffer; +}; + +} // namespace merian_nodes diff --git a/src/merian-nodes/graph/node_registry.cpp b/src/merian-nodes/graph/node_registry.cpp index 2996e5ae..be028d53 100644 --- a/src/merian-nodes/graph/node_registry.cpp +++ b/src/merian-nodes/graph/node_registry.cpp @@ -14,6 +14,7 @@ #include "merian-nodes/nodes/image_write/image_write.hpp" #include "merian-nodes/nodes/mean/mean.hpp" #include "merian-nodes/nodes/median_approx/median.hpp" +#include "merian-nodes/nodes/mse_error/mse_error.hpp" #include "merian-nodes/nodes/svgf/svgf.hpp" #include "merian-nodes/nodes/taa/taa.hpp" #include "merian-nodes/nodes/tonemap/tonemap.hpp" @@ -66,6 +67,9 @@ NodeRegistry::NodeRegistry(const ContextHandle& context, const ResourceAllocator register_node(NodeInfo{ "Median (Approximation)", "Computes an approximation of the median of a component.", [=]() { return std::make_shared(context); }}); + register_node( + NodeInfo{"MSE Error", "Computes the mean square error of two images and outputs it as a single buffer element.", + [=]() { return std::make_shared(context); }}); register_node(NodeInfo{"Shadertoy", "Execute Shadertoy-like shaders (Limited implementation).", [=]() { return std::make_shared(context); }}); diff --git a/src/merian-nodes/nodes/meson.build b/src/merian-nodes/nodes/meson.build index 7effd76a..fd8867d3 100644 --- a/src/merian-nodes/nodes/meson.build +++ b/src/merian-nodes/nodes/meson.build @@ -12,6 +12,7 @@ subdir('image_read') subdir('image_write') subdir('mean') subdir('median_approx') +subdir('mse_error') subdir('shadertoy') subdir('svgf') subdir('taa') diff --git a/src/merian-nodes/nodes/mse_error/layout.glsl b/src/merian-nodes/nodes/mse_error/layout.glsl new file mode 100644 index 00000000..526f0908 --- /dev/null +++ b/src/merian-nodes/nodes/mse_error/layout.glsl @@ -0,0 +1,21 @@ +#extension GL_EXT_scalar_block_layout : enable + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; + +layout (constant_id = 2) const int SUBGROUP_SIZE = 0; + + +layout(set = 0, binding = 0) uniform sampler2D img_src1; +layout(set = 0, binding = 1) uniform sampler2D img_src2; + +layout(set = 0, binding = 2, scalar) buffer restrict buf_result { + vec4 result[]; +}; + +layout(push_constant) uniform PushStruct { + uint divisor; + + int size; + int offset; + int count; +} params; diff --git a/src/merian-nodes/nodes/mse_error/meson.build b/src/merian-nodes/nodes/mse_error/meson.build new file mode 100644 index 00000000..36e8969a --- /dev/null +++ b/src/merian-nodes/nodes/mse_error/meson.build @@ -0,0 +1,10 @@ +shaders = [ + 'mse_image_to_buffer.comp', + 'mse_reduce_buffer.comp', +] + +foreach s : shaders + merian_nodes_src += shader_generator.process(s) +endforeach + +merian_nodes_src += files('mse_error.cpp') diff --git a/src/merian-nodes/nodes/mse_error/mse_error.cpp b/src/merian-nodes/nodes/mse_error/mse_error.cpp new file mode 100644 index 00000000..63c5bc69 --- /dev/null +++ b/src/merian-nodes/nodes/mse_error/mse_error.cpp @@ -0,0 +1,118 @@ +#include "merian-nodes/nodes/mse_error/mse_error.hpp" + +#include "merian/vk/pipeline/pipeline_compute.hpp" +#include "merian/vk/pipeline/pipeline_layout_builder.hpp" +#include "merian/vk/pipeline/specialization_info_builder.hpp" +#include "merian-nodes/graph/errors.hpp" + +#include "image_to_buffer.comp.spv.h" +#include "reduce_buffer.comp.spv.h" + +namespace merian_nodes { + +ErrorToBuffer::ErrorToBuffer(const ContextHandle context) : Node(), context(context) { + + image_to_buffer_shader = std::make_shared( + context, merian_image_to_buffer_comp_spv_size(), merian_image_to_buffer_comp_spv()); + reduce_buffer_shader = std::make_shared( + context, merian_reduce_buffer_comp_spv_size(), merian_reduce_buffer_comp_spv()); +} + +ErrorToBuffer::~ErrorToBuffer() {} + +std::vector ErrorToBuffer::describe_inputs() { + return {con_src1, con_src2}; +} + +std::vector ErrorToBuffer::describe_outputs(const NodeIOLayout& io_layout) { + if (io_layout[con_src1]->create_info.extent != io_layout[con_src2]->create_info.extent) { + throw graph_errors::node_error{"image extents mismatch!"}; + } + vk::Extent3D extent = io_layout[con_src1]->create_info.extent; + + const auto group_count_x = (extent.width + local_size_x - 1) / local_size_x; + const auto group_count_y = (extent.height + local_size_y - 1) / local_size_y; + const std::size_t buffer_size = group_count_x * group_count_y; + + con_mean = std::make_shared( + "mean", vk::AccessFlagBits2::eShaderRead | vk::AccessFlagBits2::eShaderWrite, + vk::PipelineStageFlagBits2::eComputeShader, vk::ShaderStageFlagBits::eCompute, + vk::BufferCreateInfo({}, buffer_size * sizeof(glm::vec4), + vk::BufferUsageFlagBits::eStorageBuffer)); + + return {con_mean}; +} + +ErrorToBuffer::NodeStatusFlags +ErrorToBuffer::on_connected([[maybe_unused]] const NodeIOLayout& io_layout, + const DescriptorSetLayoutHandle& descriptor_set_layout) { + if (!image_to_buffer) { + auto pipe_layout = PipelineLayoutBuilder(context) + .add_descriptor_set_layout(descriptor_set_layout) + .add_push_constant() + .build_pipeline_layout(); + auto image_to_buffer_spec_builder = SpecializationInfoBuilder(); + image_to_buffer_spec_builder.add_entry( + local_size_x, local_size_y, + context->physical_device.physical_device_subgroup_properties.subgroupSize); + SpecializationInfoHandle spec = image_to_buffer_spec_builder.build(); + image_to_buffer = + std::make_shared(pipe_layout, image_to_buffer_shader, spec); + + auto reduce_buffer_spec_builder = SpecializationInfoBuilder(); + reduce_buffer_spec_builder.add_entry( + local_size_x * local_size_y, 1, + context->physical_device.physical_device_subgroup_properties.subgroupSize); + spec = reduce_buffer_spec_builder.build(); + reduce_buffer = std::make_shared(pipe_layout, reduce_buffer_shader, spec); + } + + return {}; +} + +void ErrorToBuffer::process([[maybe_unused]] GraphRun& run, + const vk::CommandBuffer& cmd, + const DescriptorSetHandle& descriptor_set, + const NodeIO& io) { + const auto group_count_x = (io[con_src1]->get_extent().width + local_size_x - 1) / local_size_x; + const auto group_count_y = (io[con_src1]->get_extent().height + local_size_y - 1) / local_size_y; + + pc.divisor = io[con_src1]->get_extent().width * io[con_src1]->get_extent().height; + + { + MERIAN_PROFILE_SCOPE_GPU(run.get_profiler(), cmd, "images to buffer"); + image_to_buffer->bind(cmd); + image_to_buffer->bind_descriptor_set(cmd, descriptor_set); + image_to_buffer->push_constant(cmd, pc); + cmd.dispatch(group_count_x, group_count_y, 1); + } + + pc.size = group_count_x * group_count_y; + pc.offset = 1; + pc.count = group_count_x * group_count_y; + + while (pc.count > 1) { + MERIAN_PROFILE_SCOPE_GPU(run.get_profiler(), cmd, + fmt::format("reduce {} elements", pc.count)); + auto bar = io[con_mean]->buffer_barrier( + vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite); + cmd.pipelineBarrier(vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eComputeShader, {}, {}, bar, {}); + + reduce_buffer->bind(cmd); + reduce_buffer->bind_descriptor_set(cmd, descriptor_set); + reduce_buffer->push_constant(cmd, pc); + cmd.dispatch((pc.count + workgroup_size - 1) / workgroup_size, 1, 1); + + pc.count = (pc.count + workgroup_size - 1) / workgroup_size; + pc.offset *= workgroup_size; + } + + /*run.add_submit_callback([&](const QueueHandle& queue, GraphRun& run) { + queue->wait_idle(); + descriptor_set. + });*/ +} + +} // namespace merian_nodes diff --git a/src/merian-nodes/nodes/mse_error/mse_image_to_buffer.comp b/src/merian-nodes/nodes/mse_error/mse_image_to_buffer.comp new file mode 100644 index 00000000..5a8ec432 --- /dev/null +++ b/src/merian-nodes/nodes/mse_error/mse_image_to_buffer.comp @@ -0,0 +1,39 @@ +#version 460 +#extension GL_GOOGLE_include_directive : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +#include "layout.glsl" + +shared vec4 subgroup_sums[gl_WorkGroupSize.x * gl_WorkGroupSize.y / SUBGROUP_SIZE + 1]; + +void main() { + const ivec2 ipos = ivec2(gl_GlobalInvocationID); + + vec4 v; + if (any(greaterThanEqual(ipos, textureSize(img_src1, 0))) || any(greaterThanEqual(ipos, textureSize(img_src2, 0)))) { + v = vec4(0); + } else { + v = (texelFetch(img_src1, ipos, 0) - texelFetch(img_src2, ipos, 0)); + v= v * v / params.divisor; + } + + v.x = subgroupAdd(v.x); + v.y = subgroupAdd(v.y); + v.z = subgroupAdd(v.z); + v.w = subgroupAdd(v.w); + + if (subgroupElect()) { + subgroup_sums[gl_SubgroupID] = v; + } + + barrier(); + + if (gl_LocalInvocationIndex == 0) { + vec4 sum = vec4(0); + for (int i = 0; i < gl_NumSubgroups; i++) { + sum += subgroup_sums[i]; + } + + result[gl_WorkGroupID.y + gl_NumWorkGroups.y * gl_WorkGroupID.x] = sum; + } +} diff --git a/src/merian-nodes/nodes/mse_error/mse_reduce_buffer.comp b/src/merian-nodes/nodes/mse_error/mse_reduce_buffer.comp new file mode 100644 index 00000000..f9bb3ef2 --- /dev/null +++ b/src/merian-nodes/nodes/mse_error/mse_reduce_buffer.comp @@ -0,0 +1,39 @@ +#version 460 +#extension GL_GOOGLE_include_directive : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +#include "layout.glsl" + +shared vec4 subgroup_sums[gl_WorkGroupSize.x * gl_WorkGroupSize.y / SUBGROUP_SIZE + 1]; + +void main() { + uint index = params.offset * (gl_LocalInvocationIndex + gl_WorkGroupID.x * gl_WorkGroupSize.x); + + vec4 v; + if (index < params.size) { + v = result[index]; + } else { + v = vec4(0); + } + + v.x = subgroupAdd(v.x); + v.y = subgroupAdd(v.y); + v.z = subgroupAdd(v.z); + v.w = subgroupAdd(v.w); + + if (subgroupElect()) { + subgroup_sums[gl_SubgroupID] = v; + } + + barrier(); + + if (gl_LocalInvocationIndex == 0) { + vec4 sum = vec4(0); + for (int i = 0; i < gl_NumSubgroups; i++) { + sum += subgroup_sums[i]; + } + + result[index] = sum; + } + +}