diff --git a/src/correlation/gpu/shaders/cross_check_filter.comp.glsl b/src/correlation/gpu/shaders/cross_check_filter.comp.glsl index 0020f43..d41f716 100644 --- a/src/correlation/gpu/shaders/cross_check_filter.comp.glsl +++ b/src/correlation/gpu/shaders/cross_check_filter.comp.glsl @@ -22,11 +22,11 @@ layout(std430, push_constant) uniform readonly Parameters float extend_range; float min_range; }; -layout(std430, set = 1, binding = 0) buffer Img1 +layout(std430, set = 0, binding = 0) buffer Img1 { ivec2 img1[]; }; -layout(std430, set = 1, binding = 1) buffer readonly Img2 +layout(std430, set = 0, binding = 1) buffer readonly Img2 { ivec2 img2[]; }; diff --git a/src/correlation/gpu/shaders/cross_check_filter.spv b/src/correlation/gpu/shaders/cross_check_filter.spv index 7f3670c..a61edb1 100644 Binary files a/src/correlation/gpu/shaders/cross_check_filter.spv and b/src/correlation/gpu/shaders/cross_check_filter.spv differ diff --git a/src/correlation/gpu/vulkan.rs b/src/correlation/gpu/vulkan.rs index a699087..a738542 100644 --- a/src/correlation/gpu/vulkan.rs +++ b/src/correlation/gpu/vulkan.rs @@ -30,6 +30,7 @@ pub struct Device { device: ash::Device, memory_properties: vk::PhysicalDeviceMemoryProperties, buffers: Option, + direction: CorrelationDirection, max_buffer_size: usize, descriptor_sets: DescriptorSets, pipelines: HashMap, @@ -229,6 +230,7 @@ impl Device { instance.destroy_instance(None); err }; + let direction = CorrelationDirection::Forward; // Init control struct - queues, fences, command buffer. let control = unsafe { Device::create_control(&device, compute_queue_index).map_err(cleanup_err)? }; @@ -238,6 +240,7 @@ impl Device { device, memory_properties, buffers: None, + direction, max_buffer_size, descriptor_sets, pipelines, @@ -709,7 +712,7 @@ impl Device { .ty(vk::DescriptorType::STORAGE_BUFFER) .descriptor_count(6)]; let descriptor_pool_info = vk::DescriptorPoolCreateInfo::default() - .max_sets(2) + .max_sets(1) .pool_sizes(&descriptor_pool_size); let descriptor_pool = device.create_descriptor_pool(&descriptor_pool_info, None)?; let cleanup_err = |err| { @@ -836,6 +839,67 @@ impl Device { Ok(result) } + fn set_buffer_layout(&mut self, shader: &ShaderModuleType) -> Result<(), GpuError> { + let direction = self.direction; + let descriptor_sets = &self.descriptor_sets; + let buffers = &self.buffers()?; + let create_buffer_infos = |buffers: &[Buffer]| { + buffers + .iter() + .map(|buf| { + vk::DescriptorBufferInfo::default() + .buffer(buf.buffer) + .offset(0) + .range(vk::WHOLE_SIZE) + }) + .collect::>() + }; + let (buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) = + match direction { + CorrelationDirection::Forward => ( + buffers.buffer_internal_img1, + buffers.buffer_internal_img2, + buffers.buffer_out, + buffers.buffer_out_reverse, + ), + CorrelationDirection::Reverse => ( + buffers.buffer_internal_img2, + buffers.buffer_internal_img1, + buffers.buffer_out_reverse, + buffers.buffer_out, + ), + }; + let (buffer_list, descriptor_set) = if matches!(shader, ShaderModuleType::CrossCheckFilter) + { + ( + vec![buffer_out, buffer_out_reverse], + descriptor_sets.descriptor_sets[0], + ) + } else { + ( + vec![ + buffers.buffer_img, + buffer_internal_img1, + buffer_internal_img2, + buffers.buffer_internal_int, + buffer_out, + buffers.buffer_out_corr, + ], + descriptor_sets.descriptor_sets[1], + ) + }; + let buffer_infos = create_buffer_infos(buffer_list.as_slice()); + let write_descriptor = vk::WriteDescriptorSet::default() + .dst_set(descriptor_set) + .dst_binding(0) + .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) + .buffer_info(buffer_infos.as_slice()); + unsafe { + self.device.update_descriptor_sets(&[write_descriptor], &[]); + } + Ok(()) + } + unsafe fn create_control( device: &ash::Device, queue_family_index: u32, @@ -876,57 +940,7 @@ impl Device { impl super::Device for Device { fn set_buffer_direction(&mut self, direction: &CorrelationDirection) -> Result<(), GpuError> { - let descriptor_sets = &self.descriptor_sets; - let buffers = &self.buffers()?; - let create_buffer_infos = |buffers: &[Buffer]| { - buffers - .iter() - .map(|buf| { - vk::DescriptorBufferInfo::default() - .buffer(buf.buffer) - .offset(0) - .range(vk::WHOLE_SIZE) - }) - .collect::>() - }; - let (buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) = - match direction { - CorrelationDirection::Forward => ( - buffers.buffer_internal_img1, - buffers.buffer_internal_img2, - buffers.buffer_out, - buffers.buffer_out_reverse, - ), - CorrelationDirection::Reverse => ( - buffers.buffer_internal_img2, - buffers.buffer_internal_img1, - buffers.buffer_out_reverse, - buffers.buffer_out, - ), - }; - let regular_buffer_infos = create_buffer_infos(&[ - buffers.buffer_img, - buffer_internal_img1, - buffer_internal_img2, - buffers.buffer_internal_int, - buffer_out, - buffers.buffer_out_corr, - ]); - let regular_write_descriptor = vk::WriteDescriptorSet::default() - .dst_set(descriptor_sets.descriptor_sets[0]) - .dst_binding(0) - .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) - .buffer_info(regular_buffer_infos.as_slice()); - let cross_check_buffer_infos = create_buffer_infos(&[buffer_out, buffer_out_reverse]); - let cross_check_write_descriptor = vk::WriteDescriptorSet::default() - .dst_set(descriptor_sets.descriptor_sets[1]) - .dst_binding(0) - .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) - .buffer_info(cross_check_buffer_infos.as_slice()); - let write_descriptors = [regular_write_descriptor, cross_check_write_descriptor]; - unsafe { - self.device.update_descriptor_sets(&write_descriptors, &[]); - } + self.direction = direction.to_owned(); Ok(()) } @@ -952,6 +966,7 @@ impl super::Device for Device { vk::PipelineBindPoint::COMPUTE, pipeline_config.pipeline, ); + self.set_buffer_layout(&shader_type)?; // It's way easier to map all descriptor sets identically, instead of ensuring that every // kernel gets to use set = 0. // The cross correlation kernel will need to switch to descriptor set = 1.