Skip to content

Commit

Permalink
Run avg/stdev correlation in parallel.
Browse files Browse the repository at this point in the history
Inlined some of the parallel iterator functions.
  • Loading branch information
zlogic committed Jan 21, 2024
1 parent dc5c349 commit fa33e32
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
62 changes: 34 additions & 28 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,11 +616,6 @@ impl PointCorrelations {
}
}

struct PointDataCompact {
avg: f32,
stdev: f32,
}

struct ImagePointData {
avg: Grid<f32>,
stdev: Grid<f32>,
Expand All @@ -631,52 +626,63 @@ fn compute_image_point_data(img: &Grid<u8>) -> ImagePointData {
avg: Grid::new(img.width(), img.height(), f32::NAN),
stdev: Grid::new(img.width(), img.height(), f32::NAN),
};
data.avg
.iter_mut()
.zip(data.stdev.iter_mut())
.for_each(|((x, y, avg), (_x, _y, stdev))| {
let point = Point2D::new(x, y);
let p = match compute_compact_point_data(img, &point) {
Some(p) => p,
None => return,
};
*avg = p.avg;
*stdev = p.stdev;
});
data.avg.par_iter_mut().for_each(|(x, y, avg)| {
let point = Point2D::new(x, y);
let point_avg = match compute_point_avg(img, &point) {
Some(p) => p,
None => return,
};
*avg = point_avg;
});
data.stdev.par_iter_mut().for_each(|(x, y, stdev)| {
let point = Point2D::new(x, y);
let avg = data.avg.val(x, y);
let point_stdev = match compute_point_stdev(img, &point, *avg) {
Some(p) => p,
None => return,
};
*stdev = point_stdev;
});
data
}

#[inline]
fn compute_compact_point_data(img: &Grid<u8>, point: &Point2D<usize>) -> Option<PointDataCompact> {
fn compute_point_avg(img: &Grid<u8>, point: &Point2D<usize>) -> Option<f32> {
if !point_inside_bounds::<KERNEL_SIZE>(img, point) {
return None;
};
let mut result = PointDataCompact {
avg: 0.0,
stdev: 0.0,
};
let mut avg = 0.0f32;
for y in 0..KERNEL_WIDTH {
let s_y = (point.y + y).saturating_sub(KERNEL_SIZE);
for x in 0..KERNEL_WIDTH {
let s_x = (point.x + x).saturating_sub(KERNEL_SIZE);
let value = img.val(s_x, s_y);
result.avg += *value as f32;
avg += *value as f32;
}
}
result.avg /= KERNEL_POINT_COUNT as f32;
avg /= KERNEL_POINT_COUNT as f32;
Some(avg)
}

#[inline]
fn compute_point_stdev(img: &Grid<u8>, point: &Point2D<usize>, avg: f32) -> Option<f32> {
if !point_inside_bounds::<KERNEL_SIZE>(img, point) {
return None;
};
let mut stdev = 0.0f32;

for y in 0..KERNEL_WIDTH {
let s_y = (point.y + y).saturating_sub(KERNEL_SIZE);
for x in 0..KERNEL_WIDTH {
let s_x = (point.x + x).saturating_sub(KERNEL_SIZE);
let value = img.val(s_x, s_y);
let delta = *value as f32 - result.avg;
result.stdev += delta * delta;
let delta = *value as f32 - avg;
stdev += delta * delta;
}
}
result.stdev = (result.stdev / KERNEL_POINT_COUNT as f32).sqrt();
stdev = (stdev / KERNEL_POINT_COUNT as f32).sqrt();

Some(result)
Some(stdev)
}

#[inline]
Expand Down
44 changes: 40 additions & 4 deletions src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,25 @@ where
{
type Item = (usize, usize, &'a T);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let i = self.range.next()?;
let x = i % self.width;
let y = i / self.width;
let val = unsafe { &(*self.grid).data[i] };
Some((x, y, val))
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let hint = self.range.len();
(hint, Some(hint))
}

#[inline]
fn count(self) -> usize {
self.range.len()
}
}

unsafe impl<'a, T> Send for GridIter<'a, T> where T: Send + Sync {}
Expand All @@ -125,6 +137,7 @@ impl<'a, T> DoubleEndedIterator for GridIter<'a, T>
where
T: Send + Sync,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
let i = self.range.next_back()?;
let x = i % self.width;
Expand All @@ -149,13 +162,16 @@ where

type IntoIter = GridIter<'a, T>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
self.it
}

#[inline]
fn split_at(self, index: usize) -> (Self, Self) {
let left_range = self.it.range.start..self.it.range.start + index;
let right_range = left_range.end..self.it.range.end;
let mid = (self.it.range.start + index).min(self.it.range.end);
let left_range = self.it.range.start..mid;
let right_range = mid..self.it.range.end;
let left = ParGridIter {
it: GridIter {
grid: self.it.grid,
Expand All @@ -180,6 +196,7 @@ impl<'a, T> IndexedParallelIterator for ParGridIter<'a, T>
where
T: Send + Sync,
{
#[inline]
fn len(&self) -> usize {
self.it.range.len()
}
Expand All @@ -206,6 +223,7 @@ where
bridge(self, consumer)
}

#[inline]
fn opt_len(&self) -> Option<usize> {
Some(self.it.range.len())
}
Expand All @@ -227,13 +245,25 @@ where
{
type Item = (usize, usize, &'a mut T);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
let i = self.range.next()?;
let x = i % self.width;
let y = i / self.width;
let val = unsafe { &mut (*self.grid).data[i] };
Some((x, y, val))
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let hint = self.range.len();
(hint, Some(hint))
}

#[inline]
fn count(self) -> usize {
self.range.len()
}
}

unsafe impl<'a, T> Send for GridIterMut<'a, T> where T: Send + Sync {}
Expand All @@ -244,6 +274,7 @@ impl<'a, T> DoubleEndedIterator for GridIterMut<'a, T>
where
T: Send + Sync,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
let i = self.range.next_back()?;
let x = i % self.width;
Expand All @@ -268,13 +299,16 @@ where

type IntoIter = GridIterMut<'a, T>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
self.it
}

#[inline]
fn split_at(self, index: usize) -> (Self, Self) {
let left_range = self.it.range.start..self.it.range.start + index;
let right_range = left_range.end..self.it.range.end;
let mid = (self.it.range.start + index).min(self.it.range.end);
let left_range = self.it.range.start..mid;
let right_range = mid..self.it.range.end;
let left = ParGridIterMut {
it: GridIterMut {
grid: self.it.grid,
Expand All @@ -299,6 +333,7 @@ impl<'a, T> IndexedParallelIterator for ParGridIterMut<'a, T>
where
T: Sync + Send,
{
#[inline]
fn len(&self) -> usize {
self.it.range.len()
}
Expand All @@ -325,6 +360,7 @@ where
bridge(self, consumer)
}

#[inline]
fn opt_len(&self) -> Option<usize> {
Some(self.it.range.len())
}
Expand Down

0 comments on commit fa33e32

Please sign in to comment.