"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "402c3d8a",
+ "metadata": {
+ "id": "402c3d8a"
+ },
+ "source": [
+ "# Building a chatbot with Gemma"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b686fd95",
+ "metadata": {
+ "id": "b686fd95"
+ },
+ "source": [
+ "Large Language Models (LLMs) such as Gemma excel at generating informative responses, making them ideal for building virtual assistants and chatbots.\n",
+ "\n",
+ "Conventionally, LLMs operate in a stateless manner, meaning they lack an inherent memory to store past conversations. Each prompt or question is processed independently, disregarding prior interactions. However, a crucial aspect of natural conversation is the ability to retain context from prior interactions. To overcome this limitation and enable LLMs to maintain conversation context, they must be explicitly provided with relevant information such as the conversation history (or pertinent parts) into each new prompt presented to the LLM.\n",
+ "\n",
+ "This tutorial shows you how to develop a chatbot using the instruction-tuned model variant of Gemma."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29732090",
+ "metadata": {
+ "id": "29732090"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "QQ6W7NzRe1VM",
+ "metadata": {
+ "id": "QQ6W7NzRe1VM"
+ },
+ "source": [
+ "### Gemma setup\n",
+ "\n",
+ "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
+ "\n",
+ "* Get access to Gemma on kaggle.com.\n",
+ "* Select a Colab runtime with sufficient resources to run\n",
+ " the Gemma 2B model.\n",
+ "* Generate and configure a Kaggle username and API key.\n",
+ "\n",
+ "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "_gN-IVRC3dQe",
+ "metadata": {
+ "id": "_gN-IVRC3dQe"
+ },
+ "source": [
+ "### Set environment variables\n",
+ "\n",
+ "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "DrBoa_Urw9Vx",
+ "metadata": {
+ "id": "DrBoa_Urw9Vx"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from google.colab import userdata\n",
+ "\n",
+ "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
+ "# vars as appropriate for your system.\n",
+ "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
+ "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "z9oy3QUmXtSd",
+ "metadata": {
+ "id": "z9oy3QUmXtSd"
+ },
+ "source": [
+ "### Install dependencies\n",
+ "\n",
+ "Install Keras and KerasNLP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a973dd7a",
+ "metadata": {
+ "id": "a973dd7a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m214.0/214.0 MB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m86.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m70.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m88.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m64.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m29.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.\n",
+ "tensorflow 2.15.0 requires ml-dtypes~=0.2.0, but you have ml-dtypes 0.3.2 which is incompatible.\n",
+ "tensorflow 2.15.0 requires tensorboard<2.16,>=2.15, but you have tensorboard 2.16.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m515.3/515.3 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m49.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m49.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m67.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow-cpu 2.16.1 requires keras>=3.0.0, but you have keras 2.15.0 which is incompatible.\n",
+ "tensorflow-cpu 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.\n",
+ "tensorflow-cpu 2.16.1 requires tensorboard<2.17,>=2.16, but you have tensorboard 2.15.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.\n",
+ "tensorflow-cpu 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.\n",
+ "tensorflow-cpu 2.16.1 requires tensorboard<2.17,>=2.16, but you have tensorboard 2.15.2 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m589.8/589.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.1 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
+ "!pip install -q tensorflow-cpu\n",
+ "!pip install -q -U keras-nlp tensorflow-hub\n",
+ "!pip install -q -U keras>=3\n",
+ "!pip install -q -U tensorflow-text"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "Wme8666dUPVR",
+ "metadata": {
+ "id": "Wme8666dUPVR"
+ },
+ "source": [
+ "### Select a backend\n",
+ "\n",
+ "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. [Keras 3](https://keras.io/keras_3){:.external} lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "67d12d2d",
+ "metadata": {
+ "id": "67d12d2d"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "# Select JAX as the backend\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
+ "\n",
+ "# Pre-allocate 100% of TPU memory to minimize memory fragmentation\n",
+ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "Ajm_SGWTUjVd",
+ "metadata": {
+ "id": "Ajm_SGWTUjVd"
+ },
+ "source": [
+ "### Import packages\n",
+ "\n",
+ "Import Keras and KerasNLP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3lyn9FxPUok8",
+ "metadata": {
+ "id": "3lyn9FxPUok8"
+ },
+ "outputs": [],
+ "source": [
+ "import keras\n",
+ "import keras_nlp\n",
+ "\n",
+ "# for reproducibility\n",
+ "keras.utils.set_random_seed(42)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "39dc9d5b",
+ "metadata": {
+ "id": "39dc9d5b"
+ },
+ "source": [
+ "### Instantiate the model\n",
+ "\n",
+ "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/){:.external}. In this tutorial, you'll instantiate the model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.\n",
+ "\n",
+ "Instantiate the model using the `from_preset` method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c86dc8fe",
+ "metadata": {
+ "id": "c86dc8fe"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n",
+ "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...\n"
+ ]
+ }
+ ],
+ "source": [
+ "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_1.1_instruct_2b_en\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "tcCv0BSdVFv9",
+ "metadata": {
+ "id": "tcCv0BSdVFv9"
+ },
+ "source": [
+ "`from_preset` instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma_1.1_instruct_2b_en\"` specifies the preset architecture: a Gemma instruction-tuned model with 2 billion parameters.\n",
+ "This model variant is fine-tuned for conversations and to answer questions in a more natural manner.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bLNx8AoeVe-a",
+ "metadata": {
+ "id": "bLNx8AoeVe-a"
+ },
+ "source": [
+ "Use the `summary` method to get more info about the model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3MorieIpVksu",
+ "metadata": {
+ "id": "3MorieIpVksu"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "gemma_lm.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ArZPOzFpVp6S",
+ "metadata": {
+ "id": "ArZPOzFpVp6S"
+ },
+ "source": [
+ "As you can see from the summary, the model has 2.5 billion trainable parameters.\n",
+ "\n",
+ "Note: For purposes of naming the model (\"2B\"), the embedding layer is not counted against the number of parameters."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1WpS39TBYql9",
+ "metadata": {
+ "id": "1WpS39TBYql9"
+ },
+ "source": [
+ "### Define formatting helper functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3-obTC1jZGpZ",
+ "metadata": {
+ "id": "3-obTC1jZGpZ"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import Markdown\n",
+ "import textwrap\n",
+ "\n",
+ "def display_chat(prompt, text):\n",
+ " formatted_prompt = \"🙋♂️
\" + prompt + \"
\"\n",
+ " text = text.replace('•', ' *')\n",
+ " text = textwrap.indent(text, '> ', predicate=lambda _: True)\n",
+ " formatted_text = \"🤖\\n\\n\" + text + \"\\n\"\n",
+ " return Markdown(formatted_prompt+formatted_text)\n",
+ "\n",
+ "def to_markdown(text):\n",
+ " text = text.replace('•', ' *')\n",
+ " return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5ca54e8c",
+ "metadata": {
+ "id": "5ca54e8c"
+ },
+ "source": [
+ "## Building the chatbot\n",
+ "\n",
+ "The Gemma instruction-tuned model `gemma_1.1_instruct_2b_en` is fine-tuned to understand the following turn tokens:\n",
+ "\n",
+ "```\n",
+ "user\\n ... \\n\n",
+ "model\\n ... \\n\n",
+ "```\n",
+ "\n",
+ "This tutorial uses these tokens to build the chatbot. Refer to [Formatting and system instructions](https://ai.google.dev/gemma/docs/formatting) for more information on Gemma control tokens.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9583dfd1",
+ "metadata": {
+ "id": "9583dfd1"
+ },
+ "source": [
+ "### Create a chat helper to manage the conversation state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e4e9a187",
+ "metadata": {
+ "id": "e4e9a187"
+ },
+ "outputs": [],
+ "source": [
+ "class ChatState():\n",
+ " \"\"\"\n",
+ " Manages the conversation history for a turn-based chatbot\n",
+ " Follows the turn-based conversation guidelines for the Gemma family of models\n",
+ " documented at https://ai.google.dev/gemma/docs/formatting\n",
+ " \"\"\"\n",
+ "\n",
+ " __START_TURN_USER__ = \"user\\n\"\n",
+ " __START_TURN_MODEL__ = \"model\\n\"\n",
+ " __END_TURN__ = \"\\n\"\n",
+ "\n",
+ " def __init__(self, model, system=\"\"):\n",
+ " \"\"\"\n",
+ " Initializes the chat state.\n",
+ "\n",
+ " Args:\n",
+ " model: The language model to use for generating responses.\n",
+ " system: (Optional) System instructions or bot description.\n",
+ " \"\"\"\n",
+ " self.model = model\n",
+ " self.system = system\n",
+ " self.history = []\n",
+ "\n",
+ " def add_to_history_as_user(self, message):\n",
+ " \"\"\"\n",
+ " Adds a user message to the history with start/end turn markers.\n",
+ " \"\"\"\n",
+ " self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)\n",
+ "\n",
+ " def add_to_history_as_model(self, message):\n",
+ " \"\"\"\n",
+ " Adds a model response to the history with start/end turn markers.\n",
+ " \"\"\"\n",
+ " self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)\n",
+ "\n",
+ " def get_history(self):\n",
+ " \"\"\"\n",
+ " Returns the entire chat history as a single string.\n",
+ " \"\"\"\n",
+ " return \"\".join([*self.history])\n",
+ "\n",
+ " def get_full_prompt(self):\n",
+ " \"\"\"\n",
+ " Builds the prompt for the language model, including history and system description.\n",
+ " \"\"\"\n",
+ " prompt = self.get_history() + self.__START_TURN_MODEL__\n",
+ " if len(self.system)>0:\n",
+ " prompt = self.system + \"\\n\" + prompt\n",
+ " return prompt\n",
+ "\n",
+ " def send_message(self, message):\n",
+ " \"\"\"\n",
+ " Handles sending a user message and getting a model response.\n",
+ "\n",
+ " Args:\n",
+ " message: The user's message.\n",
+ "\n",
+ " Returns:\n",
+ " The model's response.\n",
+ " \"\"\"\n",
+ " self.add_to_history_as_user(message)\n",
+ " prompt = self.get_full_prompt()\n",
+ " response = self.model.generate(prompt, max_length=1024)\n",
+ " result = response.replace(prompt, \"\") # Extract only the new response\n",
+ " self.add_to_history_as_model(result)\n",
+ " return result\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9hmJS4h4ZmiP",
+ "metadata": {
+ "id": "9hmJS4h4ZmiP"
+ },
+ "source": [
+ "### Chat with the model\n",
+ "\n",
+ "Start chatting with the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b1913181",
+ "metadata": {
+ "id": "b1913181"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": "🙋♂️
Tell me, in a few words, how to compute all prime numbers up to 1000?
🤖\n\n> The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "chat = ChatState(gemma_lm)\n",
+ "message = \"Tell me, in a few words, how to compute all prime numbers up to 1000?\"\n",
+ "display_chat(message, chat.send_message(message))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ODKxUPP2Zuqy",
+ "metadata": {
+ "id": "ODKxUPP2Zuqy"
+ },
+ "source": [
+ "Continue the conversation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7448005b",
+ "metadata": {
+ "id": "7448005b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": "🙋♂️
Now in Python! No numpy, please!
🤖\n\n> ```python\n> def prime(n):\n> if n <= 1:\n> return False\n> for i in range(2, int(n**0.5) + 1):\n> if n % i == 0:\n> return False\n> return True\n> ```\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "message = \"Now in Python! No numpy, please!\"\n",
+ "display_chat(message, chat.send_message(message))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0973ff54",
+ "metadata": {
+ "id": "0973ff54"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": "🙋♂️
Thank you, it works! Can you explain the code in plain English?
🤖\n\n> The provided code defines a function `prime` that checks whether a given number is prime or not.\n> \n> **How it works:**\n> \n> - The function takes a single argument, `n`, which is the number to check.\n> \n> \n> - It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.\n> \n> \n> - It then enters a loop that iterates through numbers from 2 to the square root of `n`.\n> \n> \n> - For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).\n> \n> \n> - If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.\n> \n> \n> - If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.\n> \n> \n> **Example Usage:**\n> \n> ```python\n> >>> prime(2)\n> True\n> >>> prime(3)\n> True\n> >>> prime(4)\n> False\n> >>> prime(5)\n> True\n> ```\n> \n> **Benefits of this Code:**\n> \n> - It is a simple and efficient algorithm for finding prime numbers.\n> - It is widely used in various computer science and mathematical applications.\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "message = \"Thank you, it works! Can you explain the code in plain English?\"\n",
+ "display_chat(message, chat.send_message(message))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a0c51f42",
+ "metadata": {
+ "id": "a0c51f42"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": "🙋♂️
Great! Now add those explanations as comments in the code.
🤖\n\n> ```python\n> def prime(n):\n> \"\"\"\n> Checks whether a given number is prime or not.\n> \n> Args:\n> n: The number to check.\n> \n> Returns:\n> True if n is prime, False otherwise.\n> \"\"\"\n> \n> # Check if n is less than or equal to 1.\n> if n <= 1:\n> return False\n> \n> # Iterate through numbers from 2 to the square root of n.\n> for i in range(2, int(n**0.5) + 1):\n> # Check if n is divisible by i.\n> if n % i == 0:\n> return False\n> \n> # If the loop completes without finding any divisors for n, then n is prime.\n> return True\n> ```\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "message = \"Great! Now add those explanations as comments in the code.\"\n",
+ "display_chat(message, chat.send_message(message))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "51a33627",
+ "metadata": {
+ "id": "51a33627"
+ },
+ "source": [
+ "Test the generated response by running the generated code:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "221c0817",
+ "metadata": {
+ "id": "221c0817"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def is_prime(n):\n",
+ " \"\"\"\n",
+ " Checks if a number is prime.\n",
+ "\n",
+ " Args:\n",
+ " n: The number to check.\n",
+ "\n",
+ " Returns:\n",
+ " True if n is prime, False otherwise.\n",
+ " \"\"\"\n",
+ "\n",
+ " # If n is less than or equal to 1, it is not prime.\n",
+ " if n <= 1:\n",
+ " return False\n",
+ "\n",
+ " # Iterate through all the numbers from 2 to the square root of n.\n",
+ " for i in range(2, int(n**0.5) + 1):\n",
+ " # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.\n",
+ " if n % i == 0:\n",
+ " return False\n",
+ "\n",
+ " # If no divisors are found, n is prime.\n",
+ " return True\n",
+ "\n",
+ "\n",
+ "# Initialize an empty list to store prime numbers.\n",
+ "primes = []\n",
+ "\n",
+ "# Iterate through all the numbers from 2 to 1000.\n",
+ "for i in range(2, 1001):\n",
+ " # If the number is prime, add it to the list.\n",
+ " if is_prime(i):\n",
+ " primes.append(i)\n",
+ "\n",
+ "# Print the prime numbers.\n",
+ "print(primes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1c8ece6c",
+ "metadata": {
+ "id": "1c8ece6c"
+ },
+ "source": [
+ "Use the `get_history` method to see how all the context was retained by the `Chat` class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e48f4ca1",
+ "metadata": {
+ "id": "e48f4ca1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "user\n",
+ "Tell me, in a few words, how to compute all prime numbers up to 1000?\n",
+ "model\n",
+ "The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.\n",
+ "user\n",
+ "Now in Python! No numpy, please!\n",
+ "model\n",
+ "```python\n",
+ "def prime(n):\n",
+ " if n <= 1:\n",
+ " return False\n",
+ " for i in range(2, int(n**0.5) + 1):\n",
+ " if n % i == 0:\n",
+ " return False\n",
+ " return True\n",
+ "```\n",
+ "user\n",
+ "Thank you, it works! Can you explain the code in plain English?\n",
+ "model\n",
+ "The provided code defines a function `prime` that checks whether a given number is prime or not.\n",
+ "\n",
+ "**How it works:**\n",
+ "\n",
+ "- The function takes a single argument, `n`, which is the number to check.\n",
+ "\n",
+ "\n",
+ "- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.\n",
+ "\n",
+ "\n",
+ "- It then enters a loop that iterates through numbers from 2 to the square root of `n`.\n",
+ "\n",
+ "\n",
+ "- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).\n",
+ "\n",
+ "\n",
+ "- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.\n",
+ "\n",
+ "\n",
+ "- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.\n",
+ "\n",
+ "\n",
+ "**Example Usage:**\n",
+ "\n",
+ "```python\n",
+ ">>> prime(2)\n",
+ "True\n",
+ ">>> prime(3)\n",
+ "True\n",
+ ">>> prime(4)\n",
+ "False\n",
+ ">>> prime(5)\n",
+ "True\n",
+ "```\n",
+ "\n",
+ "**Benefits of this Code:**\n",
+ "\n",
+ "- It is a simple and efficient algorithm for finding prime numbers.\n",
+ "- It is widely used in various computer science and mathematical applications.\n",
+ "user\n",
+ "Great! Now add those explanations as comments in the code.\n",
+ "model\n",
+ "```python\n",
+ "def prime(n):\n",
+ " \"\"\"\n",
+ " Checks whether a given number is prime or not.\n",
+ "\n",
+ " Args:\n",
+ " n: The number to check.\n",
+ "\n",
+ " Returns:\n",
+ " True if n is prime, False otherwise.\n",
+ " \"\"\"\n",
+ "\n",
+ " # Check if n is less than or equal to 1.\n",
+ " if n <= 1:\n",
+ " return False\n",
+ "\n",
+ " # Iterate through numbers from 2 to the square root of n.\n",
+ " for i in range(2, int(n**0.5) + 1):\n",
+ " # Check if n is divisible by i.\n",
+ " if n % i == 0:\n",
+ " return False\n",
+ "\n",
+ " # If the loop completes without finding any divisors for n, then n is prime.\n",
+ " return True\n",
+ "```\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(chat.get_history())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9693c66f",
+ "metadata": {
+ "id": "9693c66f"
+ },
+ "source": [
+ "## Summary and further reading\n",
+ "\n",
+ "In this tutorial, you learned how to chat with the Gemma 2B Instruction tuned model using Keras on JAX.\n",
+ "\n",
+ "Check out these guides and tutorials to learn more about Gemma:\n",
+ "\n",
+ "* [Get started with Keras Gemma](https://ai.google.dev/gemma/docs/get_started).\n",
+ "* [Finetune the Gemma model on GPU](https://ai.google.dev/gemma/docs/lora_tuning).\n",
+ "* Learn about [Gemma integration with Vertex AI](https://ai.google.dev/gemma/docs/integrations/vertex)\n",
+ "* Learn how to [use Gemma models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma){:.external}.\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "gemma_chat.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}