Skip to content

Commit

Permalink
merian-nodes: accum, svgf, taa: Optional mv input
Browse files Browse the repository at this point in the history
  • Loading branch information
LDAP committed Nov 20, 2024
1 parent 9431e64 commit 080077f
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 74 deletions.
3 changes: 2 additions & 1 deletion include/merian-nodes/nodes/accumulate/accumulate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Accumulate : public Node {
ManagedVkImageInHandle con_prev_accum = ManagedVkImageIn::compute_read("prev_accum", 1);
ManagedVkImageInHandle con_prev_moments = ManagedVkImageIn::compute_read("prev_moments", 1);
ManagedVkImageInHandle con_irr_in = ManagedVkImageIn::compute_read("irr");
ManagedVkImageInHandle con_mv = ManagedVkImageIn::compute_read("mv");
ManagedVkImageInHandle con_mv = ManagedVkImageIn::compute_read("mv", 0, true);
ManagedVkImageInHandle con_moments_in = ManagedVkImageIn::compute_read("moments_in");

ManagedVkBufferInHandle con_gbuf = ManagedVkBufferIn::compute_read("gbuf");
Expand Down Expand Up @@ -117,6 +117,7 @@ class Accumulate : public Node {
int filter_mode = 0;
VkBool32 extended_search = VK_TRUE;
VkBool32 reuse_border = VK_FALSE;
bool enable_mv = VK_TRUE;

std::string clear_event_listener_pattern = "/user/clear";
};
Expand Down
3 changes: 2 additions & 1 deletion include/merian-nodes/nodes/add/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Add : public AbstractCompute {
static constexpr uint32_t local_size_y = 32;

public:
Add(const ContextHandle context, const std::optional<vk::Format> output_format = std::nullopt);
Add(const ContextHandle& context,
const std::optional<vk::Format>& output_format = std::nullopt);

~Add();

Expand Down
7 changes: 4 additions & 3 deletions include/merian-nodes/nodes/svgf/svgf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class SVGF : public Node {
};

public:
SVGF(const ContextHandle context,
const ResourceAllocatorHandle allocator,
SVGF(const ContextHandle& context,
const ResourceAllocatorHandle& allocator,
const std::optional<vk::Format> output_format = std::nullopt);

~SVGF();
Expand Down Expand Up @@ -70,7 +70,7 @@ class SVGF : public Node {
ManagedVkImageInHandle con_irr = ManagedVkImageIn::compute_read("irr");
ManagedVkImageInHandle con_moments = ManagedVkImageIn::compute_read("moments");
ManagedVkImageInHandle con_albedo = ManagedVkImageIn::compute_read("albedo");
ManagedVkImageInHandle con_mv = ManagedVkImageIn::compute_read("mv");
ManagedVkImageInHandle con_mv = ManagedVkImageIn::compute_read("mv", 0, true);
ManagedVkBufferInHandle con_gbuffer = ManagedVkBufferIn::compute_read("gbuffer");
ManagedVkBufferInHandle con_prev_gbuffer = ManagedVkBufferIn::compute_read("prev_gbuffer", 1);

Expand Down Expand Up @@ -109,6 +109,7 @@ class SVGF : public Node {
int taa_filter_prev = false;
int taa_clamping = 0;
int taa_mv_sampling = 0;
bool enable_mv = true;
};

} // namespace merian_nodes
2 changes: 2 additions & 0 deletions include/merian-nodes/nodes/taa/taa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class TAA : public AbstractCompute {
// higher value means more temporal reuse
float temporal_alpha;
int clamp_method;
VkBool32 enable_mv;
};

public:
Expand Down Expand Up @@ -42,6 +43,7 @@ class TAA : public AbstractCompute {
SpecializationInfoHandle spec_info;

ManagedVkImageInHandle con_src = ManagedVkImageIn::compute_read("src");
ManagedVkImageInHandle con_mv = ManagedVkImageIn::compute_read("mv", 0, true);

PushConstant pc;
uint32_t width{};
Expand Down
33 changes: 21 additions & 12 deletions src/merian-nodes/nodes/accumulate/accumulate.comp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ layout(constant_id = 3) const uint WG_ROUNDED_IRR_SIZE_Y = 1;
layout (constant_id = 4) const int FILTER_MODE = FILTER_MODE_NEAREST;
layout (constant_id = 5) const bool EXTENDED_SEARCH = true;
layout (constant_id = 6) const bool REUSE_BORDER = false;
layout (constant_id = 7) const bool USE_MOTION_VECTORS = true;

layout(set = 1, binding = 0) uniform sampler2D img_quartiles;

Expand Down Expand Up @@ -160,20 +161,28 @@ void main() {
return;
}

const vec2 mv = texelFetch(img_mv, ipos, 0).rg;

// REPROJECTION
vec2 prev_pos = ipos + mv;
float max_history = pc.accum_max_hist;
if (REUSE_BORDER) {
// Attemp to reuse information at the image border.
// This results in minor smearing but looks a lot better than
// noise / SVGF blotches.
if (reprojection_intersect_border(prev_pos, mv, imageSize(img_accum) - 1)) {
// reset history to converge faster
max_history = 2.0;
vec2 prev_pos;
float max_history;
if (USE_MOTION_VECTORS) {
const vec2 mv = texelFetch(img_mv, ipos, 0).rg;

// REPROJECTION
prev_pos = ipos + mv;
float max_history = pc.accum_max_hist;
if (REUSE_BORDER) {
// Attemp to reuse information at the image border.
// This results in minor smearing but looks a lot better than
// noise / SVGF blotches.
if (reprojection_intersect_border(prev_pos, mv, imageSize(img_accum) - 1)) {
// reset history to converge faster
max_history = 2.0;
}
}
} else {
prev_pos = ipos;
max_history = pc.accum_max_hist;
}


vec4 prev_irr_histlen = vec4(0);
vec2 prev_moments = vec2(0);
Expand Down
25 changes: 13 additions & 12 deletions src/merian-nodes/nodes/accumulate/accumulate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ Accumulate::on_connected([[maybe_unused]] const NodeIOLayout& io_layout,
calculate_percentiles =
std::make_shared<ComputePipeline>(quartile_pipe_layout, percentile_module, quartile_spec);

auto filter_pipe_layout = PipelineLayoutBuilder(context)
.add_descriptor_set_layout(graph_layout)
.add_descriptor_set_layout(accumulate_desc_layout)
.add_push_constant<FilterPushConstant>()
.build_pipeline_layout();
auto filter_spec_builder = SpecializationInfoBuilder();
auto accum_pipe_layout = PipelineLayoutBuilder(context)
.add_descriptor_set_layout(graph_layout)
.add_descriptor_set_layout(accumulate_desc_layout)
.add_push_constant<FilterPushConstant>()
.build_pipeline_layout();
auto accum_spec_builder = SpecializationInfoBuilder();
const uint32_t wg_rounded_irr_size_x = percentile_group_count_x * PERCENTILE_LOCAL_SIZE_X;
const uint32_t wg_rounded_irr_size_y = percentile_group_count_y * PERCENTILE_LOCAL_SIZE_Y;
filter_spec_builder.add_entry(FILTER_LOCAL_SIZE_X, FILTER_LOCAL_SIZE_Y, wg_rounded_irr_size_x,
wg_rounded_irr_size_y, filter_mode, extended_search,
reuse_border);
auto filter_spec = filter_spec_builder.build();
accum_spec_builder.add_entry(FILTER_LOCAL_SIZE_X, FILTER_LOCAL_SIZE_Y, wg_rounded_irr_size_x,
wg_rounded_irr_size_y, filter_mode, extended_search, reuse_border,
enable_mv && io_layout.is_connected(con_mv));
const auto accum_spec = accum_spec_builder.build();
accumulate =
std::make_shared<ComputePipeline>(filter_pipe_layout, accumulate_module, filter_spec);
std::make_shared<ComputePipeline>(accum_pipe_layout, accumulate_module, accum_spec);

return {};
}
Expand Down Expand Up @@ -192,7 +192,8 @@ Accumulate::NodeStatusFlags Accumulate::properties(Properties& config) {
needs_rebuild |=
config.config_text("clear event pattern", clear_event_listener_pattern, true,
"Set the event pattern which triggers a clear. Press enter to confirm.");

needs_rebuild |=
config.config_bool("enable motion vectors", enable_mv, "uses motion vectors if connected.");
config.st_separate("Reproject");
float angle = glm::acos(accumulate_pc.normal_reject_cos);
config.config_angle("normal threshold", angle, "Reject points with normals farther apart", 0,
Expand Down
4 changes: 2 additions & 2 deletions src/merian-nodes/nodes/add/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace merian_nodes {

Add::Add(const ContextHandle context, const std::optional<vk::Format> output_format)
Add::Add(const ContextHandle& context, const std::optional<vk::Format>& output_format)
: AbstractCompute(context), output_format(output_format) {
shader =
std::make_shared<ShaderModule>(context, merian_add_comp_spv_size(), merian_add_comp_spv());
Expand Down Expand Up @@ -46,7 +46,7 @@ std::vector<OutputConnectorHandle> Add::describe_outputs(const NodeIOLayout& io_
format = io_layout[input]->create_info.format;
extent = min(extent, io_layout[input]->create_info.extent);
}
spec_builder.add_entry<VkBool32>(io_layout.is_connected(input));
spec_builder.add_entry(io_layout.is_connected(input));
}

if (!at_least_one_input_connected) {
Expand Down
64 changes: 29 additions & 35 deletions src/merian-nodes/nodes/svgf/svgf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ uint32_t get_ve_local_size(const ContextHandle& context) {
if (32 * 32 * VE_SHARED_MEMORY_PER_PIXEL <=
context->physical_device.get_physical_device_limits().maxComputeSharedMemorySize) {
return 32;
} else if (16 * 16 * VE_SHARED_MEMORY_PER_PIXEL <=
context->physical_device.get_physical_device_limits().maxComputeSharedMemorySize) {
}
if (16 * 16 * VE_SHARED_MEMORY_PER_PIXEL <=
context->physical_device.get_physical_device_limits().maxComputeSharedMemorySize) {
return 16;
} else {
throw std::runtime_error{"SVGF: Not enough shared memory for spatial variance estimate."};
}
throw std::runtime_error{"SVGF: Not enough shared memory for spatial variance estimate."};
}

SVGF::SVGF(const ContextHandle context,
const ResourceAllocatorHandle allocator,
SVGF::SVGF(const ContextHandle& context,
const ResourceAllocatorHandle& allocator,
const std::optional<vk::Format> output_format)
: Node(), context(context), allocator(allocator), output_format(output_format),
variance_estimate_local_size_x(get_ve_local_size(context)),
Expand Down Expand Up @@ -135,7 +135,8 @@ SVGF::NodeStatusFlags SVGF::on_connected([[maybe_unused]] const NodeIOLayout& io
{
auto spec_builder = SpecializationInfoBuilder();
spec_builder.add_entry(local_size_x, local_size_y, taa_debug, taa_filter_prev,
taa_clamping, taa_mv_sampling);
taa_clamping, taa_mv_sampling,
enable_mv && io_layout.is_connected(con_mv));
SpecializationInfoHandle taa_spec = spec_builder.build();
taa = std::make_shared<ComputePipeline>(taa_pipe_layout, taa_module, taa_spec);
}
Expand Down Expand Up @@ -244,10 +245,8 @@ SVGF::NodeStatusFlags SVGF::properties(Properties& config) {
config.config_float("depth accept", variance_estimate_pc.depth_accept, "More means more reuse");

config.st_separate("Filter");
const int old_svgf_iterations = svgf_iterations;
config.config_int("SVGF iterations", svgf_iterations, 0, 10,
"0 disables SVGF completely (TAA-only mode)");
needs_rebuild |= old_svgf_iterations != svgf_iterations;
needs_rebuild |= config.config_int("SVGF iterations", svgf_iterations, 0, 10,
"0 disables SVGF completely (TAA-only mode)");
config.config_float("filter depth", filter_pc.param_z, "more means more blur");
angle = glm::acos(filter_pc.param_n);
config.config_angle("filter normals", angle, "Reject with normals farther apart", 0, 180);
Expand All @@ -257,10 +256,9 @@ SVGF::NodeStatusFlags SVGF::properties(Properties& config) {
"z-dependent rejection: increase to reject more. Disable with <= 0.");
config.config_float("z-bias depth", filter_pc.z_bias_depth,
"z-dependent rejection: increase to reject more. Disable with <= 0.");
int old_filter_type = filter_type;
config.config_options("filter type", filter_type, {"atrous", "box", "subsampled"},
Properties::OptionsStyle::COMBO);
needs_rebuild |= old_filter_type != filter_type;
needs_rebuild |=
config.config_options("filter type", filter_type, {"atrous", "box", "subsampled"},
Properties::OptionsStyle::COMBO);
needs_rebuild |= config.config_bool("filter variance", filter_variance,
"Filter variance with a 3x3 gaussian");

Expand All @@ -269,34 +267,30 @@ SVGF::NodeStatusFlags SVGF::properties(Properties& config) {
"TAA alpha", taa_pc.blend_alpha, 0, 1,
"Blend factor for the final image and the previous image. More means more reuse.");

const int old_taa_debug = taa_debug;
const int old_taa_filter_prev = taa_filter_prev;
const int old_taa_clamping = taa_clamping;
const int old_taa_mv_sampling = taa_mv_sampling;
config.config_options("mv sampling", taa_mv_sampling, {"center", "magnitude dilation"},
Properties::OptionsStyle::COMBO);
config.config_options("filter", taa_filter_prev, {"none", "catmull rom"},
Properties::OptionsStyle::COMBO);
config.config_options("clamping", taa_clamping, {"min-max", "moments"},
Properties::OptionsStyle::COMBO);
needs_rebuild |=
config.config_bool("enable motion vectors", enable_mv, "uses motion vectors if connected.");
if (enable_mv) {
needs_rebuild |=
config.config_options("mv sampling", taa_mv_sampling, {"center", "magnitude dilation"},
Properties::OptionsStyle::COMBO);
}
needs_rebuild |= config.config_options("filter", taa_filter_prev, {"none", "catmull rom"},
Properties::OptionsStyle::COMBO);
needs_rebuild |= config.config_options("clamping", taa_clamping, {"min-max", "moments"},
Properties::OptionsStyle::COMBO);
if (taa_clamping == 1)
config.config_float(
"TAA rejection threshold", taa_pc.rejection_threshold,
"TAA rejection threshold for the previous frame, in units of standard deviation", 0.01);
config.config_options("debug", taa_debug,
{"none", "irradiance", "variance", "normal", "depth", "albedo", "grad z",
"irradiance nan/inf", "mv"});

needs_rebuild |= old_taa_debug != taa_debug;
needs_rebuild |= old_taa_filter_prev != taa_filter_prev;
needs_rebuild |= old_taa_clamping != taa_clamping;
needs_rebuild |= old_taa_mv_sampling != taa_mv_sampling;
needs_rebuild |= config.config_options("debug", taa_debug,
{"none", "irradiance", "variance", "normal", "depth",
"albedo", "grad z", "irradiance nan/inf", "mv"});

if (needs_rebuild) {
return NEEDS_RECONNECT;
} else {
return {};
}

return {};
}

} // namespace merian_nodes
15 changes: 10 additions & 5 deletions src/merian-nodes/nodes/svgf/svgf_taa.comp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ layout (constant_id = 2) const int DEBUG = 0;
layout (constant_id = 3) const int FILTER_PREV = 0;
layout (constant_id = 4) const int CLAMPING = 0;
layout (constant_id = 5) const int MV_SAMPLING = 0;
layout (constant_id = 6) const bool USE_MOTION_VECTORS = true;

layout(push_constant, std140) uniform params_t {
float blend_alpha;
Expand All @@ -25,11 +26,15 @@ main()
const ivec2 ipos = ivec2(gl_GlobalInvocationID);
if (any(greaterThanEqual(ipos, imageSize(img_out)))) return;

vec2 mv = vec2(0);
if (MV_SAMPLING == 0)
mv = texelFetch(img_mv, ipos, 0).rg;
else if (MV_SAMPLING == 1)
mv = sample_motion_vector(img_mv, ipos, 1);
vec2 mv;
if (USE_MOTION_VECTORS) {
if (MV_SAMPLING == 0)
mv = texelFetch(img_mv, ipos, 0).rg;
else if (MV_SAMPLING == 1)
mv = sample_motion_vector(img_mv, ipos, 1);
} else {
mv = vec2(0);
}

const vec4 filter_result = texelFetch(img_filter_result, ipos, 0);

Expand Down
9 changes: 8 additions & 1 deletion src/merian-nodes/nodes/taa/taa.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ layout(push_constant) uniform PushConstant {
// higher value means more temporal reuse
float temporal_alpha;
int clamp_method;
int enable_mv;
} params;

vec4 merge_frames(const vec4 current_color,
Expand Down Expand Up @@ -70,7 +71,13 @@ void main() {
if (any(greaterThanEqual(pixel, resolution)))
return;

vec2 motion_vector = sample_motion_vector(img_mv, pixel, 1);
vec2 motion_vector;
if (params.enable_mv != 0) {
motion_vector = sample_motion_vector(img_mv, pixel, 1);
} else {
motion_vector = vec2(0);
}

if (INVERSE_MOTION > 0) {
motion_vector *= -1;
}
Expand Down
5 changes: 3 additions & 2 deletions src/merian-nodes/nodes/taa/taa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ TAA::TAA(const ContextHandle context) : AbstractCompute(context, sizeof(PushCons
std::vector<InputConnectorHandle> TAA::describe_inputs() {
return {
con_src,
ManagedVkImageIn::compute_read("feedback", 1),
ManagedVkImageIn::compute_read("mv"),
ManagedVkImageIn::compute_read("prev_src", 1),
con_mv,
};
}

Expand All @@ -30,6 +30,7 @@ TAA::describe_outputs([[maybe_unused]] const NodeIOLayout& io_layout) {
width = io_layout[con_src]->create_info.extent.width;
height = io_layout[con_src]->create_info.extent.height;

pc.enable_mv = static_cast<VkBool32>(io_layout.is_connected(con_mv));
return {
ManagedVkImageOut::compute_write("out", io_layout[con_src]->create_info.format, width,
height),
Expand Down

0 comments on commit 080077f

Please sign in to comment.