diff --git a/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_finetune.ipynb b/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_finetune.ipynb
new file mode 100644
index 000000000..b39443d53
--- /dev/null
+++ b/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_finetune.ipynb
@@ -0,0 +1,1497 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tce3stUlHN0L"
+ },
+ "source": [
+ "##### Copyright 2024 Google LLC."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tuOe1ymfHZPu"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N_yUpPhqrRrK"
+ },
+ "source": [
+ "# Fine-tuning RecurrentGemma using JAX and Flax"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-yDXE-RX835U"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MUnQEMHBt3nc"
+ },
+ "source": [
+ "This tutorial demonstrates how to fine-tune the [RecurrentGemma](https://ai.google.dev/gemma/docs/recurrentgemma) 2B Instruct model for an English-French translation task using [Google DeepMind's `recurrentgemma` library](https://github.com/google-deepmind/recurrentgemma), [JAX](https://jax.readthedocs.io) (a high-performance numerical computing library), [Flax](https://flax.readthedocs.io) (the JAX-based neural network library), [Chex](https://chex.readthedocs.io/en/latest/) (a library of utilities for writing reliable JAX code), [Optax](https://optax.readthedocs.io/en/latest/) (the JAX-based gradient processing and optimization library), and the [MTNT (Machine Translation of Noisy Text) dataset](https://arxiv.org/abs/1809.00388). Although Flax is not used directly in this notebook, Flax was used to create Gemma.\n",
+ "\n",
+ "The `recurrentgemma` library was written with JAX, Flax, [Orbax](https://orbax.readthedocs.io/) (a JAX-based library for training utilities like checkpointing), and [SentencePiece](https://github.com/google/sentencepiece) (a tokenizer/detokenizer library).\n",
+ "\n",
+ "This notebook can run on Google Colab with the T4 GPU (go to **Edit** > **Notebook settings** > Under **Hardware accelerator** select **T4 GPU**)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dbRLI7Q4-8Ve"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "The following sections explain the steps for preparing a notebook to use a RecurrentGemma model, including model access, getting an API key, and configuring the notebook runtime."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "n8Ku4iK6PnC0"
+ },
+ "source": [
+ "### Set up Kaggle access for Gemma\n",
+ "\n",
+ "To complete this tutorial, you first need to follow the setup instructions _similar_ to [Gemma setup](https://ai.google.dev/gemma/docs/setup) with a few exceptions:\n",
+ "\n",
+ "* Get access to RecurrentGemma (instead of Gemma) on [kaggle.com](https://www.kaggle.com/models/google/recurrentgemma).\n",
+ "* Select a Colab runtime with sufficient resources to run the RecurrentGemma model.\n",
+ "* Generate and configure a Kaggle username and API key.\n",
+ "\n",
+ "After you've completed the RecurrentGemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
+ "\n",
+ "### Set environment variables\n",
+ "\n",
+ "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AVH6Y4k2964n"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from google.colab import userdata # `userdata` is a Colab API.\n",
+ "\n",
+ "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
+ "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "m1UE1CEnE9ql"
+ },
+ "source": [
+ "### Install the `recurrentgemma` library\n",
+ "\n",
+ "Free Colab hardware acceleration is currently *insufficient* to run this notebook. If you are using [Colab Pay As You Go or Colab Pro](https://colab.research.google.com/signup), click on **Edit** > **Notebook settings** > Select **A100 GPU** > **Save** to enable hardware acceleration.\n",
+ "\n",
+ "Next, you need to install the Google DeepMind `recurrentgemma` library from [`github.com/google-deepmind/recurrentgemma`](https://github.com/google-deepmind/recurrentgemma). If you get an error about \"pip's dependency resolver\", you can usually ignore it.\n",
+ "\n",
+ "**Note:** By installing `recurrentgemma`, you will also install [`flax`](https://flax.readthedocs.io), core [`jax`](https://jax.readthedocs.io), [`optax`](https://optax.readthedocs.io/en/latest/) (the JAX-based gradient processing and optimization library), [`orbax`](https://orbax.readthedocs.io/), and [`sentencepiece`](https://github.com/google/sentencepiece)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XpSw-_4EEcoY"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.7/40.7 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m41.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for recurrentgemma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q git+https://github.com/google-deepmind/recurrentgemma.git"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-mRkkT-iPYoq"
+ },
+ "source": [
+ "### 4. Import libraries\n",
+ "\n",
+ "This notebook uses [Flax](https://flax.readthedocs.io) (for neural networks), core [JAX](https://jax.readthedocs.io), [SentencePiece](https://github.com/google/sentencepiece) (for tokenization), [Chex](https://chex.readthedocs.io/en/latest/) (a library of utilities for writing reliable JAX code), [Optax](https://optax.readthedocs.io/en/latest/) (the gradient processing and optimization library), and TensorFlow Datasets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ChMf1H4mPVx_"
+ },
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "from typing import Any, Mapping, Iterator\n",
+ "import enum\n",
+ "import functools\n",
+ "\n",
+ "import chex\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import optax\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_datasets as tfds\n",
+ "\n",
+ "import sentencepiece as spm\n",
+ "\n",
+ "from recurrentgemma import jax as recurrentgemma"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oNgKIkxMOsit"
+ },
+ "source": [
+ "## Load the RecurrentGemma model\n",
+ "\n",
+ "1. Load the RecurrentGemma model with [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12), which takes three arguments:\n",
+ "\n",
+ "- `handle`: The model handle from Kaggle\n",
+ "- `path`: (Optional string) The local path\n",
+ "- `force_download`: (Optional boolean) Forces to re-download the model\n",
+ "\n",
+ "**Note:** Be mindful that the RecurrentGemma 2B (IT) model is around 3.85Gb in size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X-i10429N-g2"
+ },
+ "outputs": [],
+ "source": [
+ "RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "j_QdPAGyO5zl"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...\n",
+ "100%|██████████| 3.85G/3.85G [00:50<00:00, 81.5MB/s]\n",
+ "Extracting model files...\n"
+ ]
+ }
+ ],
+ "source": [
+ "import kagglehub\n",
+ "\n",
+ "RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cjnXlLkWcHIy"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "RECURRENTGEMMA_VARIANT: 2b-it\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('RECURRENTGEMMA_VARIANT:', RECURRENTGEMMA_VARIANT)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "E1HzOpDcM04q"
+ },
+ "source": [
+ "**Note:** The path from the output above is where the model weights and tokenizer are saved locally, you will need them for later."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6ytvcJ8FPEMm"
+ },
+ "source": [
+ "2. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:\n",
+ "\n",
+ "- The `tokenizer.model` file will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1`).\n",
+ "- The model checkpoint will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JAwXvpzbuiB5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it\n",
+ "TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model\n"
+ ]
+ }
+ ],
+ "source": [
+ "CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)\n",
+ "TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')\n",
+ "print('CKPT_PATH:', CKPT_PATH)\n",
+ "print('TOKENIZER_PATH:', TOKENIZER_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "U800JRcJVIlF"
+ },
+ "source": [
+ "## Load and prepare the MTNT dataset and the Gemma tokenizer\n",
+ "\n",
+ "You will use the [MTNT (Machine Translation of Noisy Text)](https://arxiv.org/abs/1809.00388) dataset, which is available from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/mtnt).\n",
+ "\n",
+ "Download the English-to-French dataset portion of the MTNT dataset, and then sample two examples. Each sample in the dataset contains two entries: `src`: the original English sentence; and `dst`: the corresponding French translation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "pg8SfQH0EcoY"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading and preparing dataset 35.08 MiB (download: 35.08 MiB, generated: 11.33 MiB, total: 46.41 MiB) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d1e0e55e84e748398b261ad10f68326b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Dl Completed...: 0 url [00:00, ? url/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c0ff76b1edaf4d2a918d131182d43753",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Dl Size...: 0 MiB [00:00, ? MiB/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3815b4da77f245c19f44d7ece4713151",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Extraction completed...: 0 file [00:00, ? file/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6a8e49c6ae9e429ca25e445be503d7c9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating splits...: 0%| | 0/3 [00:00, ? splits/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b92e5dbd12c04d3d9747c00b75fca495",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating train examples...: 0%| | 0/35692 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b1982e8344de403b87c4fc35a9700294",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-train.tfrecord*...: 0%| …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3b8c12832013487dbfafc92412bc0819",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating test examples...: 0%| | 0/1020 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "95ece202e83740078e7b88b2ef3c5b8d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-test.tfrecord*...: 0%| |…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1ba2bc3cec6b4fe884fdf022fbe4cc5d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating valid examples...: 0%| | 0/811 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c086472eb20844c7b790b5a233bca3fb",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/1.0.0.incompleteJLH33K/mtnt-valid.tfrecord*...: 0%| …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.\n",
+ "Example 0:\n",
+ "dst: b'Le groupe de \" toutes les \\xc3\\xa9toiles potentielles de la conf\\xc3\\xa9rence de l\\'Est mais qui ne s\\'en sortent pas dans le groupe de l\\'Ouest \".'\n",
+ "src: b'The group of \\xe2\\x80\\x9ceastern conference potential all stars but not making it in the West\\xe2\\x80\\x9d group.'\n",
+ "\n",
+ "Example 1:\n",
+ "dst: b\"Kameron est-elle un peu aigrie de son manque de temps \\xc3\\xa0 l'\\xc3\\xa9cran ?\"\n",
+ "src: b'Is Kameron a Little Salty About Her Lack of Air Time?'\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "ds = tfds.load(\"mtnt/en-fr\", split=\"train\")\n",
+ "\n",
+ "ds = ds.take(2)\n",
+ "ds = ds.as_numpy_iterator()\n",
+ "\n",
+ "for idx, example in enumerate(ds):\n",
+ " print(f'Example {idx}:')\n",
+ " for key, val in example.items():\n",
+ " print(f'{key}: {val}')\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XlY3EYV6jXWR"
+ },
+ "source": [
+ "Load the Gemma tokenizer, constructed using [`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TpyG5YW1EcoY"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "vocab = spm.SentencePieceProcessor()\n",
+ "vocab.Load(TOKENIZER_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Mk_CV0A1kR4K"
+ },
+ "source": [
+ "Customize the[`SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423) for the English-to-French translation task. Since you will be fine-tuning the English portion of the RecurrentGemma (Griffin) model, you need to make a few adjustments, such as:\n",
+ "\n",
+ "- *The input prefix*: Adding a common prefix to each input signals the translation task. For example, you could use a prompt with a prefix like `Translate this into French: [INPUT_SENTENCE]`.\n",
+ "\n",
+ "- *The translation start suffix*: Adding a suffix at the end of each prompt instructs the Gemma model exactly when to begin the translation process. A new line should do the job.\n",
+ "\n",
+ "- *Language model tokens*: RecurrentGemma (Griffin) models expect a \"beginning of sequence\" token at the beginning of each sequence. Similarly, you need to add an \"end of sequence\" token at the end of each training example.\n",
+ "\n",
+ "Build a custom wrapper around the `SentencePieceProcessor` as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "L9cjK0uxEcoY"
+ },
+ "outputs": [],
+ "source": [
+ "class GriffinTokenizer:\n",
+ " \"\"\"A custom wrapper around a SentencePieceProcessor.\"\"\"\n",
+ "\n",
+ " def __init__(self, spm_processor: spm.SentencePieceProcessor):\n",
+ " self._spm_processor = spm_processor\n",
+ "\n",
+ " @property\n",
+ " def pad_id(self) -> int:\n",
+ " \"\"\"Fast access to the pad ID.\"\"\"\n",
+ " return self._spm_processor.pad_id()\n",
+ "\n",
+ " def tokenize(\n",
+ " self,\n",
+ " example: str | bytes,\n",
+ " prefix: str = '',\n",
+ " suffix: str = '',\n",
+ " add_eos: bool = True,\n",
+ " ) -> jax.Array:\n",
+ " \"\"\"\n",
+ " A tokenization function.\n",
+ "\n",
+ " Args:\n",
+ " example: Input string to tokenize.\n",
+ " prefix: Prefix to add to the input string.\n",
+ " suffix: Suffix to add to the input string.\n",
+ " add_eos: If True, add an end of sentence token at the end of the output\n",
+ " sequence.\n",
+ " Returns:\n",
+ " Tokens corresponding to the input string.\n",
+ " \"\"\"\n",
+ " int_list = [self._spm_processor.bos_id()]\n",
+ " int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n",
+ " if add_eos:\n",
+ " int_list.append(self._spm_processor.eos_id())\n",
+ "\n",
+ " return jnp.array(int_list, dtype=jnp.int32)\n",
+ "\n",
+ " def tokenize_tf_op(\n",
+ " self,\n",
+ " str_tensor: tf.Tensor,\n",
+ " prefix: str = '',\n",
+ " suffix: str = '',\n",
+ " add_eos: bool = True,\n",
+ " ) -> tf.Tensor:\n",
+ " \"\"\"A TensforFlow operator for the `tokenize` function.\"\"\"\n",
+ " encoded = tf.numpy_function(\n",
+ " self.tokenize,\n",
+ " [str_tensor, prefix, suffix, add_eos],\n",
+ " tf.int32)\n",
+ " encoded.set_shape([None])\n",
+ " return encoded\n",
+ "\n",
+ " def to_string(self, tokens: jax.Array) -> str:\n",
+ " \"\"\"Convert an array of tokens to a string.\"\"\"\n",
+ " return self._spm_processor.EncodeIds(tokens.tolist())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h-oJ2ziwxG1L"
+ },
+ "source": [
+ "Try it out by instantiating your new custom `GriffinTokenizer`, and then applying it on a small sample of the MTNT dataset:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xEA-97ioEcoY"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Example 0:\n",
+ "src: [ 2 49688 736 1280 6987 235292 108 651 2778 576\n",
+ " 1080 104745 11982 5736 832 8995 901 780 3547 665\n",
+ " 575 573 4589 235369 2778 235265 108]\n",
+ "dst: [ 2 2025 29653 581 664 16298 1437 55563 41435 7840\n",
+ " 581 683 111452 581 533 235303 9776 4108 2459 679\n",
+ " 485 235303 479 6728 579 1806 2499 709 29653 581\n",
+ " 533 235303 101323 16054 1]\n",
+ "\n",
+ "Example 1:\n",
+ "src: [ 2 49688 736 1280 6987 235292 108 2437 87150 477\n",
+ " 476 11709 230461 8045 3636 40268 576 4252 4897 235336\n",
+ " 108]\n",
+ "dst: [ 2 213606 477 1455 235290 3510 748 8268 191017 2809\n",
+ " 581 2032 69972 581 11495 1305 533 235303 65978 1654\n",
+ " 1]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "def tokenize_source(tokenizer, example: tf.Tensor):\n",
+ " return tokenizer.tokenize_tf_op(\n",
+ " example,\n",
+ " prefix='Translate this into French:\\n',\n",
+ " suffix='\\n',\n",
+ " add_eos=False\n",
+ " )\n",
+ "def tokenize_destination(tokenizer, example: tf.Tensor):\n",
+ " return tokenizer.tokenize_tf_op(example, add_eos=True)\n",
+ "\n",
+ "tokenizer = GriffinTokenizer(vocab)\n",
+ "\n",
+ "ds = tfds.load(\"mtnt/en-fr\",split=\"train\")\n",
+ "ds = ds.take(2)\n",
+ "ds = ds.map(lambda x: {\n",
+ " 'src': tokenize_source(tokenizer, x['src']),\n",
+ " 'dst': tokenize_destination(tokenizer, x['dst'])\n",
+ " })\n",
+ "ds = ds.as_numpy_iterator()\n",
+ "\n",
+ "for idx, example in enumerate(ds):\n",
+ " print(f'Example {idx}:')\n",
+ " for key, val in example.items():\n",
+ " print(f'{key}: {val}')\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qkY_hThVkkqF"
+ },
+ "source": [
+ "Build a data loader for the entire MTNT dataset:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Zm30Q2lnknmG"
+ },
+ "outputs": [],
+ "source": [
+ "@chex.dataclass(frozen=True)\n",
+ "class TrainingInput:\n",
+ " # Input tokens provided to the model.\n",
+ " input_tokens: jax.Array\n",
+ "\n",
+ " # A mask that determines which tokens contribute to the target loss\n",
+ " # calculation.\n",
+ " target_mask: jax.Array\n",
+ "\n",
+ "class DatasetSplit(enum.Enum):\n",
+ " TRAIN = 'train'\n",
+ " VALIDATION = 'valid'\n",
+ "\n",
+ "\n",
+ "class MTNTDatasetBuilder:\n",
+ " \"\"\"A data loader for the MTNT dataset.\"\"\"\n",
+ "\n",
+ " N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}\n",
+ "\n",
+ " BUFFER_SIZE_SHUFFLE = 10_000\n",
+ " TRANSLATION_PREFIX = 'Translate this into French:\\n'\n",
+ " TRANSLATION_SUFFIX = '\\n'\n",
+ "\n",
+ " def __init__(self,\n",
+ " tokenizer : GriffinTokenizer,\n",
+ " max_seq_len: int):\n",
+ " \"\"\"A constructor.\n",
+ "\n",
+ " Args:\n",
+ " tokenizer: The tokenizer to use.\n",
+ " max_seq_len: The size of each sequence in a given batch.\n",
+ " \"\"\"\n",
+ " self._tokenizer = tokenizer\n",
+ " self._base_data = {\n",
+ " DatasetSplit.TRAIN: tfds.load(\"mtnt/en-fr\",split=\"train\"),\n",
+ " DatasetSplit.VALIDATION: tfds.load(\"mtnt/en-fr\",split=\"valid\"),\n",
+ " }\n",
+ " self._max_seq_len = max_seq_len\n",
+ "\n",
+ " def _tokenize_source(self, example: tf.Tensor):\n",
+ " \"\"\"A tokenization function for the source.\"\"\"\n",
+ " return self._tokenizer.tokenize_tf_op(\n",
+ " example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,\n",
+ " add_eos=False\n",
+ " )\n",
+ "\n",
+ " def _tokenize_destination(self, example: tf.Tensor):\n",
+ " \"\"\"A tokenization function for the French translation.\"\"\"\n",
+ " return self._tokenizer.tokenize_tf_op(example, add_eos=True)\n",
+ "\n",
+ " def _pad_up_to_max_len(self,\n",
+ " input_tensor: tf.Tensor,\n",
+ " pad_value: int | bool,\n",
+ " ) -> tf.Tensor:\n",
+ " \"\"\"Pad the given tensor up to sequence length of a batch.\"\"\"\n",
+ " seq_len = tf.shape(input_tensor)[0]\n",
+ " to_pad = tf.maximum(self._max_seq_len - seq_len, 0)\n",
+ " return tf.pad(\n",
+ " input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,\n",
+ " )\n",
+ "\n",
+ " def _to_training_input(\n",
+ " self,\n",
+ " src_tokens: jax.Array,\n",
+ " dst_tokens: jax.Array,\n",
+ " ) -> TrainingInput:\n",
+ " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n",
+ "\n",
+ " # The input sequence fed to the model is simply the concatenation of the\n",
+ " # source and the destination.\n",
+ " tokens = tf.concat([src_tokens, dst_tokens], axis=0)\n",
+ "\n",
+ " # You want to prevent the model from updating based on the source (input)\n",
+ " # tokens. To achieve this, add a target mask to each input.\n",
+ " q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)\n",
+ " a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)\n",
+ " mask = tf.concat([q_mask, a_mask], axis=0)\n",
+ "\n",
+ " # If the output tokens sequence is smaller than the target sequence size,\n",
+ " # then pad it with pad tokens.\n",
+ " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n",
+ "\n",
+ " # You don't want to perform the backward on the pad tokens.\n",
+ " mask = self._pad_up_to_max_len(mask, False)\n",
+ "\n",
+ " return TrainingInput(input_tokens=tokens, target_mask=mask)\n",
+ "\n",
+ "\n",
+ " def get_train_dataset(self, batch_size: int, num_epochs: int):\n",
+ " \"\"\"Build the training dataset.\"\"\"\n",
+ "\n",
+ " # Tokenize each sample.\n",
+ " ds = self._base_data[DatasetSplit.TRAIN].map(\n",
+ " lambda x : (self._tokenize_source(x['src']),\n",
+ " self._tokenize_destination(x['dst']))\n",
+ " )\n",
+ "\n",
+ " # Convert them to training inputs.\n",
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
+ "\n",
+ " # Remove the samples which are too long.\n",
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
+ "\n",
+ " # Shuffle the dataset.\n",
+ " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n",
+ "\n",
+ " # Repeat if necessary.\n",
+ " ds = ds.repeat(num_epochs)\n",
+ "\n",
+ " # Build batches.\n",
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
+ " return ds\n",
+ "\n",
+ " def get_validation_dataset(self, batch_size: int):\n",
+ " \"\"\"Build the validation dataset.\"\"\"\n",
+ "\n",
+ " # Same as the training dataset, but no shuffling and no repetition\n",
+ " ds = self._base_data[DatasetSplit.VALIDATION].map(\n",
+ " lambda x : (self._tokenize_source(x['src']),\n",
+ " self._tokenize_destination(x['dst']))\n",
+ " )\n",
+ " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n",
+ " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n",
+ " ds = ds.batch(batch_size, drop_remainder=True)\n",
+ " return ds"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "A3jRNKosyLUK"
+ },
+ "source": [
+ "Try the `MTNTDatasetBuilder` out by instantiating the custom `GriffinTokenizer` again, then applying it on the MTNT dataset, and sampling two examples:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "bYeduOaNEcoZ"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for \n",
+ "WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for \n",
+ "WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for \n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Example 0:\n",
+ "input_tokens: [[ 2 49688 736 1280 6987 235292 108 12583 665 235265\n",
+ " 108 2 6151 94975 1320 6238 235265 1 0 0]\n",
+ " [ 2 49688 736 1280 6987 235292 108 4899 29960 11270\n",
+ " 108282 235265 108 2 4899 79025 11270 108282 1 0]\n",
+ " [ 2 49688 736 1280 6987 235292 108 26620 235265 108\n",
+ " 2 26620 235265 1 0 0 0 0 0 0]]\n",
+ "target_mask: [[False False False False False False False False False False False True\n",
+ " True True True True True True False False]\n",
+ " [False False False False False False False False False False False False\n",
+ " False True True True True True True False]\n",
+ " [False False False False False False False False False False True True\n",
+ " True True False False False False False False]]\n",
+ "\n",
+ "Example 1:\n",
+ "input_tokens: [[ 2 49688 736 1280 6987 235292 108 527 5174 1683\n",
+ " 235336 108 2 206790 581 20726 482 2208 1654 1]\n",
+ " [ 2 49688 736 1280 6987 235292 108 28484 235256 235336\n",
+ " 108 2 120500 13832 1654 1 0 0 0 0]\n",
+ " [ 2 49688 736 1280 6987 235292 108 235324 235304 2705\n",
+ " 235265 108 2 235324 235304 19963 235265 1 0 0]]\n",
+ "target_mask: [[False False False False False False False False False False False False\n",
+ " True True True True True True True True]\n",
+ " [False False False False False False False False False False False True\n",
+ " True True True True False False False False]\n",
+ " [False False False False False False False False False False False False\n",
+ " True True True True True True False False]]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)\n",
+ "ds = dataset_builder.get_train_dataset(3, 1)\n",
+ "ds = ds.take(2)\n",
+ "ds = ds.as_numpy_iterator()\n",
+ "\n",
+ "for idx, example in enumerate(ds):\n",
+ " print(f'Example {idx}:')\n",
+ " for key, val in example.items():\n",
+ " print(f'{key}: {val}')\n",
+ " print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7IY8Muu1zRF4"
+ },
+ "source": [
+ "## Configure the model\n",
+ "\n",
+ "Before you begin fine-tuning the Gemma model, you need to configure it.\n",
+ "\n",
+ "Load the RecurrentGemma (Griffin) model checkpoint with the [`recurrentgemma.jax.utils.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31) method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "by6eWKtqzxRf"
+ },
+ "outputs": [],
+ "source": [
+ "params = recurrentgemma.load_parameters(CKPT_PATH, \"single_device\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wWBglOTTA34w"
+ },
+ "source": [
+ "To automatically load the correct configuration from the RecurrentGemma model checkpoint, use [`recurrentgemma.GriffinConfig.from_flax_params_or_variables`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/common.py#L128):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OWyrLqMMdsPq"
+ },
+ "outputs": [],
+ "source": [
+ "config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "82wJkg6CAtmz"
+ },
+ "source": [
+ "Instantiate the [Griffin](https://arxiv.org/abs/2402.19427) model with [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_h9Rycpg9gYy"
+ },
+ "outputs": [],
+ "source": [
+ "model = recurrentgemma.Griffin(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2jKaRvCbAp0J"
+ },
+ "source": [
+ "Create a `sampler` with [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74) on top of the RecurrentGemma model checkpoint/weights and the tokenizer to check if your model can perform translation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "a4tqSw26ANSi"
+ },
+ "outputs": [],
+ "source": [
+ "sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t7UL2Af536x_"
+ },
+ "source": [
+ "## Fine-tune the model\n",
+ "\n",
+ "In this section, you will:\n",
+ "\n",
+ "- Use the `gemma.transformer.Transformer` class to create the forward pass and loss function.\n",
+ "- Build the position and attention mask vectors for tokens\n",
+ "- Build a training step function with Flax.\n",
+ "- Build the validation step without the backwards pass.\n",
+ "- Create the training loop.\n",
+ "- Fine-tune the Gemma model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aJhtJumH7H8_"
+ },
+ "source": [
+ "Define the forward pass and the loss function using the [`recurrentgemma.jax.griffin.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29)\n",
+ " class. The RecurrentGemma `Griffin` inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html), and offers two essential methods:\n",
+ "\n",
+ "- `init`: Initializes the model's parameters.\n",
+ "- `apply`: Executes the model's `__call__` function using a given set of parameters.\n",
+ "\n",
+ "Since you are working with pre-trained Gemma weights, you don't need to use the `init` function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iEcV0XEEEcoZ"
+ },
+ "outputs": [],
+ "source": [
+ "def forward_and_loss_fn(\n",
+ " params,\n",
+ " *,\n",
+ " model: recurrentgemma.Griffin,\n",
+ " input_tokens: jax.Array, # Shape [B, L]\n",
+ " input_mask: jax.Array, # Shape [B, L]\n",
+ " positions: jax.Array, # Shape [B, L]\n",
+ ") -> jax.Array:\n",
+ " \"\"\"Forward pass and loss function.\n",
+ "\n",
+ " Args:\n",
+ " params: model's input parameters.\n",
+ " model: Griffin model to call.\n",
+ " input_tokens: input tokens sequence, shape [B, L].\n",
+ " input_mask: tokens to ignore when computing the loss, shape [B, L].\n",
+ " positions: relative position of each token, shape [B, L].\n",
+ "\n",
+ " Returns:\n",
+ " Softmax cross-entropy loss for the next-token prediction task.\n",
+ " \"\"\"\n",
+ " batch_size = input_tokens.shape[0]\n",
+ " # Foward pass on the input data.\n",
+ " # No attention cache is needed here.\n",
+ " # Exclude the last step as it does not appear in the targets.\n",
+ " logits, _ = model.apply(\n",
+ " {\"params\": params},\n",
+ " tokens=input_tokens[:, :-1],\n",
+ " segment_pos=positions[:, :-1],\n",
+ " cache=None,\n",
+ " )\n",
+ "\n",
+ " # Similarly, the first token cannot be predicteds.\n",
+ " target_tokens = input_tokens[:, 1:]\n",
+ " target_mask = input_mask[:, 1:]\n",
+ "\n",
+ " # Convert the target labels into one-hot encoded vectors.\n",
+ " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n",
+ "\n",
+ " # Don't update on unwanted tokens.\n",
+ " one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]\n",
+ "\n",
+ " # Normalization factor.\n",
+ " norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)\n",
+ "\n",
+ " # Return the negative log-likelihood loss (NLL) function.\n",
+ " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uRkeF6ed8tOI"
+ },
+ "source": [
+ "Build the `train_step` function that performs the backward pass and updates the model's parameters accordingly, where:\n",
+ "\n",
+ "- [`jax.value_and_grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html) is for evaluating the loss function and gradients during the forward and backward passes.\n",
+ "- [`optax.apply_updates`](https://optax.readthedocs.io/en/latest/api/apply_updates.html#optax.apply_updates) is for updating the parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cPSfp7ZUEcoZ"
+ },
+ "outputs": [],
+ "source": [
+ "Params = Mapping[str, Any]\n",
+ "\n",
+ "def get_positions(example: jax.Array, pad_id : int) -> jax.Array:\n",
+ " \"\"\"Builds the position vector from the given tokens.\"\"\"\n",
+ " pad_mask = example != pad_id\n",
+ " positions = jnp.cumsum(pad_mask, axis=-1)\n",
+ " # Subtract one for all positions from the first valid one as they are\n",
+ " # 0-indexed\n",
+ " positions = positions - (positions >= 1)\n",
+ " return positions\n",
+ "\n",
+ "@functools.partial(\n",
+ " jax.jit,\n",
+ " static_argnames=['model', 'optimizer'],\n",
+ " donate_argnames=['params', 'opt_state'],\n",
+ ")\n",
+ "def train_step(\n",
+ " model: recurrentgemma.Griffin,\n",
+ " params: Params,\n",
+ " optimizer: optax.GradientTransformation,\n",
+ " opt_state: optax.OptState,\n",
+ " pad_id: int,\n",
+ " example: TrainingInput,\n",
+ ") -> tuple[jax.Array, Params, optax.OptState]:\n",
+ " \"\"\"The train step.\n",
+ "\n",
+ " Args:\n",
+ " model: The RecurrentGemma (Griffin) model.\n",
+ " params: The model's input parameters.\n",
+ " optimizer: The Optax optimizer to use.\n",
+ " opt_state: The input optimizer's state.\n",
+ " pad_id: The ID of the pad token.\n",
+ " example: The input batch.\n",
+ "\n",
+ " Returns:\n",
+ " Training loss, updated parameters, updated optimizer state.\n",
+ " \"\"\"\n",
+ "\n",
+ " positions = get_positions(example.input_tokens, pad_id)\n",
+ "\n",
+ " # Forward and backward passes.\n",
+ " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(\n",
+ " params,\n",
+ " model=model,\n",
+ " input_tokens=example.input_tokens,\n",
+ " input_mask=example.target_mask,\n",
+ " positions=positions,\n",
+ " )\n",
+ " # Update the parameters.\n",
+ " updates, opt_state = optimizer.update(grads, opt_state, params)\n",
+ " params = optax.apply_updates(params, updates)\n",
+ "\n",
+ " return train_loss, params, opt_state"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8ZKSa-jJ809n"
+ },
+ "source": [
+ "Build the `validation_step` function without the backward pass:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yU4oR92YEcoa"
+ },
+ "outputs": [],
+ "source": [
+ "@functools.partial(jax.jit, static_argnames=['model'])\n",
+ "def validation_step(\n",
+ " model: recurrentgemma.Griffin,\n",
+ " params: Params,\n",
+ " pad_id: int,\n",
+ " example: TrainingInput,\n",
+ ") -> jax.Array:\n",
+ " return forward_and_loss_fn(\n",
+ " params,\n",
+ " model=model,\n",
+ " input_tokens=example.input_tokens,\n",
+ " input_mask=example.target_mask,\n",
+ " positions=get_positions(example.input_tokens, pad_id),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bNqVhj7v87f4"
+ },
+ "source": [
+ "Define the training loop:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xT4bAqNLEcoa"
+ },
+ "outputs": [],
+ "source": [
+ "def train_loop(\n",
+ " model: recurrentgemma.Griffin,\n",
+ " params: Params,\n",
+ " optimizer: optax.GradientTransformation,\n",
+ " train_ds: Iterator[TrainingInput],\n",
+ " validation_ds: Iterator[TrainingInput],\n",
+ " num_steps: int | None = None,\n",
+ " eval_every_n: int = 20,\n",
+ "):\n",
+ " opt_state = jax.jit(optimizer.init)(params)\n",
+ "\n",
+ " step_counter = 0\n",
+ " avg_loss=0\n",
+ "\n",
+ " # The first round of the validation loss.\n",
+ " n_steps_eval = 0\n",
+ " eval_loss = 0\n",
+ " for val_example in validation_ds.as_numpy_iterator():\n",
+ " eval_loss += validation_step(\n",
+ " model, params, dataset_builder._tokenizer.pad_id, val_example\n",
+ " )\n",
+ " n_steps_eval += 1\n",
+ " print(f\"Start, validation loss: {eval_loss/n_steps_eval}\")\n",
+ "\n",
+ " for train_example in train_ds:\n",
+ " train_loss, params, opt_state = train_step(\n",
+ " model=model,\n",
+ " params=params,\n",
+ " optimizer=optimizer,\n",
+ " opt_state=opt_state,\n",
+ " pad_id=dataset_builder._tokenizer.pad_id,\n",
+ " example=train_example,\n",
+ " )\n",
+ "\n",
+ " step_counter += 1\n",
+ " avg_loss += train_loss\n",
+ " if step_counter % eval_every_n == 0:\n",
+ " eval_loss = 0\n",
+ "\n",
+ " n_steps_eval = 0\n",
+ " val_iterator = validation_ds.as_numpy_iterator()\n",
+ " for val_example in val_iterator:\n",
+ " eval_loss += validation_step(\n",
+ " model,\n",
+ " params,\n",
+ " dataset_builder._tokenizer.pad_id,\n",
+ " val_example,\n",
+ " )\n",
+ " n_steps_eval +=1\n",
+ " avg_loss /= eval_every_n\n",
+ " eval_loss /= n_steps_eval\n",
+ " print(f\"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}\")\n",
+ " avg_loss=0\n",
+ " if num_steps is not None and step_counter > num_steps:\n",
+ " break\n",
+ " return params"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YFooBD7W1Fk3"
+ },
+ "source": [
+ "Here you have to choose an (Optax) optimizer. For devices with smaller memory, you should use SGD, as it has a much lower memory footprint. To achieve best fine-tuning performance, try Adam-W. The optimal hyperparameters for each optimizer for the particular task in this notebook are provided in this example for the `2b-it` checkpoint.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "woFjh7U_eiev"
+ },
+ "outputs": [],
+ "source": [
+ "def griffin_weight_decay_mask(params_like: optax.Params) -> Any:\n",
+ " # Don't put weight decay on the RGLRU, the embeddings and any biases\n",
+ " def enable_weight_decay(path: list[Any], _: Any) -> bool:\n",
+ " # Parameters in the LRU and embedder\n",
+ " path = [dict_key.key for dict_key in path]\n",
+ " if 'rg_lru' in path or 'embedder' in path:\n",
+ " return False\n",
+ " # All biases and scales\n",
+ " if path[-1] in ('b', 'scale'):\n",
+ " return False\n",
+ " return True\n",
+ "\n",
+ " return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)\n",
+ "\n",
+ "optimizer_choice = \"sgd\" #@param [\"sgd\", \"adamw\"]\n",
+ "\n",
+ "if optimizer_choice == \"sgd\":\n",
+ " optimizer = optax.sgd(learning_rate=1e-3)\n",
+ " num_steps = 300\n",
+ "elif optimizer_choice == \"adamw\":\n",
+ " optimizer = optax.adamw(\n",
+ " learning_rate=1e-4,\n",
+ " b2=0.96,\n",
+ " eps=1e-8,\n",
+ " weight_decay=0.1,\n",
+ " mask=griffin_weight_decay_mask,\n",
+ " )\n",
+ " num_steps = 100\n",
+ "else:\n",
+ " raise ValueError(f\"Unknown optimizer: {optimizer_choice}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h-KYQziReyCn"
+ },
+ "source": [
+ "Prepare the training and validation datasets:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rGXTQ2uHeozO"
+ },
+ "outputs": [],
+ "source": [
+ "# Choose a small sequence length size, so that everything fits in memory.\n",
+ "num_epochs = 1 #@param {type: \"integer\"}\n",
+ "batch_size = 1 #@param {type: \"integer\"}\n",
+ "sequence_length = 32 #@param {type: \"integer\"}\n",
+ "\n",
+ "# Make the dataset builder.\n",
+ "tokenizer = GriffinTokenizer(vocab)\n",
+ "dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)\n",
+ "\n",
+ "# Build the training dataset.\n",
+ "train_ds = dataset_builder.get_train_dataset(\n",
+ " batch_size=batch_size,\n",
+ " num_epochs=num_epochs,\n",
+ ").as_numpy_iterator()\n",
+ "\n",
+ "# Build the validation dataset, with a limited number of samples for this demo.\n",
+ "validation_ds = dataset_builder.get_validation_dataset(\n",
+ " batch_size=batch_size,\n",
+ ").take(50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B3alnSJQ1xmd"
+ },
+ "source": [
+ "Begin fine-tuning the RecurrentGemma (Griffin) model on a limited number of steps (`num_steps`):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7SL2VAmVEcoa"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Start, validation loss: 7.894117832183838\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,33]), ShapedArray(bool[1,33]), ShapedArray(int32[], weak_type=True).\n",
+ "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.\n",
+ " warnings.warn(\"Some donated buffers were not usable:\"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "STEP 20 training loss: 4.592616081237793 - eval loss: 2.847407102584839\n",
+ "STEP 40 training loss: 2.7537424564361572 - eval loss: 2.9258534908294678\n",
+ "STEP 60 training loss: 2.835618257522583 - eval loss: 2.4382340908050537\n",
+ "STEP 80 training loss: 2.6322107315063477 - eval loss: 2.3696839809417725\n",
+ "STEP 100 training loss: 1.8703256845474243 - eval loss: 2.355681896209717\n",
+ "STEP 120 training loss: 2.7280433177948 - eval loss: 2.4059958457946777\n",
+ "STEP 140 training loss: 2.3047447204589844 - eval loss: 2.083082914352417\n",
+ "STEP 160 training loss: 2.3432137966156006 - eval loss: 2.095074415206909\n",
+ "STEP 180 training loss: 2.1081202030181885 - eval loss: 2.006460189819336\n",
+ "STEP 200 training loss: 2.5359647274017334 - eval loss: 1.9667452573776245\n",
+ "STEP 220 training loss: 2.202195644378662 - eval loss: 1.9440618753433228\n",
+ "STEP 240 training loss: 2.756615400314331 - eval loss: 2.1073737144470215\n",
+ "STEP 260 training loss: 2.5128934383392334 - eval loss: 2.117241859436035\n",
+ "STEP 280 training loss: 2.73045015335083 - eval loss: 1.9159646034240723\n",
+ "STEP 300 training loss: 2.0918595790863037 - eval loss: 1.9742532968521118\n"
+ ]
+ }
+ ],
+ "source": [
+ "trained_params = train_loop(\n",
+ " model=model,\n",
+ " params=params,\n",
+ " optimizer=optimizer,\n",
+ " train_ds=train_ds,\n",
+ " validation_ds=validation_ds,\n",
+ " num_steps=num_steps,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EtfVo3pDDAZV"
+ },
+ "source": [
+ "Both the training loss and the validation loss should have gone down with each step count.\n",
+ "\n",
+ "To ensure your input matches the training format, remember to use the prefix `Translate this into French:\\n` and a newline character at the end. This signals the model to begin translation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "S5F3fk22Ecod"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,16]).\n",
+ "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.\n",
+ " warnings.warn(\"Some donated buffers were not usable:\"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mais je m'appelle Morgane.\n"
+ ]
+ }
+ ],
+ "source": [
+ "sampler.params = trained_params\n",
+ "output = sampler(\n",
+ " [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n",
+ " total_generation_steps=100,\n",
+ ")\n",
+ "print(output.text[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Jao0Qk-ZIqyD"
+ },
+ "source": [
+ "## Learn more\n",
+ "\n",
+ "- You can learn more about the Google DeepMind [`recurrentgemma` library on GitHub](https://github.com/google-deepmind/recurrentgemma), which contains docstrings of methods and modules you used in this tutorial, such as [`recurrentgemma.jax.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31), [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29), and [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74).\n",
+ "- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), [Chex](https://chex.readthedocs.io/en/latest/), [Optax](https://optax.readthedocs.io/en/latest/), and [Orbax](https://orbax.readthedocs.io/).\n",
+ "- For `sentencepiece` tokenizer/detokenizer documentation, check out [Google's `sentencepiece` GitHub repo](https://github.com/google/sentencepiece).\n",
+ "- For `kagglehub` documentation, check out `README.md` on [Kaggle's `kagglehub` GitHub repo](https://github.com/Kaggle/kagglehub).\n",
+ "- Learn how to [use Gemma models with Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
+ "- If you are using Google Cloud TPUs (v3-8 and newer), make sure to also update to the latest `jax[tpu]` package (`!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html`), restart the runtime, and check that `jax` and `jaxlib` versions match (`!pip list | grep jax`). This can prevent the `RuntimeError` that can arise because of the `jaxlib` and `jax` version mismatch. For more JAX installation instructions, refer to the [JAX docs](https://jax.readthedocs.io/en/latest/tutorials/installation.html#install-google-tpu).\n",
+ "- Check out the [RecurrentGemma: Moving Past Transformers\n",
+ "for Efficient Open Language Models](https://arxiv.org/pdf/2404.07839) paper by Google DeepMind.\n",
+ "- Read the [Griffin: Mixing Gated Linear Recurrences with\n",
+ "Local Attention for Efficient Language Models](https://arxiv.org/pdf/2402.19427) paper by Google DeepMind to learn more about the model architecture used by RecurrentGemma."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "recurrentgemma_jax_finetune.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb b/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb
new file mode 100644
index 000000000..b1c36ea2b
--- /dev/null
+++ b/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb
@@ -0,0 +1,506 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tce3stUlHN0L"
+ },
+ "source": [
+ "##### Copyright 2024 Google LLC."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tuOe1ymfHZPu"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FUOiKRSF7jc1"
+ },
+ "source": [
+ "# Inference with RecurrentGemma"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "60KmTK7o6ppd"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tdlq6K0znh3O"
+ },
+ "source": [
+ "This tutorial demonstrates how to perform basic sampling/inference with the [RecurrentGemma](https://ai.google.dev/gemma/docs/recurrentgemma) 2B Instruct model using [Google DeepMind's `recurrentgemma` library](https://github.com/google-deepmind/recurrentgemma) that was written with [JAX](https://jax.readthedocs.io) (a high-performance numerical computing library), [Flax](https://flax.readthedocs.io) (the JAX-based neural network library), [Orbax](https://orbax.readthedocs.io/) (a JAX-based library for training utilities like checkpointing), and [SentencePiece](https://github.com/google/sentencepiece) (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma and RecurrentGemma (the Griffin model).\n",
+ "\n",
+ "This notebook can run on Google Colab with the T4 GPU (go to **Edit** > **Notebook settings** > Under **Hardware accelerator** select **T4 GPU**)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aKvTsIkL98BG"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "The following sections explain the steps for preparing a notebook to use a RecurrentGemma model, including model access, getting an API key, and configuring the notebook runtime"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WCgCkmQSPxkE"
+ },
+ "source": [
+ "### Set up Kaggle access for Gemma\n",
+ "\n",
+ "To complete this tutorial, you first need to follow the setup instructions _similar_ to [Gemma setup](https://ai.google.dev/gemma/docs/setup) with a few exceptions:\n",
+ "\n",
+ "* Get access to RecurrentGemma (instead of Gemma) on [kaggle.com](https://www.kaggle.com/models/google/recurrentgemma).\n",
+ "* Select a Colab runtime with sufficient resources to run the RecurrentGemma model.\n",
+ "* Generate and configure a Kaggle username and API key.\n",
+ "\n",
+ "After you've completed the RecurrentGemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
+ "\n",
+ "### Set environment variables\n",
+ "\n",
+ "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lKoW-nhE-gNO"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from google.colab import userdata # `userdata` is a Colab API.\n",
+ "\n",
+ "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
+ "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AO7a1Q4Yyc9Z"
+ },
+ "source": [
+ "### Install the `recurrentgemma` library\n",
+ "\n",
+ "This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on **Edit** > **Notebook settings** > Select **T4 GPU** > **Save**.\n",
+ "\n",
+ "Next, you need to install the Google DeepMind `recurrentgemma` library from [`github.com/google-deepmind/recurrentgemma`](https://github.com/google-deepmind/recurrentgemma). If you get an error about \"pip's dependency resolver\", you can usually ignore it.\n",
+ "\n",
+ "**Note:** By installing `recurrentgemma`, you will also install [`flax`](https://flax.readthedocs.io), core [`jax`](https://jax.readthedocs.io), [`optax`](https://optax.readthedocs.io/en/latest/) (the JAX-based gradient processing and optimization library), [`orbax`](https://orbax.readthedocs.io/), and [`sentencepiece`](https://github.com/google/sentencepiece)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WWEzVJR4Fx9g"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting git+https://github.com/google-deepmind/recurrentgemma.git\n",
+ " Cloning https://github.com/google-deepmind/recurrentgemma.git to /tmp/pip-req-build-zz9xp6s4\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/recurrentgemma.git /tmp/pip-req-build-zz9xp6s4\n",
+ " Resolved https://github.com/google-deepmind/recurrentgemma.git to commit e4939f9b7edf8baa1d512fb86bfc2e206044d66b\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: absl-py<1.5.0,>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from recurrentgemma==0.1.0) (1.4.0)\n",
+ "Collecting einops<0.8.0,>=0.7.0 (from recurrentgemma==0.1.0)\n",
+ " Downloading einops-0.7.0-py3-none-any.whl (44 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jaxtyping<0.3.0,>=0.2.28 (from recurrentgemma==0.1.0)\n",
+ " Downloading jaxtyping-0.2.28-py3-none-any.whl (40 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.7/40.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.10/dist-packages (from recurrentgemma==0.1.0) (1.25.2)\n",
+ "Collecting sentencepiece<0.3.0,>=0.2.0 (from recurrentgemma==0.1.0)\n",
+ " Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting typeguard==2.13.3 (from jaxtyping<0.3.0,>=0.2.28->recurrentgemma==0.1.0)\n",
+ " Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)\n",
+ "Building wheels for collected packages: recurrentgemma\n",
+ " Building wheel for recurrentgemma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for recurrentgemma: filename=recurrentgemma-0.1.0-py3-none-any.whl size=73547 sha256=e3d3e85d59877ec33d2e4dff1a1666eaed1342c68199255cdd806d74472d4524\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-42qdygtw/wheels/31/37/18/c57f1df6091b661385ab728b959bdfbf2078d9fc7c856899e4\n",
+ "Successfully built recurrentgemma\n",
+ "Installing collected packages: sentencepiece, typeguard, einops, jaxtyping, recurrentgemma\n",
+ " Attempting uninstall: sentencepiece\n",
+ " Found existing installation: sentencepiece 0.1.99\n",
+ " Uninstalling sentencepiece-0.1.99:\n",
+ " Successfully uninstalled sentencepiece-0.1.99\n",
+ "Successfully installed einops-0.7.0 jaxtyping-0.2.28 recurrentgemma-0.1.0 sentencepiece-0.2.0 typeguard-2.13.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install git+https://github.com/google-deepmind/recurrentgemma.git"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VKLjBAe1m3Ck"
+ },
+ "source": [
+ "## Load and prepare the RecurrentGemma model\n",
+ "\n",
+ "1. Load the RecurrentGemma model with [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12), which takes three arguments:\n",
+ "\n",
+ "- `handle`: The model handle from Kaggle\n",
+ "- `path`: (Optional string) The local path\n",
+ "- `force_download`: (Optional boolean) Forces to re-download the model\n",
+ "\n",
+ "**Note:** Be mindful that the `recurrentgemma-2b-it` model is around 3.85Gb in size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_W3FUd9lt8VT"
+ },
+ "outputs": [],
+ "source": [
+ "RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kFCmWEKdMA_Y"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...\n",
+ "100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]\n",
+ "Extracting model files...\n"
+ ]
+ }
+ ],
+ "source": [
+ "import kagglehub\n",
+ "\n",
+ "RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nYmYTMk8aELi"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ytNi47xSlw71"
+ },
+ "source": [
+ "**Note:** The path from the output above is where the model weights and tokenizer are saved locally, you will need them for later."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "92BcvYdemXbd"
+ },
+ "source": [
+ "2. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:\n",
+ "\n",
+ "- The `tokenizer.model` file will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1`).\n",
+ "- The model checkpoint will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QY6OnASOpZbW"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it\n",
+ "TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model\n"
+ ]
+ }
+ ],
+ "source": [
+ "CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)\n",
+ "TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')\n",
+ "print('CKPT_PATH:', CKPT_PATH)\n",
+ "print('TOKENIZER_PATH:', TOKENIZER_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jc0ZzYIW0TSN"
+ },
+ "source": [
+ "## Perform sampling/inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aEe3p8geqekV"
+ },
+ "source": [
+ "1. Load the RecurrentGemma model checkpoint with the [`recurrentgemma.jax.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31) method. The `sharding` argument set to `\"single_device\"` loads all model parameters on a single device."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Mnr52JQVqKRw"
+ },
+ "outputs": [],
+ "source": [
+ "import recurrentgemma\n",
+ "from recurrentgemma import jax as recurrentgemma\n",
+ "\n",
+ "params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding=\"single_device\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-Xpnb2igrGjk"
+ },
+ "source": [
+ "2. Load the RecurrentGemma model tokenizer, constructed using [`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-T0ZHff5rNSy"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sentencepiece as spm\n",
+ "\n",
+ "vocab = spm.SentencePieceProcessor()\n",
+ "vocab.Load(TOKENIZER_PATH)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IkAf4fkNrY-3"
+ },
+ "source": [
+ "3. To automatically load the correct configuration from the RecurrentGemma model checkpoint, use [`recurrentgemma.GriffinConfig.from_flax_params_or_variables`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/common.py#L128). Then, instantiate the [Griffin](https://arxiv.org/abs/2402.19427) model with [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4PNWxDhvrRXJ"
+ },
+ "outputs": [],
+ "source": [
+ "model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(\n",
+ " flax_params_or_variables=params)\n",
+ "\n",
+ "model = recurrentgemma.Griffin(model_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vs0vgmXVroBq"
+ },
+ "source": [
+ "3. Create a `sampler` with [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74) on top of the RecurrentGemma model checkpoint/weights and the tokenizer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4GX4pFP6rtyN"
+ },
+ "outputs": [],
+ "source": [
+ "sampler = recurrentgemma.Sampler(\n",
+ " model=model,\n",
+ " vocab=vocab,\n",
+ " params=params,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "V9yU99Xxr59w"
+ },
+ "source": [
+ "4. Write a prompt in `prompt` and perform inference. You can tweak `total_generation_steps` (the number of steps performed when generating a response — this example uses `50` to preserve host memory).\n",
+ "\n",
+ "**Note:** If you run out of memory, click on **Runtime** > **Disconnect and delete runtime**, and then **Runtime** > **Run all**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Gj9jRFI5Hrv2"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).\n",
+ "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.\n",
+ " warnings.warn(\"Some donated buffers were not usable:\"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prompt:\n",
+ "\n",
+ "# 5+9=?\n",
+ "Output:\n",
+ "\n",
+ "\n",
+ "# Answer: 14\n",
+ "\n",
+ "# Explanation: 5 + 9 = 14.\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompt = [\n",
+ " \"\\n# 5+9=?\",\n",
+ "]\n",
+ "\n",
+ "reply = sampler(input_strings=prompt,\n",
+ " total_generation_steps=50,\n",
+ " )\n",
+ "\n",
+ "for input_string, out_string in zip(prompt, reply.text):\n",
+ " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bzKsCGIN0yX5"
+ },
+ "source": [
+ "## Learn more\n",
+ "\n",
+ "- You can learn more about the Google DeepMind [`recurrentgemma` library on GitHub](https://github.com/google-deepmind/recurrentgemma), which contains docstrings of methods and modules you used in this tutorial, such as [`recurrentgemma.jax.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31), [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29), and [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74).\n",
+ "- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).\n",
+ "- For `sentencepiece` tokenizer/detokenizer documentation, check out [Google's `sentencepiece` GitHub repo](https://github.com/google/sentencepiece).\n",
+ "- For `kagglehub` documentation, check out `README.md` on [Kaggle's `kagglehub` GitHub repo](https://github.com/Kaggle/kagglehub).\n",
+ "- Learn how to [use Gemma models with Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
+ "- Check out the [RecurrentGemma: Moving Past Transformers\n",
+ "for Efficient Open Language Models](https://arxiv.org/pdf/2404.07839) paper by Google DeepMind.\n",
+ "- Read the [Griffin: Mixing Gated Linear Recurrences with\n",
+ "Local Attention for Efficient Language Models](https://arxiv.org/pdf/2402.19427) paper by GoogleDeepMind to learn more about the model architecture used by RecurrentGemma."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [
+ "Tce3stUlHN0L"
+ ],
+ "name": "recurrentgemma_jax_inference.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}