diff --git a/src/correlation/gpu/shaders/cross_check_filter.comp.glsl b/src/correlation/gpu/shaders/cross_check_filter.comp.glsl index 0020f43..d41f716 100644 --- a/src/correlation/gpu/shaders/cross_check_filter.comp.glsl +++ b/src/correlation/gpu/shaders/cross_check_filter.comp.glsl @@ -22,11 +22,11 @@ layout(std430, push_constant) uniform readonly Parameters float extend_range; float min_range; }; -layout(std430, set = 1, binding = 0) buffer Img1 +layout(std430, set = 0, binding = 0) buffer Img1 { ivec2 img1[]; }; -layout(std430, set = 1, binding = 1) buffer readonly Img2 +layout(std430, set = 0, binding = 1) buffer readonly Img2 { ivec2 img2[]; }; diff --git a/src/correlation/gpu/shaders/cross_check_filter.spv b/src/correlation/gpu/shaders/cross_check_filter.spv index 7f3670c..a61edb1 100644 Binary files a/src/correlation/gpu/shaders/cross_check_filter.spv and b/src/correlation/gpu/shaders/cross_check_filter.spv differ diff --git a/src/correlation/gpu/vulkan.rs b/src/correlation/gpu/vulkan.rs index a699087..3d74877 100644 --- a/src/correlation/gpu/vulkan.rs +++ b/src/correlation/gpu/vulkan.rs @@ -30,6 +30,7 @@ pub struct Device { device: ash::Device, memory_properties: vk::PhysicalDeviceMemoryProperties, buffers: Option, + direction: CorrelationDirection, max_buffer_size: usize, descriptor_sets: DescriptorSets, pipelines: HashMap, @@ -64,8 +65,7 @@ struct DeviceBuffers { struct DescriptorSets { descriptor_pool: vk::DescriptorPool, - regular_layout: vk::DescriptorSetLayout, - cross_check_layout: vk::DescriptorSetLayout, + layout: vk::DescriptorSetLayout, pipeline_layout: vk::PipelineLayout, descriptor_sets: Vec, } @@ -229,6 +229,7 @@ impl Device { instance.destroy_instance(None); err }; + let direction = CorrelationDirection::Forward; // Init control struct - queues, fences, command buffer. let control = unsafe { Device::create_control(&device, compute_queue_index).map_err(cleanup_err)? }; @@ -238,6 +239,7 @@ impl Device { device, memory_properties, buffers: None, + direction, max_buffer_size, descriptor_sets, pipelines, @@ -709,27 +711,20 @@ impl Device { .ty(vk::DescriptorType::STORAGE_BUFFER) .descriptor_count(6)]; let descriptor_pool_info = vk::DescriptorPoolCreateInfo::default() - .max_sets(2) + .max_sets(1) .pool_sizes(&descriptor_pool_size); let descriptor_pool = device.create_descriptor_pool(&descriptor_pool_info, None)?; let cleanup_err = |err| { device.destroy_descriptor_pool(descriptor_pool, None); err }; - let regular_layout = create_layout_bindings(6).map_err(cleanup_err)?; + let layout = create_layout_bindings(6).map_err(cleanup_err)?; let cleanup_err = |err| { - device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_set_layout(layout, None); device.destroy_descriptor_pool(descriptor_pool, None); err }; - let cross_check_layout = create_layout_bindings(2).map_err(cleanup_err)?; - let cleanup_err = |err| { - device.destroy_descriptor_set_layout(cross_check_layout, None); - device.destroy_descriptor_set_layout(regular_layout, None); - device.destroy_descriptor_pool(descriptor_pool, None); - err - }; - let layouts = [regular_layout, cross_check_layout]; + let layouts = [layout]; let push_constant_ranges = vk::PushConstantRange::default() .offset(0) .size(std::mem::size_of::() as u32) @@ -744,8 +739,7 @@ impl Device { .map_err(cleanup_err)?; let cleanup_err = |err| { device.destroy_pipeline_layout(pipeline_layout, None); - device.destroy_descriptor_set_layout(cross_check_layout, None); - device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_set_layout(layout, None); device.destroy_descriptor_pool(descriptor_pool, None); err }; @@ -758,8 +752,7 @@ impl Device { Ok(DescriptorSets { descriptor_pool, - regular_layout, - cross_check_layout, + layout, pipeline_layout, descriptor_sets, }) @@ -836,6 +829,57 @@ impl Device { Ok(result) } + fn set_buffer_layout(&mut self, shader: &ShaderModuleType) -> Result<(), GpuError> { + let direction = self.direction; + let descriptor_sets = &self.descriptor_sets; + let buffers = &self.buffers()?; + let (buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) = + match direction { + CorrelationDirection::Forward => ( + buffers.buffer_internal_img1, + buffers.buffer_internal_img2, + buffers.buffer_out, + buffers.buffer_out_reverse, + ), + CorrelationDirection::Reverse => ( + buffers.buffer_internal_img2, + buffers.buffer_internal_img1, + buffers.buffer_out_reverse, + buffers.buffer_out, + ), + }; + let buffer_list = if matches!(shader, ShaderModuleType::CrossCheckFilter) { + vec![buffer_out, buffer_out_reverse] + } else { + vec![ + buffers.buffer_img, + buffer_internal_img1, + buffer_internal_img2, + buffers.buffer_internal_int, + buffer_out, + buffers.buffer_out_corr, + ] + }; + let buffer_infos = buffer_list + .iter() + .map(|buf| { + vk::DescriptorBufferInfo::default() + .buffer(buf.buffer) + .offset(0) + .range(vk::WHOLE_SIZE) + }) + .collect::>(); + let write_descriptor = vk::WriteDescriptorSet::default() + .dst_set(descriptor_sets.descriptor_sets[0]) + .dst_binding(0) + .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) + .buffer_info(buffer_infos.as_slice()); + unsafe { + self.device.update_descriptor_sets(&[write_descriptor], &[]); + } + Ok(()) + } + unsafe fn create_control( device: &ash::Device, queue_family_index: u32, @@ -876,57 +920,7 @@ impl Device { impl super::Device for Device { fn set_buffer_direction(&mut self, direction: &CorrelationDirection) -> Result<(), GpuError> { - let descriptor_sets = &self.descriptor_sets; - let buffers = &self.buffers()?; - let create_buffer_infos = |buffers: &[Buffer]| { - buffers - .iter() - .map(|buf| { - vk::DescriptorBufferInfo::default() - .buffer(buf.buffer) - .offset(0) - .range(vk::WHOLE_SIZE) - }) - .collect::>() - }; - let (buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) = - match direction { - CorrelationDirection::Forward => ( - buffers.buffer_internal_img1, - buffers.buffer_internal_img2, - buffers.buffer_out, - buffers.buffer_out_reverse, - ), - CorrelationDirection::Reverse => ( - buffers.buffer_internal_img2, - buffers.buffer_internal_img1, - buffers.buffer_out_reverse, - buffers.buffer_out, - ), - }; - let regular_buffer_infos = create_buffer_infos(&[ - buffers.buffer_img, - buffer_internal_img1, - buffer_internal_img2, - buffers.buffer_internal_int, - buffer_out, - buffers.buffer_out_corr, - ]); - let regular_write_descriptor = vk::WriteDescriptorSet::default() - .dst_set(descriptor_sets.descriptor_sets[0]) - .dst_binding(0) - .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) - .buffer_info(regular_buffer_infos.as_slice()); - let cross_check_buffer_infos = create_buffer_infos(&[buffer_out, buffer_out_reverse]); - let cross_check_write_descriptor = vk::WriteDescriptorSet::default() - .dst_set(descriptor_sets.descriptor_sets[1]) - .dst_binding(0) - .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) - .buffer_info(cross_check_buffer_infos.as_slice()); - let write_descriptors = [regular_write_descriptor, cross_check_write_descriptor]; - unsafe { - self.device.update_descriptor_sets(&write_descriptors, &[]); - } + self.direction = direction.to_owned(); Ok(()) } @@ -952,6 +946,7 @@ impl super::Device for Device { vk::PipelineBindPoint::COMPUTE, pipeline_config.pipeline, ); + self.set_buffer_layout(&shader_type)?; // It's way easier to map all descriptor sets identically, instead of ensuring that every // kernel gets to use set = 0. // The cross correlation kernel will need to switch to descriptor set = 1. @@ -1107,8 +1102,7 @@ impl DescriptorSets { unsafe fn destroy(&self, device: &ash::Device) { let _ = device.free_descriptor_sets(self.descriptor_pool, self.descriptor_sets.as_slice()); device.destroy_pipeline_layout(self.pipeline_layout, None); - device.destroy_descriptor_set_layout(self.cross_check_layout, None); - device.destroy_descriptor_set_layout(self.regular_layout, None); + device.destroy_descriptor_set_layout(self.layout, None); device.destroy_descriptor_pool(self.descriptor_pool, None); } }