Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate vivit tutorial to Keras3 [all backends] #1739

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions examples/vision/ipynb/vivit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)<br>\n",
"**Date created:** 2022/01/12<br>\n",
"**Last modified:** 2024/01/13<br>\n",
"**Last modified:** 2024/01/15<br>\n",
"**Description:** A Transformer-based architecture for video classification."
]
},
Expand Down Expand Up @@ -43,8 +43,8 @@
"the embedding scheme and one of the variants of the Transformer\n",
"architecture, for simplicity.\n",
"\n",
"This example requires the `medmnist`\n",
"package, which can be installed by running the code cell below."
"This example requires `medmnist` package, which can be installed\n",
"by running the code cell below."
]
},
{
Expand Down Expand Up @@ -81,11 +81,9 @@
"import medmnist\n",
"import ipywidgets\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf # for data preprocessing only\n",
"import keras\n",
"import tensorflow as tf\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras import layers, ops\n",
"\n",
"# Setting seed for reproducibility\n",
"SEED = 42\n",
Expand Down Expand Up @@ -214,15 +212,17 @@
"outputs": [],
"source": [
"\n",
"def preprocess(frames, label):\n",
"def preprocess(frames: tf.Tensor, label: tf.Tensor):\n",
" \"\"\"Preprocess the frames tensors and parse the labels.\"\"\"\n",
" # Preprocess images\n",
" frames = ops.cast(frames, \"float32\")\n",
" frames = ops.expand_dims(\n",
" frames, axis=-1\n",
" ) # The new axis is to help for further processing with Conv3D layers\n",
" frames = tf.image.convert_image_dtype(\n",
" frames[\n",
" ..., tf.newaxis\n",
" ], # The new axis is to help for further processing with Conv3D layers\n",
" tf.float32,\n",
" )\n",
" # Parse label\n",
" label = ops.cast(label, \"float32\")\n",
" label = tf.cast(label, tf.float32)\n",
" return frames, label\n",
"\n",
"\n",
Expand Down Expand Up @@ -337,7 +337,7 @@
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_tokens, output_dim=self.embed_dim\n",
" )\n",
" self.positions = ops.arange(start=0, stop=num_tokens, step=1)\n",
" self.positions = ops.arange(0, num_tokens, 1)\n",
"\n",
" def call(self, encoded_tokens):\n",
" # Encode the positions and add it to the encoded tokens\n",
Expand Down Expand Up @@ -411,8 +411,8 @@
" x3 = layers.LayerNormalization(epsilon=1e-6)(x2)\n",
" x3 = keras.Sequential(\n",
" [\n",
" layers.Dense(units=embed_dim * 4, activation=\"gelu\"),\n",
" layers.Dense(units=embed_dim, activation=\"gelu\"),\n",
" layers.Dense(units=embed_dim * 4, activation=ops.gelu),\n",
" layers.Dense(units=embed_dim, activation=ops.gelu),\n",
" ]\n",
" )(x3)\n",
"\n",
Expand Down Expand Up @@ -511,9 +511,9 @@
"\n",
"for i, (testsample, label) in enumerate(zip(testsamples, labels)):\n",
" # Generate gif\n",
" testsample = ops.reshape(testsample, (-1, 28, 28))\n",
" testsample = np.reshape(testsample.numpy(), (-1, 28, 28))\n",
" with io.BytesIO() as gif:\n",
" imageio.mimsave(gif, (testsample.numpy() * 255).astype(\"uint8\"), \"GIF\", fps=5)\n",
" imageio.mimsave(gif, (testsample * 255).astype(\"uint8\"), \"GIF\", fps=5)\n",
" videos.append(gif.getvalue())\n",
"\n",
" # Get model prediction\n",
Expand Down
Loading
Loading