Skip to content

Commit

Permalink
Update tensorflow_lite.md
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Yuan <[email protected]>
  • Loading branch information
juntao authored Aug 20, 2023
1 parent 637999d commit 7002ff6
Showing 1 changed file with 42 additions and 89 deletions.
131 changes: 42 additions & 89 deletions docs/develop/rust/wasinn/tensorflow_lite.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
sidebar_position: 2
---

# TensorFlow-Lite Backend
# TensorFlow Lite Backend

We will use [this example project](https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/tflite-birds_v1-image) to show how to make AI inference with a TensorFlow Lite model in WasmEdge and Rust.

## Prerequisite

Besides the [regular WasmEdge and Rust requirements](../../rust/setup.md), please make sure that you have the [Wasi-NN plugin with TensorFlow Lite installed](../../../start/install.md#wasi-nn-plug-in-with-tensorflow-lite-backend).
Besides the [regular WasmEdge and Rust requirements](../../rust/setup.md), please make sure that you have the [WASI-NN plugin with TensorFlow Lite installed](../../../start/install.md#wasi-nn-plug-in-with-tensorflow-lite-backend).

## Quick Start

Because the example already includes a compiled WASM file from the Rust code, we could use WasmEdge CLI to execute the example directly.

First, git clone the `WasmEdge-WASINN-examples`.
Because the example already includes a compiled WASM file from the Rust code, we could use WasmEdge CLI to execute the example directly. First, git clone the `WasmEdge-WASINN-examples` repo.

```bash
git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd WasmEdge-WASINN-examples
cd WasmEdge-WASINN-examples/tflite-birds_v1-image/
```

Run the inference application in WasmEdge.

```bash
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg
```
Expand All @@ -40,129 +40,82 @@ Executed graph inference
5.) [819](1)Anas platyrhynchos
```

## Build and Run the example from Rust source code

Let's build the wasm file from the rust source code.
## Build and Run the example from the Rust source code

First, git clone the `WasmEdge-WASINN-examples`.
Let's build the wasm file from the rust source code. First, git clone the `WasmEdge-WASINN-examples` repo.

```bash
git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd tflite-birds_v1-image/rust/
cd WasmEdge-WASINN-examples/tflite-birds_v1-image/rust/
```

Second, use `cargo` to build the template project.
Second, use `cargo` to build the example project.

```bash
cargo build --target wasm32-wasi --release
```

The output WASM file is `target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm`.

Next, let's use WasmEdge to identify your images.
The output WASM file is `target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm`. Next, let's use WasmEdge to load the Tensorflow Lite model and then use it to classify objects in your image.

```bash
wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.xml mobilenet.bin input.jpg
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg
```

You can replace `input.jpg` with your image file.
You can replace `bird.jpg` with your image file.

## Improve performance

For the AOT mode, which is much more quickly, you can compile the WASM first:
You can make the inference program run faster by AOT compiling the `wasm` file first.

```bash
wasmedgec rust/tflite-bird/target/wasm32-wasi/release/wasmedge-wasinn-example-tflite-bird-image.wasm wasmedge-wasinn-example-tflite-bird-image.wasm
wasmedge --dir .:. wasmedge-wasinn-example-tflite-bird-image.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg
wasmedge compile wasmedge-wasinn-example-tflite-bird-image.wasm out.wasm
wasmedge --dir .:. out.wasm lite-model_aiy_vision_classifier_birds_V1_3.tflite bird.jpg
```

## Understand the code

The [main.rs](https://github.com/second-state/WasmEdge-WASINN-examples/blob/master/tflite-birds_v1-image/rust/tflite-bird/src/main.rs) is the complete example Rust source.

First, read the model description and weights into memory:
The [main.rs](https://github.com/second-state/WasmEdge-WASINN-examples/blob/master/tflite-birds_v1-image/rust/tflite-bird/src/main.rs) is the complete example Rust source. First, read the image file and Tensorflow Lite (tflite) model file into the application memory.

```rust
let args: Vec<String> = env::args().collect();
let model_bin_name: &str = &args[1]; // File name for the tflite model
let model_bin_name: &str = &args[1]; // File name for the TFLite model
let image_name: &str = &args[2]; // File name for the input image

let weights = fs::read(model_bin_name).unwrap();
```

We should use a helper function to convert the input image into the tensor data (the tensor type is `U8`):

```rust
fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
let pixels = Reader::open(path).unwrap().decode().unwrap();
let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);
let bgr_img = dyn_img.to_rgb8();
// Get an array of the pixel values
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];
return raw_u8_arr.to_vec();
}
```

And use this helper function to convert the input image:
We use a helper function called `image_to_tensor()` to convert the input image into tensor data (the tensor type is `U8`). Now we can load the model, feed the tensor array from the image to the model, and get the inference output tensor array.

```rust
// load model
let weights = fs::read(model_bin_name)?;
let graph = GraphBuilder::new(
GraphEncoding::TensorflowLite,
ExecutionTarget::CPU,
).build_from_bytes(&[&weights])?;
let mut ctx = graph.init_execution_context()?;

// Load a tensor that precisely matches the graph input tensor
let tensor_data = image_to_tensor(image_name.to_string(), 224, 224);
```
ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?;

Now we can start our inference with WASI-NN:

```rust
// load model
let graph = unsafe {
wasi_nn::load(
&[&weights],
wasi_nn::GRAPH_ENCODING_PYTORCH,
wasi_nn::EXECUTION_TARGET_CPU,
)
.unwrap()
};
// initialize the computation context
let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
// initialize the input tensor
let tensor = wasi_nn::Tensor {
dimensions: &[1, 3, 224, 224],
r#type: wasi_nn::TENSOR_TYPE_F32,
data: &tensor_data,
};
// set_input
unsafe {
wasi_nn::set_input(context, 0, tensor).unwrap();
}
// Execute the inference.
unsafe {
wasi_nn::compute(context).unwrap();
}
// retrieve output
let mut output_buffer = vec![0f32; 1001];
unsafe {
wasi_nn::get_output(
context,
0,
&mut output_buffer[..] as *mut [f32] as *mut u8,
(output_buffer.len() * 4).try_into().unwrap(),
)
.unwrap();
}
```
ctx.compute().unwrap();

Where the `wasi_nn::GRAPH_ENCODING_TENSORFLOWLITE` means using the TensorFlow-Lite backend and `wasi_nn::EXECUTION_TARGET_CPU` means running the computation on CPU.
// Retrieve the output.
let mut output_buffer = vec![0u8; imagenet_classes::AIY_BIRDS_V1.len()];
_ = ctx.get_output(0, &mut output_buffer)?;
```

Finally, we sort the output and then print the top-5 classification results:
In the above code, `GraphEncoding::TensorflowLite` means using the PyTorch backend, and `ExecutionTarget::CPU` means running the computation on the CPU. Finally, we sort the output and then print the top-5 classification results. Finally, we sort the output and then print the top-5 classification results:

```rust
let results = sort_results(&output_buffer);
for i in 0..5 {
println!(
" {}.) [{}]({:.4}){}",
i + 1,
results[i].0,
results[i].1,
imagenet_classes::IMAGENET_CLASSES[results[i].0]
);
println!(
" {}.) [{}]({:.4}){}",
i + 1,
results[i].0,
results[i].1,
imagenet_classes::AIY_BIRDS_V1[results[i].0]
);
}
```

0 comments on commit 7002ff6

Please sign in to comment.