Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move normalize_image to rten-imageproc and make it customizable #343

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rten-examples/src/deeplab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::{HashSet, VecDeque};
use std::error::Error;

use rten::{Dimension, FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

Expand Down Expand Up @@ -109,7 +110,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let model = Model::load_file(args.model)?;

let mut image: Tensor = read_image(&args.image)?.into();
normalize_image(image.nd_view_mut());
normalize_image(image.nd_view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);
image.insert_axis(0); // Add batch dim

// Resize image according to metadata in the model.
Expand Down
5 changes: 3 additions & 2 deletions rten-examples/src/depth_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::VecDeque;
use std::error::Error;

use rten::{FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

Expand Down Expand Up @@ -78,7 +79,7 @@ fn main() -> Result<(), Box<dyn Error>> {

let mut image: Tensor = read_image(&args.image)?.into();
let [_, orig_height, orig_width] = image.shape().try_into()?;
normalize_image(image.nd_view_mut());
normalize_image(image.nd_view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);
image.insert_axis(0); // Add batch dim

// Input size taken from README in https://github.com/fabio-sim/Depth-Anything-ONNX.
Expand Down
6 changes: 3 additions & 3 deletions rten-examples/src/detr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::collections::VecDeque;
use std::error::Error;

use rten::{FloatOperators, Model, Operators};
use rten_imageio::{normalize_image, read_image, write_image};
use rten_imageproc::{Painter, Rect};
use rten_imageio::{read_image, write_image};
use rten_imageproc::{normalize_image, Painter, Rect, IMAGENET_MEAN, IMAGENET_STD_DEV};
use rten_tensor::prelude::*;
use rten_tensor::NdTensor;

Expand Down Expand Up @@ -285,7 +285,7 @@ fn main() -> Result<(), Box<dyn Error>> {
// Save a copy of the input before normalization and scaling
let mut annotated_image = args.annotated_image.as_ref().map(|_| image.clone());

normalize_image(image.view_mut());
normalize_image(image.view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV);

let [_, image_height, image_width] = image.shape();

Expand Down
26 changes: 7 additions & 19 deletions rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use std::io::prelude::*;
use rten::{FloatOperators, Model};
use rten_generate::{Generator, GeneratorUtils};
use rten_imageio::read_image;
use rten_imageproc::normalize_image;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorViewMut};
use rten_tensor::NdTensor;
use rten_text::tokenizers::Tokenizer;

struct Args {
Expand Down Expand Up @@ -62,23 +63,6 @@ Args:
Ok(args)
}

fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];

(value - mean[channel]) / std_dev[channel]
}

fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for chan in 0..img.size(0) {
img.slice_mut::<2, _>(chan)
.apply(|x| normalize_pixel(*x, chan));
}
}

/// Recognize text line images using TrOCR [^1].
///
/// First use Hugging Face's Optimum tool to download and export the models to
Expand Down Expand Up @@ -114,7 +98,11 @@ fn main() -> Result<(), Box<dyn Error>> {

// From `image_size` in config.json.
let mut image = image.resize_image([384, 384])?;
normalize_image(image.slice_mut(0));

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];
normalize_image(image.slice_mut(0), mean, std_dev);

let encoded_image: NdTensor<f32, 3> = encoder_model
.run_one(image.view().into(), None)?
Expand Down
19 changes: 1 addition & 18 deletions rten-imageio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,11 @@
//! implementation.

use std::error::Error;
use std::iter::zip;
use std::path::Path;

use rten_tensor::errors::FromDataError;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};

/// Apply standard ImageNet normalization to a pixel value.
/// See <https://huggingface.co/facebook/detr-resnet-50#preprocessing>.
pub fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");
let imagenet_mean = [0.485, 0.456, 0.406];
let imagenet_std_dev = [0.229, 0.224, 0.225];
(value - imagenet_mean[channel]) / imagenet_std_dev[channel]
}

/// Apply standard ImageNet normalization to all pixel values in an image.
pub fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for ([chan, _y, _x], pixel) in zip(img.indices(), img.iter_mut()) {
*pixel = normalize_pixel(*pixel, chan);
}
}
use rten_tensor::{NdTensor, NdTensorView};

/// Errors reported when creating a tensor from an image.
#[derive(Debug)]
Expand Down
2 changes: 2 additions & 0 deletions rten-imageproc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
mod contours;
mod drawing;
mod math;
mod normalize;
mod poly_algos;
mod shapes;

pub use contours::{find_contours, RetrievalMode};
pub use drawing::{draw_line, draw_polygon, fill_rect, stroke_rect, FillIter, Painter, Rgb};
pub use math::Vec2;
pub use normalize::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV};
pub use poly_algos::{convex_hull, min_area_rect, simplify_polygon, simplify_polyline};
pub use shapes::{
bounding_rect, BoundingRect, Coord, Line, LineF, Point, PointF, Polygon, PolygonF, Polygons,
Expand Down
38 changes: 38 additions & 0 deletions rten-imageproc/src/normalize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use rten_tensor::prelude::*;
use rten_tensor::NdTensorViewMut;

/// Standard ImageNet normalization mean values, for use with
/// [`normalize_image`].
pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];

/// Standard ImageNet normalization standard deviation values, for use with
/// [`normalize_image`].
pub const IMAGENET_STD_DEV: [f32; 3] = [0.229, 0.224, 0.225];

/// Normalize the mean and standard deviation of all pixels in an image.
///
/// `img` should be a CHW tensor with `C` channels. For each channel `c`, the
/// output pixel values are computed as `y = (x - mean[c]) / std_dev[c]`.
///
/// This is a common preprocessing step for inputs to machine learning models.
/// Many models use standard "ImageNet" constants ([`IMAGENET_MEAN`],
/// [`IMAGENET_STD_DEV`]), but check the expected values for the model you are
/// using.
pub fn normalize_image<const C: usize>(
mut img: NdTensorViewMut<f32, 3>,
mean: [f32; C],
std_dev: [f32; C],
) {
let n_chans = img.size(0);
assert_eq!(
n_chans, C,
"expected image to have {} channels but found {}",
C, n_chans
);

for chan in 0..n_chans {
let inv_std_dev = 1. / std_dev[chan];
img.slice_mut::<2, _>(chan)
.apply(|x| (x - mean[chan]) * inv_std_dev);
}
}
Loading