Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated MobileViT example to Keras v3 #1758

Merged
merged 3 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions examples/vision/ipynb/mobilevit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Date created:** 2021/10/20<br>\n",
"**Last modified:** 2021/10/20<br>\n",
"**Last modified:** 2024/02/11<br>\n",
"**Description:** MobileViT for image classification with combined benefits of convolutions and Transformers."
]
},
Expand Down Expand Up @@ -50,22 +50,22 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"import tensorflow as tf\n",
"\n",
"from keras.src.applications import imagenet_utils\n",
"# For versions <TF2.13 change the above import to:\n",
"# from keras.applications import imagenet_utils\n",
"from tensorflow.keras import layers\n",
"from tensorflow import keras\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import keras\n",
"from keras import layers\n",
"from keras import backend\n",
"\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow_addons as tfa\n",
"\n",
"tfds.disable_progress_bar()"
]
Expand All @@ -81,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -116,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -125,31 +125,54 @@
"\n",
"def conv_block(x, filters=16, kernel_size=3, strides=2):\n",
" conv_layer = layers.Conv2D(\n",
" filters, kernel_size, strides=strides, activation=tf.nn.swish, padding=\"same\"\n",
" filters,\n",
" kernel_size,\n",
" strides=strides,\n",
" activation=keras.activations.swish,\n",
" padding=\"same\",\n",
" )\n",
" return conv_layer(x)\n",
"\n",
"\n",
"# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435\n",
"\n",
"\n",
"def correct_pad(inputs, kernel_size):\n",
" img_dim = 2 if backend.image_data_format() == \"channels_first\" else 1\n",
" input_size = inputs.shape[img_dim : (img_dim + 2)]\n",
" if isinstance(kernel_size, int):\n",
" kernel_size = (kernel_size, kernel_size)\n",
" if input_size[0] is None:\n",
" adjust = (1, 1)\n",
" else:\n",
" adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)\n",
" correct = (kernel_size[0] // 2, kernel_size[1] // 2)\n",
" return (\n",
" (correct[0] - adjust[0], correct[0]),\n",
" (correct[1] - adjust[1], correct[1]),\n",
" )\n",
"\n",
"\n",
"# Reference: https://git.io/JKgtC\n",
"\n",
"\n",
"def inverted_residual_block(x, expanded_channels, output_channels, strides=1):\n",
" m = layers.Conv2D(expanded_channels, 1, padding=\"same\", use_bias=False)(x)\n",
" m = layers.BatchNormalization()(m)\n",
" m = tf.nn.swish(m)\n",
" m = keras.activations.swish(m)\n",
"\n",
" if strides == 2:\n",
" m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)\n",
" m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)\n",
" m = layers.DepthwiseConv2D(\n",
" 3, strides=strides, padding=\"same\" if strides == 1 else \"valid\", use_bias=False\n",
" )(m)\n",
" m = layers.BatchNormalization()(m)\n",
" m = tf.nn.swish(m)\n",
" m = keras.activations.swish(m)\n",
"\n",
" m = layers.Conv2D(output_channels, 1, padding=\"same\", use_bias=False)(m)\n",
" m = layers.BatchNormalization()(m)\n",
"\n",
" if tf.math.equal(x.shape[-1], output_channels) and strides == 1:\n",
" if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:\n",
" return layers.Add()([m, x])\n",
" return m\n",
"\n",
Expand All @@ -160,7 +183,7 @@
"\n",
"def mlp(x, hidden_units, dropout_rate):\n",
" for units in hidden_units:\n",
" x = layers.Dense(units, activation=tf.nn.swish)(x)\n",
" x = layers.Dense(units, activation=keras.activations.swish)(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
"\n",
Expand All @@ -178,7 +201,11 @@
" # Layer normalization 2.\n",
" x3 = layers.LayerNormalization(epsilon=1e-6)(x2)\n",
" # MLP.\n",
" x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)\n",
" x3 = mlp(\n",
" x3,\n",
" hidden_units=[x.shape[-1] * 2, x.shape[-1]],\n",
" dropout_rate=0.1,\n",
" )\n",
" # Skip connection 2.\n",
" x = layers.Add()([x3, x2])\n",
"\n",
Expand Down Expand Up @@ -217,8 +244,7 @@
" local_global_features, filters=projection_dim, strides=strides\n",
" )\n",
"\n",
" return local_global_features\n",
""
" return local_global_features\n"
]
},
{
Expand Down Expand Up @@ -261,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -336,7 +362,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -368,8 +394,7 @@
" if is_training:\n",
" dataset = dataset.shuffle(batch_size * 10)\n",
" dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)\n",
" return dataset.batch(batch_size).prefetch(auto)\n",
""
" return dataset.batch(batch_size).prefetch(auto)\n"
]
},
{
Expand All @@ -393,7 +418,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -423,7 +448,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -441,7 +466,8 @@
" mobilevit_xxs = create_mobilevit(num_classes=num_classes)\n",
" mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=[\"accuracy\"])\n",
"\n",
" checkpoint_filepath = \"/tmp/checkpoint\"\n",
" # When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`\n",
" checkpoint_filepath = \"/tmp/checkpoint.weights.h5\"\n",
" checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
" checkpoint_filepath,\n",
" monitor=\"val_accuracy\",\n",
Expand Down Expand Up @@ -479,14 +505,14 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"# Serialize the model as a SavedModel.\n",
"mobilevit_xxs.save(\"mobilevit_xxs\")\n",
"tf.saved_model.save(mobilevit_xxs, \"mobilevit_xxs\")\n",
"\n",
"# Convert to TFLite. This form of quantization is called\n",
"# post-training dynamic-range quantization in TFLite.\n",
Expand All @@ -510,7 +536,7 @@
"inference with TFLite models, check out\n",
"[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs) ",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT)."
]
}
Expand Down Expand Up @@ -539,7 +565,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
65 changes: 46 additions & 19 deletions examples/vision/md/mobilevit.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>
**Date created:** 2021/10/20<br>
**Last modified:** 2021/10/20<br>
**Last modified:** 2024/02/11<br>
**Description:** MobileViT for image classification with combined benefits of convolutions and Transformers.


