Skip to content

Commit

Permalink
Move normalize_image to rten-imageproc and make it customizable
Browse files Browse the repository at this point in the history
Several examples need to normalize an image with a mean and standard deviation
different from the imagenet values. Add a function in rten-imageproc with this
flexibility. The function has been added to the rten-imageproc crate because
it doesn't have any dependencies on specific image formats.

In the process the implementation was also revised to make it more efficient.
  • Loading branch information
robertknight committed Sep 5, 2024
1 parent 87f0a38 commit 33b845d
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 44 deletions.
5 changes: 3 additions & 2 deletions rten-examples/src/deeplab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::{HashSet, VecDeque};
use std::error::Error;

use rten::{Dimension, FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

Expand Down Expand Up @@ -109,7 +110,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let model = Model::load_file(args.model)?;

let mut image: Tensor = read_image(&args.image)?.into();
normalize_image(image.nd_view_mut());
normalize_image(image.nd_view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);
image.insert_axis(0); // Add batch dim

// Resize image according to metadata in the model.
Expand Down
5 changes: 3 additions & 2 deletions rten-examples/src/depth_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::VecDeque;
use std::error::Error;

use rten::{FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

Expand Down Expand Up @@ -78,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {

let mut image: Tensor = read_image(&args.image)?.into();
let [_, orig_height, orig_width] = image.shape().try_into()?;
normalize_image(image.nd_view_mut());
normalize_image(image.nd_view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);
image.insert_axis(0); // Add batch dim

// Input size taken from README in https://github.com/fabio-sim/Depth-Anything-ONNX.
Expand Down
6 changes: 3 additions & 3 deletions rten-examples/src/detr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::VecDeque;
use std::error::Error;

use rten::{FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageproc::{Painter, Rect};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, Painter, Rect, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::NdTensor;

Expand Down Expand Up @@ -285,7 +285,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// Save a copy of the input before normalization and scaling
let mut annotated_image = args.annotated_image.as_ref().map(|_| image.clone());

normalize_image(image.view_mut());
normalize_image(image.view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);

let [_, image_height, image_width] = image.shape();

Expand Down
26 changes: 7 additions & 19 deletions rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use std::io::prelude::*;
use rten::{FloatOperators, Model};
use rten_generate::{Generator, GeneratorUtils};
use rten_imageio::read_image;
use rten_imageproc::normalize_image;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorViewMut};
use rten_tensor::NdTensor;
use rten_text::tokenizers::Tokenizer;

struct Args {
Expand Down Expand Up @@ -62,23 +63,6 @@ Args:
Ok(args)
}

fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];

(value - mean[channel]) / std_dev[channel]
}

fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for chan in 0..img.size(0) {
img.slice_mut::<2, _>(chan)
.apply(|x| normalize_pixel(*x, chan));
}
}

/// Recognize text line images using TrOCR [^1].
///
/// First use Hugging Face's Optimum tool to download and export the models to
Expand Down Expand Up @@ -114,7 +98,11 @@ fn main() -> Result<(), Box<dyn Error>> {

// From `image_size` in config.json.
let mut image = image.resize_image([384, 384])?;
normalize_image(image.slice_mut(0));

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];
normalize_image(image.slice_mut(0), mean, std_dev);

let encoded_image: NdTensor<f32, 3> = encoder_model
.run_one(image.view().into(), None)?
Expand Down
19 changes: 1 addition & 18 deletions rten-imageio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,11 @@
//! implementation.

use std::error::Error;
use std::iter::zip;
use std::path::Path;

use rten_tensor::errors::FromDataError;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};

/// Apply standard ImageNet normalization to a pixel value.
/// See <https://huggingface.co/facebook/detr-resnet-50#preprocessing>.
pub fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");
let imagenet_mean = [0.485, 0.456, 0.406];
let imagenet_std_dev = [0.229, 0.224, 0.225];
(value - imagenet_mean[channel]) / imagenet_std_dev[channel]
}

/// Apply standard ImageNet normalization to all pixel values in an image.
pub fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for ([chan, _y, _x], pixel) in zip(img.indices(), img.iter_mut()) {
*pixel = normalize_pixel(*pixel, chan);
}
}
use rten_tensor::{NdTensor, NdTensorView};

/// Errors reported when creating a tensor from an image.
#[derive(Debug)]
Expand Down
2 changes: 2 additions & 0 deletions rten-imageproc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
mod contours;
mod drawing;
mod math;
mod normalize;
mod poly_algos;
mod shapes;

pub use contours::{find_contours, RetrievalMode};
pub use drawing::{draw_line, draw_polygon, fill_rect, stroke_rect, FillIter, Painter, Rgb};
pub use math::Vec2;
pub use normalize::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
pub use poly_algos::{convex_hull, min_area_rect, simplify_polygon, simplify_polyline};
pub use shapes::{
bounding_rect, BoundingRect, Coord, Line, LineF, Point, PointF, Polygon, PolygonF, Polygons,
Expand Down
38 changes: 38 additions & 0 deletions rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use rten_tensor::prelude::*;
use rten_tensor::NdTensorViewMut;

/// Standard ImageNet normalization mean values, for use with
/// [`normalize_image`].
pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];

/// Standard ImageNet normalization standard deviation values, for use with
/// [`normalize_image`].
pub const IMAGENET_STD_DEV: [f32; 3] = [0.229, 0.224, 0.225];

/// Normalize the mean and standard deviation of all pixels in an image.
///
/// `img` should be a CHW tensor with `C` channels. For each channel `c`, the
/// output pixel values are computed as `y = (x - mean[c]) / std_dev[c]`.
///
/// This is a common preprocessing step for inputs to machine learning models.
/// Many models use standard "ImageNet" constants ([`IMAGENET_MEAN`],
/// [`IMAGENET_STD_DEV`]), but check the expected values for the model you are
/// using.
pub fn normalize_image<const C: usize>(
mut img: NdTensorViewMut<f32, 3>,
mean: [f32; C],
std_dev: [f32; C],
) {
let n_chans = img.size(0);
assert_eq!(
n_chans, C,
"expected image to have {} channels but found {}",
C, n_chans
);

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_mut::<2, _>(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}

0 comments on commit 33b845d

Please sign in to comment.