Skip to content

Commit

Permalink
Implemented swithing correlation direction.
Browse files Browse the repository at this point in the history
  • Loading branch information
zlogic committed Jan 28, 2024
1 parent cd81313 commit 77357b3
Showing 1 changed file with 24 additions and 30 deletions.
54 changes: 24 additions & 30 deletions src/correlation/vk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ struct Buffer {
buffer_memory: vk::DeviceMemory,
host_visible: bool,
host_coherent: bool,
size: u64,
}

enum BufferType {
Expand All @@ -108,7 +107,6 @@ struct DeviceBuffers {
}

struct DescriptorSets {
direction: CorrelationDirection,
descriptor_pool: vk::DescriptorPool,
regular_layout: vk::DescriptorSetLayout,
cross_check_layout: vk::DescriptorSetLayout,
Expand Down Expand Up @@ -163,6 +161,7 @@ impl GpuContext {

let start_time = SystemTime::now();
let device = Device::new(img1_pixels, img2_pixels)?;
device.set_buffer_direction(&CorrelationDirection::Forward);

if let Ok(t) = start_time.elapsed() {
println!("Initialized device in {:.3} seconds", t.as_secs_f32());
Expand Down Expand Up @@ -198,7 +197,7 @@ impl GpuContext {
scale: f32,
dir: CorrelationDirection,
) -> Result<(), Box<dyn error::Error>> {
// TODO: recreate descriptor sets if direction changes
self.device.set_buffer_direction(&dir);
let (out_dimensions, out_dimensions_reverse) = match dir {
CorrelationDirection::Forward => (self.img1_dimensions, self.img2_dimensions),
CorrelationDirection::Reverse => (self.img2_dimensions, self.img1_dimensions),
Expand Down Expand Up @@ -254,11 +253,7 @@ impl GpuContext {
progress_listener: Option<&PL>,
dir: CorrelationDirection,
) -> Result<(), Box<dyn error::Error>> {
if matches!(dir, CorrelationDirection::Reverse) {
// TODO: fix this
return Ok(());
}
// TODO: recreate descriptor sets if direction changes
self.device.set_buffer_direction(&dir);
let max_width = img1.width().max(img2.width());
let max_height = img1.height().max(img2.height());
let max_dimensions = (max_width, max_height);
Expand Down Expand Up @@ -300,7 +295,6 @@ impl GpuContext {
min_range: CORRIDOR_MIN_RANGE as f32,
};

println!("correlate images");
unsafe { self.device.transfer_in_images(img1, img2)? };

if first_pass {
Expand Down Expand Up @@ -481,10 +475,8 @@ impl Device {
err
};
// Init pipelines and shaders.
let descriptor_sets = unsafe {
Device::create_descriptor_sets(&device, &buffers, CorrelationDirection::Forward)
.map_err(cleanup_err)?
};
let descriptor_sets =
unsafe { Device::create_descriptor_sets(&device).map_err(cleanup_err)? };
let cleanup_err = |err| unsafe {
descriptor_sets.destroy(&device);
buffers.destroy(&device);
Expand Down Expand Up @@ -786,7 +778,7 @@ impl Device {
out_image: &mut Grid<Option<super::Match>>,
correlation_values: &Grid<Option<f32>>,
) -> Result<(), Box<dyn error::Error>> {
// TODO: combine this with save_result
// TODO: combine this with save_corr
let size = out_image.width() * out_image.height() * 2;
let size_bytes = size * std::mem::size_of::<u32>();
let width = out_image.width();
Expand Down Expand Up @@ -953,7 +945,8 @@ impl Device {
}
return a.0.cmp(&b.0);
});
let (device, name, queue_index) = if let Some((device, name, queue_index, score)) = device {
let (device, name, queue_index) = if let Some((device, name, queue_index, _score)) = device
{
(device, name, queue_index)
} else {
return Err(GpuError::new("Device not found").into());
Expand Down Expand Up @@ -1164,16 +1157,13 @@ impl Device {
buffer_memory,
host_visible,
host_coherent,
size,
};
device.bind_buffer_memory(buffer, buffer_memory, 0)?;
Ok(result)
}

unsafe fn create_descriptor_sets(
device: &ash::Device,
buffers: &DeviceBuffers,
direction: CorrelationDirection,
) -> Result<DescriptorSets, Box<dyn error::Error>> {
let create_layout_bindings = |count| {
let bindings = (0..count)
Expand Down Expand Up @@ -1243,7 +1233,18 @@ impl Device {
.allocate_descriptor_sets(&descriptor_set_allocate_info)
.map_err(cleanup_err)?;

// TODO: extract this to allow switching direction on the fly.
Ok(DescriptorSets {
descriptor_pool,
regular_layout,
cross_check_layout,
pipeline_layout,
descriptor_sets,
})
}

fn set_buffer_direction(&self, direction: &CorrelationDirection) {
let descriptor_sets = &self.descriptor_sets;
let buffers = &self.buffers;
let create_buffer_infos = |buffers: &[Buffer]| {
buffers
.iter()
Expand All @@ -1258,7 +1259,7 @@ impl Device {
};
let create_write_descriptor = |i: usize, buffer_infos: &[vk::DescriptorBufferInfo]| {
vk::WriteDescriptorSet::builder()
.dst_set(descriptor_sets[i])
.dst_set(descriptor_sets.descriptor_sets[i])
.dst_binding(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(buffer_infos)
Expand All @@ -1281,16 +1282,9 @@ impl Device {
create_write_descriptor(0, regular_buffer_infos.as_slice()),
create_write_descriptor(1, cross_check_buffer_infos.as_slice()),
];
device.update_descriptor_sets(&write_descriptors, &[]);

Ok(DescriptorSets {
direction: direction.to_owned(),
descriptor_pool,
regular_layout,
cross_check_layout,
pipeline_layout,
descriptor_sets,
})
unsafe {
self.device.update_descriptor_sets(&write_descriptors, &[]);
}
}

unsafe fn load_shaders(
Expand Down

0 comments on commit 77357b3

Please sign in to comment.