diff --git a/open-models/serving/Dockerfile b/open-models/serving/Dockerfile
new file mode 100644
index 00000000000..04708e7f07a
--- /dev/null
+++ b/open-models/serving/Dockerfile
@@ -0,0 +1,18 @@
+
+FROM ollama/ollama
+# Set the host and port to listen on
+ENV OLLAMA_HOST 0.0.0.0:8080
+# Set the directory to store model weight files
+ENV OLLAMA_MODELS /models
+# Reduce the verbosity of the logs
+ENV OLLAMA_DEBUG false
+# Do not unload model weights from the GPU
+ENV OLLAMA_KEEP_ALIVE -1
+# Choose the model to load. Ollama defaults to 4-bit quantized weights
+ENV MODEL gemma2:9b
+# Start the ollama server and download the model weights
+RUN ollama serve & sleep 5 && ollama pull $MODEL
+# At startup time we start the server and run a dummy request
+# to request the model to be loaded in the GPU memory
+ENTRYPOINT ["/bin/sh"]
+CMD ["-c", "ollama serve & (ollama run $MODEL 'Say one word' &) && wait"]
diff --git a/open-models/serving/cloud_run_ollama_gemma2_rag_qa.ipynb b/open-models/serving/cloud_run_ollama_gemma2_rag_qa.ipynb
index f77a7b3745a..a411200c47d 100644
--- a/open-models/serving/cloud_run_ollama_gemma2_rag_qa.ipynb
+++ b/open-models/serving/cloud_run_ollama_gemma2_rag_qa.ipynb
@@ -1,1120 +1,1137 @@
{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ur8xi4C7S06n"
- },
- "outputs": [],
- "source": [
- "# Copyright 2024 Google LLC\n",
- "#\n",
- "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
- "# you may not use this file except in compliance with the License.\n",
- "# You may obtain a copy of the License at\n",
- "#\n",
- "# https://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing, software\n",
- "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
- "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
- "# See the License for the specific language governing permissions and\n",
- "# limitations under the License."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JAPoU8Sm5E6e"
- },
- "source": [
- "# Cloud Run GPU Inference: Gemma 2 RAG Q&A with Ollama and LangChain\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " ![\"Google](\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\") Open in Colab\n",
- " \n",
- " | \n",
- " \n",
- " \n",
- " ![\"Google](\"https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png\") Open in Colab Enterprise\n",
- " \n",
- " | \n",
- " \n",
- " \n",
- " ![\"Vertex](\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\") Open in Workbench\n",
- " \n",
- " | \n",
- " \n",
- " \n",
- " ![\"GitHub](\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\") View on GitHub\n",
- " \n",
- " | \n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "84f0f73a0f76"
- },
- "source": [
- "| | |\n",
- "|-|-|\n",
- "| Author(s) | [Elia Secchi](https://github.com/eliasecchig/) |"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tvgnzT1CKxrO"
- },
- "source": [
- "## Overview\n",
- "\n",
- "\n",
- "\n",
- "> **[Cloud Run](https://cloud.google.com/run)**:\n",
- "It's a serverless platform by Google Cloud for running containerized applications. It automatically scales and manages infrastructure, supporting various programming languages. Cloud Run now offers GPU acceleration for AI/ML workloads.\n",
- "\n",
- "> **Note:** GPU support in Cloud Run is a guarded feature. Before running this notebook, make sure your Google Cloud project is enabled. You can do that by visiting this page [g.co/cloudrun/gpu](https://g.co/cloudrun/gpu).\n",
- "\n",
- "\n",
- "> **[Ollama](ollama.com)**: is an open-source tool for easily running and deploying large language models locally. It offers simple management and usage of LLMs on personal computers or servers.\n",
- "\n",
- "This notebook showcase how to deploy [Google Gemma 2](https://blog.google/technology/developers/google-gemma-2/) in Cloud Run, with the objective to build a simple RAG Q&A application.\n",
- "\n",
- "By the end of this notebook, you will learn how to:\n",
- "\n",
- "1. Deploy Google Gemma 2 on Cloud Run using Ollama\n",
- "2. Implement a Retrieval-Augmented Generation (RAG) application with Gemma 2 and Ollama\n",
- "3. Build a custom container with Ollama to deploy any Large Language Model (LLM) of your choice\n",
- "\n",
- "\n",
- "\n",
- "### Required roles\n",
- "\n",
- "To get the permissions that you need to complete the tutorial, ask your administrator to grant you the following IAM roles on your project:\n",
- "\n",
- "1. Artifact Registry Administrator (`roles/artifactregistry.admin`)\n",
- "2. Cloud Build Editor (`roles/cloudbuild.builds.editor`)\n",
- "3. Cloud Run Admin (`roles/run.developer`)\n",
- "4. Service Account User (`roles/iam.serviceAccountUser`)\n",
- "5. Service Usage Consumer (`roles/serviceusage.serviceUsageConsumer`)\n",
- "6. Storage Admin (`roles/storage.admin`)\n",
- "\n",
- "\n",
- "\n",
- "For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "FYbo7iEPluZQ"
- },
- "source": [
- "![cloud_run_gemma_ollama.png](https://storage.googleapis.com/github-repo/generative-ai/open-models/serving/cloud_run_gemma_ollama.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "61RBz8LLbxCR"
- },
- "source": [
- "## Get started"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "No17Cw5hgx12"
- },
- "source": [
- "### Install Vertex AI SDK and other required packages\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "tFy3H3aPgx12"
- },
- "outputs": [],
- "source": [
- "%pip install --upgrade --user --quiet google-cloud-aiplatform langchain-community langchainhub langchain_google_vertexai"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "R5Xep4W9lq-Z"
- },
- "source": [
- "### Restart runtime\n",
- "\n",
- "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
- "\n",
- "The restart might take a minute or longer. After it's restarted, continue to the next step."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "XRvKdaPDTznN"
- },
- "outputs": [],
- "source": [
- "import IPython\n",
- "\n",
- "app = IPython.Application.instance()\n",
- "app.kernel.do_shutdown(True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "SbmM4z7FOBpM"
- },
- "source": [
- "\n",
- "⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️\n",
- "
\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dmWOrTJ3gx13"
- },
- "source": [
- "### Authenticate your notebook environment (Colab only)\n",
- "\n",
- "If you're running this notebook on Google Colab, run the cell below to authenticate your environment.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "NyKGtVQjgx13"
- },
- "outputs": [],
- "source": [
- "!gcloud auth login --update-adc --quiet"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "DF4l8DTdWgPY"
- },
- "source": [
- "### Set Google Cloud project information and initialize Vertex AI SDK\n",
- "\n",
- "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
- "\n",
- "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Nqwi-5ufWp_B"
- },
- "outputs": [],
- "source": [
- "PROJECT_ID = \"genai-blackbelt-fishfooding\" # @param {type:\"string\"}\n",
- "LOCATION = \"us-central1\" # @param {type:\"string\"}"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8pB4NiQAMzgt"
- },
- "source": [
- "### Fetch your Google Cloud project number"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Y54slycDMjHK"
- },
- "outputs": [],
- "source": [
- "PROJECT_NUMBER = get_ipython().getoutput('gcloud projects describe $PROJECT_ID --format=\"value(projectNumber)\"')[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "EdvJRUWRNGHE"
- },
- "source": [
- "## Deploy Ollama with Cloud Run"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5J5rY6YhxTRl"
- },
- "source": [
- "## Build your container\n",
- "\n",
- "For deploying Gemma 2 in Cloud Run, create a container that packages the Ollama server and the Gemma 2 model.\n",
- "\n",
- "To build the container, you can use [Cloud Build](https://cloud.google.com/build), a serverless CI/CD platform which allows developers to easily build software.\n",
- "\n",
- "> For optimal startup time and improved scalability, it's recommended to store model weights for Gemma 2 (9B) and similarly sized models directly in the container image.\n",
- "However, consider the storage requirements of larger models as they might be impractical to store in the container image. Refer to [Best practices: AI inference on Cloud Run with GPUs](https://cloud.google.com/run/docs/configuring/services/gpu-best-practices#loading-storing-models-tradeoff) for an overview of the trade-offs.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "IprOEAAN1sBQ"
- },
- "source": [
- "### Create Artifact Registry repository\n",
- "\n",
- "To build a container you will need to first create a repository in Google Cloud Artifact Registry:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "p5hXDtoYsCEB"
- },
- "outputs": [],
- "source": [
- "AR_REPOSITORY_NAME = \"cr-gpu-repo\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "z1ZBM9PDrbdM"
- },
- "outputs": [],
- "source": [
- "!gcloud artifacts repositories create $AR_REPOSITORY_NAME \\\n",
- " --repository-format=docker \\\n",
- " --location=$LOCATION \\\n",
- " --project=$PROJECT_ID"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "IDMpuXEu2thu"
- },
- "source": [
- "### Create a Dockerfile\n",
- "\n",
- "You will then need to create a Dockerfile which defines the build steps of the container.\n",
- "\n",
- "You can customize the model used by modifying the `MODEL` environment variable. \n",
- "Explore the [Ollama library](https://ollama.com/library) for a comprehensive list of available models."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Vi9T53CScWdn"
- },
- "outputs": [],
- "source": [
- "%%writefile Dockerfile\n",
- "FROM ollama/ollama\n",
- "\n",
- "# Set the host and port to listen on\n",
- "ENV OLLAMA_HOST 0.0.0.0:8080\n",
- "\n",
- "# Set the directory to store model weight files\n",
- "ENV OLLAMA_MODELS /models\n",
- "\n",
- "# Reduce the verbosity of the logs\n",
- "ENV OLLAMA_DEBUG false\n",
- "\n",
- "# Do not unload model weights from the GPU\n",
- "ENV OLLAMA_KEEP_ALIVE -1\n",
- "\n",
- "# Choose the model to load. Ollama defaults to 4-bit quantized weights\n",
- "ENV MODEL gemma2:9b\n",
- "\n",
- "# Start the ollama server and download the model weights\n",
- "RUN ollama serve & sleep 5 && ollama pull $MODEL\n",
- "\n",
- "# At startup time we start the server and run a dummy request\n",
- "# to request the model to be loaded in the GPU memory\n",
- "ENTRYPOINT [\"/bin/sh\"]\n",
- "CMD [\"-c\", \"ollama serve & (ollama run $MODEL 'Say one word' &) && wait\"]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5RnaYx2p235W"
- },
- "source": [
- "### Trigger Cloud Build\n",
- "\n",
- "You are now ready to trigger the container build process!\n",
- "We will use the `gcloud builds submit` command, using a `e2-highcpu-32` machine to optimize build time. We use e2-highcpu-32 machines because multiple cores allow for parallel downloads, significantly speeding up the build process.\n",
- "\n",
- "Cloud Build pricing is based on build minutes consumed. See [the pricing page](https://cloud.google.com/build/pricing) for details\n",
- "\n",
- "The operation will take ~10 minutes for completion."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "k2aooaREsT-F"
- },
- "outputs": [],
- "source": [
- "CONTAINER_URI = (\n",
- " f\"{LOCATION}-docker.pkg.dev/{PROJECT_ID}/{AR_REPOSITORY_NAME}/ollama-gemma-2\"\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "CU8n7kk5OeP8"
- },
- "outputs": [],
- "source": [
- "!gcloud builds submit --tag $CONTAINER_URI --project $PROJECT_ID --machine-type e2-highcpu-32"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xd_Zfz9c3cZy"
- },
- "source": [
- "You can now use the container you just built to deploy a new Cloud Run service!"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3YGGFLB-JElj"
- },
- "source": [
- "### Deploy container in Cloud Run\n",
- "\n",
- "You are now ready for deployment! Cloud Run offers multiple deployment methods, including Console, gcloud CLI, Cloud Code, Terraform, YAML, and Client Libraries. Explore all the options in the [official documentation](https://cloud.google.com/run/docs/deploying#service).\n",
- "\n",
- "For quick prototyping, you can start with the gcloud CLI `gcloud run deploy` command. This convenient command-line tool provides a straightforward way to get your container running on Cloud Run. Learn more about its features and usage in the [gcloud CLI reference](https://cloud.google.com/sdk/gcloud/reference/run/deploy).\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "8e6kybbhp3Na"
- },
- "outputs": [],
- "source": [
- "SERVICE_NAME = \"ollama-gemma-2\" # @param {type:\"string\"}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "kDkLl8AFKKD0"
- },
- "outputs": [],
- "source": [
- "!gcloud beta run deploy $SERVICE_NAME \\\n",
- " --project $PROJECT_ID \\\n",
- " --region $LOCATION \\\n",
- " --image $CONTAINER_URI \\\n",
- " --concurrency 4 \\\n",
- " --cpu 8 \\\n",
- " --gpu 1 \\\n",
- " --gpu-type nvidia-l4 \\\n",
- " --max-instances 7 \\\n",
- " --memory 32Gi \\\n",
- " --no-allow-unauthenticated \\\n",
- " --no-cpu-throttling \\\n",
- " --timeout=600"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "e1afbaee64a4"
- },
- "source": [
- "*Expect a slower initial deployment as the container image is being pulled for the first time.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8IRTamcobASG"
- },
- "source": [
- "### Setting concurrency for optimal performance\n",
- "\n",
- "In Cloud Run, [concurrency](https://cloud.google.com/run/docs/about-concurrency) defines the maximum number of requests that can be processed simultaneously by a given instance.\n",
- "\n",
- "For this sample we set a `concurrency` value equal to 4.\n",
- "\n",
- "As part of your use case you might need to experiment with different concurrency settings to find the best latency vs throughput tradeoff.\n",
- "\n",
- "Refer to the following documentation pages to know more about performance optimizations:\n",
- "- [Setting concurrency for optimal performance in Cloud Run](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama#set-concurrency-for-performance)\n",
- "- [GPU performance best practices](https://cloud.google.com/run/docs/configuring/services/gpu-best-practices)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XSrXJkabGdjw"
- },
- "source": [
- "## Invoking Gemma 2 in Cloud Run\n",
- "\n",
- "We are now ready to send some requests to Gemma!\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Vrx30A8jKwrY"
- },
- "source": [
- "### Fetch identity token\n",
- "\n",
- "Once deployed to Cloud Run, to invoke Gemma 2, we will need to fetch an Identity token to perform authentication. See the relative documentation to discover more about [authentication in Cloud Run](https://cloud.google.com/run/docs/authenticating/overview).\n",
- "\n",
- "In the appendix of this sample, you'll find a helper function that supports the automatic refresh of the [Identity Token](https://cloud.google.com/docs/authentication/token-types#id), which expires every hour by default."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "qSa5aZCPuLlU"
- },
- "outputs": [],
- "source": [
- "ID_TOKEN = get_ipython().getoutput('gcloud auth print-identity-token -q')[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "UmA-haVjOA6U"
- },
- "source": [
- "### Setup the Service URL & model name"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "LOVfy893tvcl"
- },
- "outputs": [],
- "source": [
- "SERVICE_URL = f\"https://{SERVICE_NAME}-{PROJECT_NUMBER}.{LOCATION}.run.app\" # type: ignore"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "IcPKeFPNQZzI"
- },
- "outputs": [],
- "source": [
- "MODEL_NAME = \"gemma2:9b\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XbOtGicVLNgD"
- },
- "source": [
- "## Invoking Gemma\n",
- "\n",
- "You are ready to test the model you just deployed! The [Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md) are a great resource to learn more about the different endpoints and how to interact with your model.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "iMI0nlXVT20t"
- },
- "source": [
- "#### Invoke through CURL request\n",
- "You can invoke Gemma and Cloud Run in many ways. For example, you can send an HTTP CURL request to Cloud Run:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4b1c47642f7e"
- },
- "outputs": [],
- "source": [
- "ENDPOINT_URL = f\"{SERVICE_URL}/api/generate\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "NsixJcaBP2q4"
- },
- "outputs": [],
- "source": [
- "%%bash -s \"$ENDPOINT_URL\" \"$ID_TOKEN\" \"$MODEL_NAME\" \n",
- "ENDPOINT_URL=$1\n",
- "ID_TOKEN=$2\n",
- "MODEL_NAME=$3\n",
- "\n",
- "curl -s -X POST \"${ENDPOINT_URL}\" \\\n",
- "-H \"Authorization: Bearer ${ID_TOKEN}\" \\\n",
- "-H \"Content-Type: application/json\" \\\n",
- "-d '{ \"model\": \"'${MODEL_NAME}'\", \"prompt\": \"Hi\", \"max_tokens\": 100, \"stream\": false}'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "41657205b738"
- },
- "source": [
- "#### Invoke with a Python POST Request\n",
- "\n",
- "You can also invoke the model using a POST request with Python's popular `requests` library. [Learn more about the `requests` library here.](https://requests.readthedocs.io/en/latest/) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "2e8c87dfd38b"
- },
- "outputs": [],
- "source": [
- "import requests\n",
- "\n",
- "headers = {\"Authorization\": f\"Bearer {ID_TOKEN}\", \"Content-Type\": \"application/json\"} # type: ignore\n",
- "\n",
- "data = {\n",
- " \"model\": MODEL_NAME,\n",
- " \"prompt\": \"Hi, I am using python!\",\n",
- " \"max_tokens\": 100,\n",
- " \"stream\": False,\n",
- "}\n",
- "\n",
- "response = requests.post(ENDPOINT_URL, headers=headers, json=data)\n",
- "\n",
- "print(response.text)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bFoe0NVOT6DD"
- },
- "source": [
- "#### Invoke Ollama with Python integrations\n",
- "\n",
- "Popular Generative AI orchestration frameworks like [LangChain](https://www.langchain.com) and [LlamaIndex](https://www.llamaindex.ai/) offer direct integration with Ollama:\n",
- "- [LangChain integration](https://python.langchain.com/v0.2/docs/integrations/llms/ollama/)\n",
- "- [LlamaIndex integration](https://docs.llamaindex.ai/en/stable/api_reference/llms/ollama/)\n",
- "\n",
- "As part of this sample, we will be using the LangChain integration to perform different calls and build a sample RAG chain."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vZyZqnnNaeWw"
- },
- "source": [
- "### Import libraries"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "gQDWB66Vadlx"
- },
- "outputs": [],
- "source": [
- "import google.auth\n",
- "from langchain_community.chat_models import ChatOllama\n",
- "from langchain_community.document_loaders import WebBaseLoader\n",
- "from langchain_community.vectorstores import SKLearnVectorStore\n",
- "from langchain_core.output_parsers import StrOutputParser\n",
- "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
- "from langchain_core.runnables import RunnablePassthrough\n",
- "from langchain_google_vertexai import VertexAIEmbeddings\n",
- "from langchain_text_splitters import CharacterTextSplitter"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_hnaKrZftjbT"
- },
- "outputs": [],
- "source": [
- "llm = ChatOllama(\n",
- " model=MODEL_NAME,\n",
- " base_url=SERVICE_URL,\n",
- " num_predict=300,\n",
- " headers={\"Authorization\": f\"Bearer {ID_TOKEN}\"}, # type: ignore\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "9GYGr76T7aYF"
- },
- "outputs": [],
- "source": [
- "# You can perform a synchronous invocation through the `.invoke` method\n",
- "\n",
- "llm.invoke(\"Hi!\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yVqxaDylWjck"
- },
- "source": [
- "Or invoke through the generation of a stream through the `.stream` **method**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "2twtoa6a4_ui"
- },
- "outputs": [],
- "source": [
- "# You can also generate a stream through the `.stream` method\n",
- "\n",
- "for m in llm.stream(\"Hi!\"):\n",
- " print(m)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bqr_QPso7shY"
- },
- "source": [
- "## RAG Q&A Chain with Gemma 2 and Cloud Run\n",
- "\n",
- "We can leverage the LangChain integration to create a sample RAG application with Gemma, Cloud Run, [Vertex AI Embedding](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings) for generating embeddings and [FAISS vector store](https://python.langchain.com/v0.2/docs/integrations/vectorstores/faiss/) for document retrieval.\n",
- "\n",
- "Through RAG, we will ask Gemma 2 to answer questions about the [Cloud Run documentation page](https://cloud.google.com/run/docs/overview/what-is-cloud-run)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wHq8zpG5a4u9"
- },
- "source": [
- "### Setup embedding model and retriever\n",
- "\n",
- "We are ready to setup our embedding model and retriever."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "tI66kos7a8B4"
- },
- "outputs": [],
- "source": [
- "credentials, _ = google.auth.default(quota_project_id=PROJECT_ID)\n",
- "embeddings = VertexAIEmbeddings(\n",
- " project=PROJECT_ID, model_name=\"text-embedding-004\", credentials=credentials\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "xxCu8MST6oWK"
- },
- "outputs": [],
- "source": [
- "loader = WebBaseLoader(\"https://cloud.google.com/run/docs/overview/what-is-cloud-run\")\n",
- "docs = loader.load()\n",
- "documents = CharacterTextSplitter(chunk_size=800, chunk_overlap=100).split_documents(\n",
- " docs\n",
- ")\n",
- "\n",
- "vector = SKLearnVectorStore.from_documents(documents, embeddings)\n",
- "retriever = vector.as_retriever()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rCH9n2UEzwMv"
- },
- "source": [
- "### RAG Chain Definition\n",
- "\n",
- "We will define now our RAG Chain.\n",
- "\n",
- "The RAG chain works as follows:\n",
- "\n",
- "1. The user's query and conversation history are passed to the `query_rewrite_chain` to generate a rewritten query optimized for semantic search.\n",
- "2. The rewritten query is used by the `retriever` to fetch relevant documents.\n",
- "3. The retrieved documents are formatted into a single string.\n",
- "4. The formatted documents, along with the original user messages, are passed to the LLM with instructions to generate an answer based on the provided context.\n",
- "5. The LLM's response is parsed and returned as the final answer."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "x1nPO00s4r4o"
- },
- "outputs": [],
- "source": [
- "answer_generation_template = ChatPromptTemplate.from_messages(\n",
- " [\n",
- " (\n",
- " \"system\",\n",
- " \"You are an assistant for question answering-tasks. \"\n",
- " \"Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. \"\n",
- " \"{context}\",\n",
- " ),\n",
- " MessagesPlaceholder(variable_name=\"messages\"),\n",
- " ]\n",
- ")\n",
- "query_rewrite_template = ChatPromptTemplate.from_messages(\n",
- " [\n",
- " (\n",
- " \"system\",\n",
- " \"Rewrite a query to a semantic search engine using the current conversation. \"\n",
- " \"Provide only the rewritten query as output.\",\n",
- " ),\n",
- " MessagesPlaceholder(variable_name=\"messages\"),\n",
- " ]\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jm309i0r0Lwv"
- },
- "outputs": [],
- "source": [
- "query_rewrite_chain = query_rewrite_template | llm\n",
- "\n",
- "\n",
- "def extract_query(messages):\n",
- " return query_rewrite_chain.invoke(messages).content\n",
- "\n",
- "\n",
- "def format_docs(docs):\n",
- " return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
- "\n",
- "\n",
- "rag_chain = (\n",
- " {\n",
- " \"context\": extract_query | retriever | format_docs,\n",
- " \"messages\": RunnablePassthrough(),\n",
- " }\n",
- " | answer_generation_template\n",
- " | llm\n",
- " | StrOutputParser()\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rRFtc7eo0Tn7"
- },
- "source": [
- "### Testing the RAG Chain"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "MIfd_Wwa7RUe"
- },
- "outputs": [],
- "source": [
- "rag_chain.invoke([(\"human\", \"What features does Cloud Run offer?\")])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "f0994074042d"
- },
- "source": [
- "Now, let's use a specific question from the documentation to explore how RAG addresses potential gaps in the model's knowledge."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ec5895d100f5"
- },
- "outputs": [],
- "source": [
- "QUESTION = \"What are the three ways you can use Cloud Run jobs?\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "e0abfae19b32"
- },
- "source": [
- "First, we'll ask the LLM directly:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4fb7c3e2426e"
- },
- "outputs": [],
- "source": [
- "llm.invoke(QUESTION)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5575573f8174"
- },
- "source": [
- "Then, we'll ask the same question using the RAG chain:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "AszY1ke5bWC9"
- },
- "outputs": [],
- "source": [
- "rag_chain.invoke([(\"human\", QUESTION)])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "88c5f72f1981"
- },
- "source": [
- "## Conclusion\n",
- "Congratulations. Now you know how to deploy an open model to Cloud Run powered by a GPU! Specifically, you deployed a Gemma 2 model to Cloud Run with a GPU, as part of a RAG application powered by LangChain. You were able to ask answers from Gemma 2 about a documentation page.\n",
- "\n",
- "For more information about your identity tokens expiring and how to refresh your tokens, see the next section below \"Appendix: Handling Identity Token Expiration\".\n",
- "\n",
- "To clean up the resources you created in this section, see the section at the bottom \"Cleaning up\"."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A5VPH4gP3igf"
- },
- "source": [
- "## Appendix: Handling Identity Token Expiration\n",
- "\n",
- "When deploying a Generative AI application Google Cloud Run, you'll often need to authenticate your requests using Identity Tokens.\n",
- "\n",
- "These tokens will expire hourly, requiring a mechanism for automatic refresh to ensure uninterrupted operation.\n",
- "\n",
- "The following helper classes provide an example of how to deal with token refresh. It leverages the `google.auth` library to handle the authentication process and automatically refresh the token when necessary.\n",
- "\n",
- "\n",
- "See the following resources for more information on authentication:\n",
- "* [Identity Token Overview](https://cloud.google.com/docs/authentication/token-types#id)\n",
- "* [Google Cloud Run Authentication Documentation](https://cloud.google.com/run/docs/authenticating/overview)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "TJlxAZKz-tjY"
- },
- "outputs": [],
- "source": [
- "import time\n",
- "\n",
- "import google.auth\n",
- "from google.auth.exceptions import DefaultCredentialsError\n",
- "import google.auth.transport.requests\n",
- "import google.oauth2.id_token\n",
- "from pydantic.v1 import Extra\n",
- "\n",
- "\n",
- "class TokenManager:\n",
- " def __init__(\n",
- " self, url=None, token_lifetime=3600\n",
- " ): # Default token lifetime of 1 hour\n",
- " self.token = None\n",
- " self.expiry_time = 0\n",
- " self.token_lifetime = token_lifetime\n",
- " self.url = url\n",
- " self.creds, _ = google.auth.default()\n",
- "\n",
- " def get_token(self):\n",
- " if time.time() >= self.expiry_time:\n",
- " self.refresh_token(url=self.url)\n",
- " return self.token\n",
- "\n",
- " def refresh_token(self, url):\n",
- " \"\"\"\n",
- " Retrieves an ID token, attempting to use default credentials first,\n",
- " and falling back to fetching a service-to-service new token if necessary.\n",
- " See more on Cloud Run authentication at this link:\n",
- " https://cloud.google.com/run/docs/authenticating/service-to-service\n",
- " Args:\n",
- " url: The URL to use for the token request.\n",
- " \"\"\"\n",
- "\n",
- " auth_req = google.auth.transport.requests.Request()\n",
- " try:\n",
- " self.token = google.oauth2.id_token.fetch_id_token(auth_req, url)\n",
- " except DefaultCredentialsError:\n",
- " self.creds.refresh(auth_req)\n",
- " self.token = self.creds.id_token\n",
- "\n",
- " self.expiry_time = time.time() + self.token_lifetime\n",
- "\n",
- "\n",
- "class ChatOllamaWithAuth(ChatOllama):\n",
- " class Config:\n",
- " extra = Extra.allow\n",
- "\n",
- " def __init__(self, *args, **kwargs):\n",
- " super().__init__(*args, **kwargs)\n",
- " self._token_manager = TokenManager()\n",
- " self.headers = {} if self.headers is None else self.headers\n",
- "\n",
- " def _generate(self, *args, **kwargs) -> str:\n",
- " self.headers[\"Authorization\"] = f\"Bearer {self._token_manager.get_token()}\"\n",
- " return super()._generate(*args, **kwargs)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "GveBb6tLPUEc"
- },
- "outputs": [],
- "source": [
- "llm = ChatOllamaWithAuth(model=MODEL_NAME, base_url=SERVICE_URL, num_predict=300)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LjgBf0M4Mokn"
- },
- "source": [
- "You can now use the `invoke` function as usual, with the token being refreshed automatically every hour.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "NH0kWCloMlUe"
- },
- "outputs": [],
- "source": [
- "llm.invoke(\"Hi, testing a request\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "270LFAC0P3i4"
- },
- "source": [
- "## Cleaning up\n",
- "To clean up all Google Cloud resources, you can run the following cell to delete the Cloud Run service you created.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "XJRe7I7KP2HM"
- },
- "outputs": [],
- "source": [
- "# Delete the Cloud Run service deployed above\n",
- "\n",
- "!gcloud run services delete $SERVICE_NAME --project $PROJECT_ID --region $LOCATION --quiet"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "name": "cloud_run_ollama_gemma2_rag_qa.ipynb",
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ur8xi4C7S06n"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright 2024 Google LLC\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JAPoU8Sm5E6e"
+ },
+ "source": [
+ "# Cloud Run GPU Inference: Gemma 2 RAG Q&A with Ollama and LangChain\n",
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " ![\"Google](\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\") Open in Colab\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " ![\"Google](\"https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png\") Open in Colab Enterprise\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " ![\"Vertex](\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\") Open in Workbench\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " ![\"GitHub](\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\") View on GitHub\n",
+ " \n",
+ " | \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "84f0f73a0f76"
+ },
+ "source": [
+ "| | |\n",
+ "|-|-|\n",
+ "| Author(s) | [Elia Secchi](https://github.com/eliasecchig/) |"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tvgnzT1CKxrO"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "\n",
+ "\n",
+ "> **[Cloud Run](https://cloud.google.com/run)**:\n",
+ "It's a serverless platform by Google Cloud for running containerized applications. It automatically scales and manages infrastructure, supporting various programming languages. Cloud Run now offers GPU acceleration for AI/ML workloads.\n",
+ "\n",
+ "> **Note:** GPU support in Cloud Run is a guarded feature. Before running this notebook, make sure your Google Cloud project is enabled. You can do that by visiting this page [g.co/cloudrun/gpu](https://g.co/cloudrun/gpu).\n",
+ "\n",
+ "\n",
+ "> **[Ollama](ollama.com)**: is an open-source tool for easily running and deploying large language models locally. It offers simple management and usage of LLMs on personal computers or servers.\n",
+ "\n",
+ "This notebook showcase how to deploy [Google Gemma 2](https://blog.google/technology/developers/google-gemma-2/) in Cloud Run, with the objective to build a simple RAG Q&A application.\n",
+ "\n",
+ "By the end of this notebook, you will learn how to:\n",
+ "\n",
+ "1. Deploy Google Gemma 2 on Cloud Run using Ollama\n",
+ "2. Implement a Retrieval-Augmented Generation (RAG) application with Gemma 2 and Ollama\n",
+ "3. Build a custom container with Ollama to deploy any Large Language Model (LLM) of your choice\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Required roles\n",
+ "\n",
+ "To get the permissions that you need to complete the tutorial, ask your administrator to grant you the following IAM roles on your project:\n",
+ "\n",
+ "1. Artifact Registry Administrator (`roles/artifactregistry.admin`)\n",
+ "2. Cloud Build Editor (`roles/cloudbuild.builds.editor`)\n",
+ "3. Cloud Run Admin (`roles/run.developer`)\n",
+ "4. Service Account User (`roles/iam.serviceAccountUser`)\n",
+ "5. Service Usage Consumer (`roles/serviceusage.serviceUsageConsumer`)\n",
+ "6. Storage Admin (`roles/storage.admin`)\n",
+ "\n",
+ "\n",
+ "\n",
+ "For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FYbo7iEPluZQ"
+ },
+ "source": [
+ "![cloud_run_gemma_ollama.png](https://storage.googleapis.com/github-repo/generative-ai/open-models/serving/cloud_run_gemma_ollama.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "61RBz8LLbxCR"
+ },
+ "source": [
+ "## Get started"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "No17Cw5hgx12"
+ },
+ "source": [
+ "### Install Vertex AI SDK and other required packages\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tFy3H3aPgx12"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install --upgrade --user --quiet google-cloud-aiplatform langchain-community langchainhub langchain_google_vertexai"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "R5Xep4W9lq-Z"
+ },
+ "source": [
+ "### Restart runtime\n",
+ "\n",
+ "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
+ "\n",
+ "The restart might take a minute or longer. After it's restarted, continue to the next step."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XRvKdaPDTznN"
+ },
+ "outputs": [],
+ "source": [
+ "import IPython\n",
+ "\n",
+ "app = IPython.Application.instance()\n",
+ "app.kernel.do_shutdown(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SbmM4z7FOBpM"
+ },
+ "source": [
+ "\n",
+ "⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️\n",
+ "
\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dmWOrTJ3gx13"
+ },
+ "source": [
+ "### Authenticate your notebook environment (Colab only)\n",
+ "\n",
+ "If you're running this notebook on Google Colab, run the cell below to authenticate your environment.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NyKGtVQjgx13"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud auth login --update-adc --quiet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DF4l8DTdWgPY"
+ },
+ "source": [
+ "### Set Google Cloud project information and initialize Vertex AI SDK\n",
+ "\n",
+ "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
+ "\n",
+ "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Nqwi-5ufWp_B"
+ },
+ "outputs": [],
+ "source": [
+ "# Use the environment variable if the user doesn't provide Project ID.\n",
+ "import os\n",
+ "\n",
+ "import vertexai\n",
+ "\n",
+ "PROJECT_ID = \"[your-project-id]\" # @param {type:\"string\", isTemplate: true}\n",
+ "if PROJECT_ID == \"[your-project-id]\":\n",
+ " PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
+ "\n",
+ "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
+ "\n",
+ "vertexai.init(project=PROJECT_ID, location=LOCATION)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8pB4NiQAMzgt"
+ },
+ "source": [
+ "### Fetch your Google Cloud project number"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Y54slycDMjHK"
+ },
+ "outputs": [],
+ "source": [
+ "PROJECT_NUMBER = get_ipython().getoutput('gcloud projects describe $PROJECT_ID --format=\"value(projectNumber)\"')[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EdvJRUWRNGHE"
+ },
+ "source": [
+ "## Deploy Ollama with Cloud Run"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5J5rY6YhxTRl"
+ },
+ "source": [
+ "## Build your container\n",
+ "\n",
+ "For deploying Gemma 2 in Cloud Run, create a container that packages the Ollama server and the Gemma 2 model.\n",
+ "\n",
+ "To build the container, you can use [Cloud Build](https://cloud.google.com/build), a serverless CI/CD platform which allows developers to easily build software.\n",
+ "\n",
+ "> For optimal startup time and improved scalability, it's recommended to store model weights for Gemma 2 (9B) and similarly sized models directly in the container image.\n",
+ "However, consider the storage requirements of larger models as they might be impractical to store in the container image. Refer to [Best practices: AI inference on Cloud Run with GPUs](https://cloud.google.com/run/docs/configuring/services/gpu-best-practices#loading-storing-models-tradeoff) for an overview of the trade-offs.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IprOEAAN1sBQ"
+ },
+ "source": [
+ "### Create Artifact Registry repository\n",
+ "\n",
+ "To build a container you will need to first create a repository in Google Cloud Artifact Registry:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "p5hXDtoYsCEB"
+ },
+ "outputs": [],
+ "source": [
+ "AR_REPOSITORY_NAME = \"cr-gpu-repo\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "z1ZBM9PDrbdM"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud artifacts repositories create $AR_REPOSITORY_NAME \\\n",
+ " --repository-format=docker \\\n",
+ " --location=$LOCATION \\\n",
+ " --project=$PROJECT_ID"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IDMpuXEu2thu"
+ },
+ "source": [
+ "### Create a Dockerfile\n",
+ "\n",
+ "You will then need to create a Dockerfile which defines the build steps of the container.\n",
+ "\n",
+ "You can customize the model used by modifying the `MODEL` variable. \n",
+ "Explore the [Ollama library](https://ollama.com/library) for a comprehensive list of available models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "IcPKeFPNQZzI"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = \"gemma2:9b\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Vi9T53CScWdn"
+ },
+ "outputs": [],
+ "source": [
+ "dockerfile_content = f\"\"\"\n",
+ "FROM ollama/ollama\n",
+ "# Set the host and port to listen on\n",
+ "ENV OLLAMA_HOST 0.0.0.0:8080\n",
+ "# Set the directory to store model weight files\n",
+ "ENV OLLAMA_MODELS /models\n",
+ "# Reduce the verbosity of the logs\n",
+ "ENV OLLAMA_DEBUG false\n",
+ "# Do not unload model weights from the GPU\n",
+ "ENV OLLAMA_KEEP_ALIVE -1\n",
+ "# Choose the model to load. Ollama defaults to 4-bit quantized weights\n",
+ "ENV MODEL {MODEL_NAME}\n",
+ "# Start the ollama server and download the model weights\n",
+ "RUN ollama serve & sleep 5 && ollama pull $MODEL\n",
+ "# At startup time we start the server and run a dummy request\n",
+ "# to request the model to be loaded in the GPU memory\n",
+ "ENTRYPOINT [\"/bin/sh\"]\n",
+ "CMD [\"-c\", \"ollama serve & (ollama run $MODEL 'Say one word' &) && wait\"]\n",
+ "\"\"\"\n",
+ "\n",
+ "# Write the Dockerfile\n",
+ "with open(\"Dockerfile\", \"w\") as f:\n",
+ " f.write(dockerfile_content)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5RnaYx2p235W"
+ },
+ "source": [
+ "### Trigger Cloud Build\n",
+ "\n",
+ "You are now ready to trigger the container build process!\n",
+ "We will use the `gcloud builds submit` command, using a `e2-highcpu-32` machine to optimize build time. We use e2-highcpu-32 machines because multiple cores allow for parallel downloads, significantly speeding up the build process.\n",
+ "\n",
+ "Cloud Build pricing is based on build minutes consumed. See [the pricing page](https://cloud.google.com/build/pricing) for details\n",
+ "\n",
+ "The operation will take ~10 minutes for completion."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k2aooaREsT-F"
+ },
+ "outputs": [],
+ "source": [
+ "CONTAINER_URI = (\n",
+ " f\"{LOCATION}-docker.pkg.dev/{PROJECT_ID}/{AR_REPOSITORY_NAME}/ollama-gemma-2\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CU8n7kk5OeP8"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud builds submit --tag $CONTAINER_URI --project $PROJECT_ID --machine-type e2-highcpu-32"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xd_Zfz9c3cZy"
+ },
+ "source": [
+ "You can now use the container you just built to deploy a new Cloud Run service!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3YGGFLB-JElj"
+ },
+ "source": [
+ "### Deploy container in Cloud Run\n",
+ "\n",
+ "You are now ready for deployment! Cloud Run offers multiple deployment methods, including Console, gcloud CLI, Cloud Code, Terraform, YAML, and Client Libraries. Explore all the options in the [official documentation](https://cloud.google.com/run/docs/deploying#service).\n",
+ "\n",
+ "For quick prototyping, you can start with the gcloud CLI `gcloud run deploy` command. This convenient command-line tool provides a straightforward way to get your container running on Cloud Run. Learn more about its features and usage in the [gcloud CLI reference](https://cloud.google.com/sdk/gcloud/reference/run/deploy).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8e6kybbhp3Na"
+ },
+ "outputs": [],
+ "source": [
+ "SERVICE_NAME = \"ollama-gemma-2\" # @param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kDkLl8AFKKD0"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud beta run deploy $SERVICE_NAME \\\n",
+ " --project $PROJECT_ID \\\n",
+ " --region $LOCATION \\\n",
+ " --image $CONTAINER_URI \\\n",
+ " --concurrency 4 \\\n",
+ " --cpu 8 \\\n",
+ " --gpu 1 \\\n",
+ " --gpu-type nvidia-l4 \\\n",
+ " --max-instances 7 \\\n",
+ " --memory 32Gi \\\n",
+ " --no-allow-unauthenticated \\\n",
+ " --no-cpu-throttling \\\n",
+ " --timeout=600"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e1afbaee64a4"
+ },
+ "source": [
+ "*Expect a slower initial deployment as the container image is being pulled for the first time.*"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8IRTamcobASG"
+ },
+ "source": [
+ "### Setting concurrency for optimal performance\n",
+ "\n",
+ "In Cloud Run, [concurrency](https://cloud.google.com/run/docs/about-concurrency) defines the maximum number of requests that can be processed simultaneously by a given instance.\n",
+ "\n",
+ "For this sample we set a `concurrency` value equal to 4.\n",
+ "\n",
+ "As part of your use case you might need to experiment with different concurrency settings to find the best latency vs throughput tradeoff.\n",
+ "\n",
+ "Refer to the following documentation pages to know more about performance optimizations:\n",
+ "- [Setting concurrency for optimal performance in Cloud Run](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama#set-concurrency-for-performance)\n",
+ "- [GPU performance best practices](https://cloud.google.com/run/docs/configuring/services/gpu-best-practices)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XSrXJkabGdjw"
+ },
+ "source": [
+ "## Invoking Gemma 2 in Cloud Run\n",
+ "\n",
+ "We are now ready to send some requests to Gemma!\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Vrx30A8jKwrY"
+ },
+ "source": [
+ "### Fetch identity token\n",
+ "\n",
+ "Once deployed to Cloud Run, to invoke Gemma 2, we will need to fetch an Identity token to perform authentication. See the relative documentation to discover more about [authentication in Cloud Run](https://cloud.google.com/run/docs/authenticating/overview).\n",
+ "\n",
+ "In the appendix of this sample, you'll find a helper function that supports the automatic refresh of the [Identity Token](https://cloud.google.com/docs/authentication/token-types#id), which expires every hour by default."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qSa5aZCPuLlU"
+ },
+ "outputs": [],
+ "source": [
+ "ID_TOKEN = get_ipython().getoutput('gcloud auth print-identity-token -q')[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UmA-haVjOA6U"
+ },
+ "source": [
+ "### Setup the Service URL"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LOVfy893tvcl"
+ },
+ "outputs": [],
+ "source": [
+ "SERVICE_URL = f\"https://{SERVICE_NAME}-{PROJECT_NUMBER}.{LOCATION}.run.app\" # type: ignore"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XbOtGicVLNgD"
+ },
+ "source": [
+ "## Invoking Gemma\n",
+ "\n",
+ "You are ready to test the model you just deployed! The [Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md) are a great resource to learn more about the different endpoints and how to interact with your model.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iMI0nlXVT20t"
+ },
+ "source": [
+ "#### Invoke through CURL request\n",
+ "You can invoke Gemma and Cloud Run in many ways. For example, you can send an HTTP CURL request to Cloud Run:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4b1c47642f7e"
+ },
+ "outputs": [],
+ "source": [
+ "ENDPOINT_URL = f\"{SERVICE_URL}/api/generate\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NsixJcaBP2q4"
+ },
+ "outputs": [],
+ "source": [
+ "%%bash -s \"$ENDPOINT_URL\" \"$ID_TOKEN\" \"$MODEL_NAME\" \n",
+ "ENDPOINT_URL=$1\n",
+ "ID_TOKEN=$2\n",
+ "MODEL_NAME=$3\n",
+ "\n",
+ "curl -s -X POST \"${ENDPOINT_URL}\" \\\n",
+ "-H \"Authorization: Bearer ${ID_TOKEN}\" \\\n",
+ "-H \"Content-Type: application/json\" \\\n",
+ "-d '{ \"model\": \"'${MODEL_NAME}'\", \"prompt\": \"Hi\", \"max_tokens\": 100, \"stream\": false}'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "41657205b738"
+ },
+ "source": [
+ "#### Invoke with a Python POST Request\n",
+ "\n",
+ "You can also invoke the model using a POST request with Python's popular `requests` library. [Learn more about the `requests` library here.](https://requests.readthedocs.io/en/latest/) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2e8c87dfd38b"
+ },
+ "outputs": [],
+ "source": [
+ "import requests\n",
+ "\n",
+ "headers = {\"Authorization\": f\"Bearer {ID_TOKEN}\", \"Content-Type\": \"application/json\"} # type: ignore\n",
+ "\n",
+ "data = {\n",
+ " \"model\": MODEL_NAME,\n",
+ " \"prompt\": \"Hi, I am using python!\",\n",
+ " \"max_tokens\": 100,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "\n",
+ "response = requests.post(ENDPOINT_URL, headers=headers, json=data)\n",
+ "\n",
+ "print(response.text)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bFoe0NVOT6DD"
+ },
+ "source": [
+ "#### Invoke Ollama with Python integrations\n",
+ "\n",
+ "Popular Generative AI orchestration frameworks like [LangChain](https://www.langchain.com) and [LlamaIndex](https://www.llamaindex.ai/) offer direct integration with Ollama:\n",
+ "- [LangChain integration](https://python.langchain.com/v0.2/docs/integrations/llms/ollama/)\n",
+ "- [LlamaIndex integration](https://docs.llamaindex.ai/en/stable/api_reference/llms/ollama/)\n",
+ "\n",
+ "As part of this sample, we will be using the LangChain integration to perform different calls and build a sample RAG chain."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vZyZqnnNaeWw"
+ },
+ "source": [
+ "### Import libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gQDWB66Vadlx"
+ },
+ "outputs": [],
+ "source": [
+ "import google.auth\n",
+ "from langchain_community.chat_models import ChatOllama\n",
+ "from langchain_community.document_loaders import WebBaseLoader\n",
+ "from langchain_community.vectorstores import SKLearnVectorStore\n",
+ "from langchain_core.output_parsers import StrOutputParser\n",
+ "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
+ "from langchain_core.runnables import RunnablePassthrough\n",
+ "from langchain_google_vertexai import VertexAIEmbeddings\n",
+ "from langchain_text_splitters import CharacterTextSplitter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_hnaKrZftjbT"
+ },
+ "outputs": [],
+ "source": [
+ "llm = ChatOllama(\n",
+ " model=MODEL_NAME,\n",
+ " base_url=SERVICE_URL,\n",
+ " num_predict=300,\n",
+ " headers={\"Authorization\": f\"Bearer {ID_TOKEN}\"}, # type: ignore\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9GYGr76T7aYF"
+ },
+ "outputs": [],
+ "source": [
+ "# You can perform a synchronous invocation through the `.invoke` method\n",
+ "\n",
+ "llm.invoke(\"Hi!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yVqxaDylWjck"
+ },
+ "source": [
+ "Or invoke through the generation of a stream through the `.stream` **method**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2twtoa6a4_ui"
+ },
+ "outputs": [],
+ "source": [
+ "# You can also generate a stream through the `.stream` method\n",
+ "\n",
+ "for m in llm.stream(\"Hi!\"):\n",
+ " print(m)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bqr_QPso7shY"
+ },
+ "source": [
+ "## RAG Q&A Chain with Gemma 2 and Cloud Run\n",
+ "\n",
+ "We can leverage the LangChain integration to create a sample RAG application with Gemma, Cloud Run, [Vertex AI Embedding](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings) for generating embeddings and [FAISS vector store](https://python.langchain.com/v0.2/docs/integrations/vectorstores/faiss/) for document retrieval.\n",
+ "\n",
+ "Through RAG, we will ask Gemma 2 to answer questions about the [Cloud Run documentation page](https://cloud.google.com/run/docs/overview/what-is-cloud-run)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wHq8zpG5a4u9"
+ },
+ "source": [
+ "### Setup embedding model and retriever\n",
+ "\n",
+ "We are ready to setup our embedding model and retriever."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tI66kos7a8B4"
+ },
+ "outputs": [],
+ "source": [
+ "credentials, _ = google.auth.default(quota_project_id=PROJECT_ID)\n",
+ "embeddings = VertexAIEmbeddings(\n",
+ " project=PROJECT_ID, model_name=\"text-embedding-004\", credentials=credentials\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xxCu8MST6oWK"
+ },
+ "outputs": [],
+ "source": [
+ "loader = WebBaseLoader(\"https://cloud.google.com/run/docs/overview/what-is-cloud-run\")\n",
+ "docs = loader.load()\n",
+ "documents = CharacterTextSplitter(chunk_size=800, chunk_overlap=100).split_documents(\n",
+ " docs\n",
+ ")\n",
+ "\n",
+ "vector = SKLearnVectorStore.from_documents(documents, embeddings)\n",
+ "retriever = vector.as_retriever()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rCH9n2UEzwMv"
+ },
+ "source": [
+ "### RAG Chain Definition\n",
+ "\n",
+ "We will define now our RAG Chain.\n",
+ "\n",
+ "The RAG chain works as follows:\n",
+ "\n",
+ "1. The user's query and conversation history are passed to the `query_rewrite_chain` to generate a rewritten query optimized for semantic search.\n",
+ "2. The rewritten query is used by the `retriever` to fetch relevant documents.\n",
+ "3. The retrieved documents are formatted into a single string.\n",
+ "4. The formatted documents, along with the original user messages, are passed to the LLM with instructions to generate an answer based on the provided context.\n",
+ "5. The LLM's response is parsed and returned as the final answer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "x1nPO00s4r4o"
+ },
+ "outputs": [],
+ "source": [
+ "answer_generation_template = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " (\n",
+ " \"system\",\n",
+ " \"You are an assistant for question answering-tasks. \"\n",
+ " \"Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. \"\n",
+ " \"{context}\",\n",
+ " ),\n",
+ " MessagesPlaceholder(variable_name=\"messages\"),\n",
+ " ]\n",
+ ")\n",
+ "query_rewrite_template = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " (\n",
+ " \"system\",\n",
+ " \"Rewrite a query to a semantic search engine using the current conversation. \"\n",
+ " \"Provide only the rewritten query as output.\",\n",
+ " ),\n",
+ " MessagesPlaceholder(variable_name=\"messages\"),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jm309i0r0Lwv"
+ },
+ "outputs": [],
+ "source": [
+ "query_rewrite_chain = query_rewrite_template | llm\n",
+ "\n",
+ "\n",
+ "def extract_query(messages):\n",
+ " return query_rewrite_chain.invoke(messages).content\n",
+ "\n",
+ "\n",
+ "def format_docs(docs):\n",
+ " return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+ "\n",
+ "\n",
+ "rag_chain = (\n",
+ " {\n",
+ " \"context\": extract_query | retriever | format_docs,\n",
+ " \"messages\": RunnablePassthrough(),\n",
+ " }\n",
+ " | answer_generation_template\n",
+ " | llm\n",
+ " | StrOutputParser()\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rRFtc7eo0Tn7"
+ },
+ "source": [
+ "### Testing the RAG Chain"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MIfd_Wwa7RUe"
+ },
+ "outputs": [],
+ "source": [
+ "rag_chain.invoke([(\"human\", \"What features does Cloud Run offer?\")])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f0994074042d"
+ },
+ "source": [
+ "Now, let's use a specific question from the documentation to explore how RAG addresses potential gaps in the model's knowledge."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ec5895d100f5"
+ },
+ "outputs": [],
+ "source": [
+ "QUESTION = \"List all the different Cloud Run integrations\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e0abfae19b32"
+ },
+ "source": [
+ "First, we'll ask the LLM directly:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4fb7c3e2426e"
+ },
+ "outputs": [],
+ "source": [
+ "print(llm.invoke(QUESTION).content)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5575573f8174"
+ },
+ "source": [
+ "Then, we'll ask the same question using the RAG chain:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AszY1ke5bWC9"
+ },
+ "outputs": [],
+ "source": [
+ "print(rag_chain.invoke([(\"human\", QUESTION)]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c98450b91f72"
+ },
+ "source": [
+ "We can notice how RAG chain provides a more accurate and comprehensive answer than the LLM by leveraging [source documentation](https://cloud.google.com/run/docs/overview/what-is-cloud-run). \n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88c5f72f1981"
+ },
+ "source": [
+ "## Conclusion\n",
+ "Congratulations. Now you know how to deploy an open model to Cloud Run powered by a GPU! Specifically, you deployed a Gemma 2 model to Cloud Run with a GPU, as part of a RAG application powered by LangChain. You were able to ask answers from Gemma 2 about a documentation page.\n",
+ "\n",
+ "For more information about your identity tokens expiring and how to refresh your tokens, see the next section below \"Appendix: Handling Identity Token Expiration\".\n",
+ "\n",
+ "To clean up the resources you created in this section, see the section at the bottom \"Cleaning up\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "A5VPH4gP3igf"
+ },
+ "source": [
+ "## Appendix: Handling Identity Token Expiration\n",
+ "\n",
+ "When deploying a Generative AI application Google Cloud Run, you'll often need to authenticate your requests using Identity Tokens.\n",
+ "\n",
+ "These tokens will expire hourly, requiring a mechanism for automatic refresh to ensure uninterrupted operation.\n",
+ "\n",
+ "The following helper classes provide an example of how to deal with token refresh. It leverages the `google.auth` library to handle the authentication process and automatically refresh the token when necessary.\n",
+ "\n",
+ "\n",
+ "See the following resources for more information on authentication:\n",
+ "* [Identity Token Overview](https://cloud.google.com/docs/authentication/token-types#id)\n",
+ "* [Google Cloud Run Authentication Documentation](https://cloud.google.com/run/docs/authenticating/overview)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TJlxAZKz-tjY"
+ },
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "\n",
+ "import google.auth\n",
+ "from google.auth.exceptions import DefaultCredentialsError\n",
+ "import google.auth.transport.requests\n",
+ "import google.oauth2.id_token\n",
+ "from pydantic.v1 import Extra\n",
+ "\n",
+ "\n",
+ "class TokenManager:\n",
+ " def __init__(\n",
+ " self, url=None, token_lifetime=3600\n",
+ " ): # Default token lifetime of 1 hour\n",
+ " self.token = None\n",
+ " self.expiry_time = 0\n",
+ " self.token_lifetime = token_lifetime\n",
+ " self.url = url\n",
+ " self.creds, _ = google.auth.default()\n",
+ "\n",
+ " def get_token(self):\n",
+ " if time.time() >= self.expiry_time:\n",
+ " self.refresh_token(url=self.url)\n",
+ " return self.token\n",
+ "\n",
+ " def refresh_token(self, url):\n",
+ " \"\"\"\n",
+ " Retrieves an ID token, attempting to use default credentials first,\n",
+ " and falling back to fetching a service-to-service new token if necessary.\n",
+ " See more on Cloud Run authentication at this link:\n",
+ " https://cloud.google.com/run/docs/authenticating/service-to-service\n",
+ " Args:\n",
+ " url: The URL to use for the token request.\n",
+ " \"\"\"\n",
+ "\n",
+ " auth_req = google.auth.transport.requests.Request()\n",
+ " try:\n",
+ " self.token = google.oauth2.id_token.fetch_id_token(auth_req, url)\n",
+ " except DefaultCredentialsError:\n",
+ " self.creds.refresh(auth_req)\n",
+ " self.token = self.creds.id_token\n",
+ "\n",
+ " self.expiry_time = time.time() + self.token_lifetime\n",
+ "\n",
+ "\n",
+ "class ChatOllamaWithAuth(ChatOllama):\n",
+ " class Config:\n",
+ " extra = Extra.allow\n",
+ "\n",
+ " def __init__(self, *args, **kwargs):\n",
+ " super().__init__(*args, **kwargs)\n",
+ " self._token_manager = TokenManager()\n",
+ " self.headers = {} if self.headers is None else self.headers\n",
+ "\n",
+ " def _generate(self, *args, **kwargs) -> str:\n",
+ " self.headers[\"Authorization\"] = f\"Bearer {self._token_manager.get_token()}\"\n",
+ " return super()._generate(*args, **kwargs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GveBb6tLPUEc"
+ },
+ "outputs": [],
+ "source": [
+ "llm = ChatOllamaWithAuth(model=MODEL_NAME, base_url=SERVICE_URL, num_predict=300)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LjgBf0M4Mokn"
+ },
+ "source": [
+ "You can now use the `invoke` function as usual, with the token being refreshed automatically every hour.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NH0kWCloMlUe"
+ },
+ "outputs": [],
+ "source": [
+ "llm.invoke(\"Hi, testing a request\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "270LFAC0P3i4"
+ },
+ "source": [
+ "## Cleaning up\n",
+ "To clean up all Google Cloud resources, you can run the following cell to delete the Cloud Run service you created.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XJRe7I7KP2HM"
+ },
+ "outputs": [],
+ "source": [
+ "# Delete the Cloud Run service deployed above\n",
+ "\n",
+ "!gcloud run services delete $SERVICE_NAME --project $PROJECT_ID --region $LOCATION --quiet"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "name": "cloud_run_ollama_gemma2_rag_qa.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}