Expand Down Expand Up @@ -34,17 +34,16 @@ Note: This example should be run with Tensorflow 2.13 and higher.


```python
import os
import tensorflow as tf

from keras.src.applications import imagenet_utils
# For versions <TF2.13 change the above import to:
# from keras.applications import imagenet_utils
os.environ["KERAS_BACKEND"] = "tensorflow"

from tensorflow.keras import layers
from tensorflow import keras
import keras
from keras import layers
from keras import backend

import tensorflow_datasets as tfds
import tensorflow_addons as tfa

tfds.disable_progress_bar()
```
Expand Down Expand Up @@ -80,31 +79,54 @@ presented in the figure below (taken from the

def conv_block(x, filters=16, kernel_size=3, strides=2):
conv_layer = layers.Conv2D(
filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
filters,
kernel_size,
strides=strides,
activation=keras.activations.swish,
padding="same",
)
return conv_layer(x)


# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435


def correct_pad(inputs, kernel_size):
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
input_size = inputs.shape[img_dim : (img_dim + 2)]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if input_size[0] is None:
adjust = (1, 1)
else:
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
return (
(correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]),
)


# Reference: https://git.io/JKgtC


def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = tf.nn.swish(m)
m = keras.activations.swish(m)

if strides == 2:
m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
)(m)
m = layers.BatchNormalization()(m)
m = tf.nn.swish(m)
m = keras.activations.swish(m)

m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)

if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m

Expand All @@ -115,7 +137,7 @@ def inverted_residual_block(x, expanded_channels, output_channels, strides=1):

def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.swish)(x)
x = layers.Dense(units, activation=keras.activations.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return x

Expand All @@ -133,7 +155,11 @@ def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
x3 = mlp(
x3,
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
dropout_rate=0.1,
)
# Skip connection 2.
x = layers.Add()([x3, x2])

Expand Down Expand Up @@ -640,8 +666,6 @@ Trainable params: 1,305,077
Non-trainable params: 2,544
__________________________________________________________________________________________________

```
</div>
---
## Dataset preparation

Expand Down Expand Up @@ -728,7 +752,8 @@ def run_experiment(epochs=epochs):
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

checkpoint_filepath = "/tmp/checkpoint"
# When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
Expand Down Expand Up @@ -828,7 +853,7 @@ and can be converted with the following code:

```python
# Serialize the model as a SavedModel.
mobilevit_xxs.save("mobilevit_xxs")
tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")

# Convert to TFLite. This form of quantization is called
# post-training dynamic-range quantization in TFLite.
Expand All @@ -845,4 +870,6 @@ open("mobilevit_xxs.tflite", "wb").write(tflite_model)
To learn more about different quantization recipes available in TFLite and running
inference with TFLite models, check out
[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT).

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT).
Loading