diff --git a/00_introduction.py b/00_introduction.py
new file mode 100644
index 0000000..dbc2361
--- /dev/null
+++ b/00_introduction.py
@@ -0,0 +1,168 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC
+# MAGIC
+# MAGIC # Semantic Cache Solution Accelerator
+# MAGIC
+# MAGIC Generative AI models are increasingly revolutionizing industries, with techniques like Retrieval Augmented Generation (RAG) and Compound AI systems leading the charge. These models empower organizations by enhancing capabilities such as information retrieval, decision-making, and content generation. However, the implementation of these systems is often accompanied by significant costs, especially in terms of computational resources. Despite these challenges, the rapid advancement of AI platforms and the development of more efficient algorithms are enabling businesses to optimize costs and scale AI-driven solutions more effectively.
+# MAGIC
+# MAGIC Semantic cache is a technique that is adopted by many enterprises to reduce the computational load of AI-driven systems. As generative AI models handle increasingly complex queries, there is often semantic overlap between different queries, such as users asking variations of the same question. Without semantic caching, these systems would need to repeatedly perform resource-intensive computations, leading to inefficiencies. By storing the previously processed queries and responses, semantic caching allows AI models to retrieve relevant information without recalculating, thereby reducing latency, lowering server load, and conserving computational resources. This becomes especially important as AI applications scale, ensuring cost-effectiveness and maintaining high performance, particularly in natural language processing, where nuanced query variations are frequent.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## How Semantic Caching works
+# MAGIC
+# MAGIC Semantic caching leverages a vector database to store and retrieve answers based on the meaning or semantics of a question rather than just its keywords. In this system, each question is embedded as a vector, and the cached answers are stored. When a new query is submitted, the system searches the database for similar vectors, returning a cached response when a match is found. When a suitable match is not found, the system proceeds to execute the standard pipeline to generate a response, and in turn persists the new question and answer pair in the database.
+# MAGIC
+# MAGIC This technique is particularly effective for handling high-volume, repetitive queries such as those often found in customer FAQs, where users frequently ask the same or similar questions. Some of the key business benefits of semantic cache are:
+# MAGIC
+# MAGIC - Reduce Costs: With fewer computationally expensive model calls, businesses will see significant cost savings. The system bypasses the need to generate new answers for questions that have already been asked, leading to reduced usage of cloud resources and lower operational costs.
+# MAGIC - Faster Response Time: Customer satisfaction is closely tied to how quickly they receive answers. With semantic caching, chatbots can instantly retrieve answers from the cache, dramatically reducing the time it takes to respond to queries.
+# MAGIC - Scalability: As businesses scale, so do the number of customer inquiries. Caching frequently asked questions ensures the chatbot can handle increased volumes without a corresponding increase in costs or latency.
+# MAGIC
+# MAGIC Some use cases we see in the market that are especially suitable for semantic caching include:
+# MAGIC
+# MAGIC - FAQs: Questions that customers frequently ask—such as product details, order statuses, or return policies—are prime candidates for caching. Businesses can quickly address these repetitive queries without taxing the system.
+# MAGIC - Support Tickets: For companies that manage large-scale customer support, semantic caching can be implemented to address commonly recurring issues.
+# MAGIC - Internal Knowledge Bases: Employees often ask the same internal queries, and caching these answers can improve productivity by providing instant access to stored knowledge.
+# MAGIC
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Semantic Cache on Databricks Mosaic AI
+# MAGIC
+# MAGIC Databricks provides an optimal platform for building AI agents with semantic caching capabilities. With Databricks Mosaic AI, users have access to all necessary components such as a vector database, agent development framework, agent serving, and an agent evaluation framework on a unified, highly governed platform. This ensures that key assets, including data, vector indexes, models, agents, and endpoints, are centrally managed under robust governance.
+# MAGIC
+# MAGIC Mosaic AI also offers an open architecture, allowing users to experiment with various models for embeddings and generation. Leveraging the Mosaic AI Agent Framework and Evaluation tools, users can rapidly iterate on applications until they meet production-level standards. Once deployed, KPIs like hit ratios and latency can be monitored using MLflow traces, which are automatically logged in Inference Tables for easy tracking.
+# MAGIC
+# MAGIC If you're looking to implement semantic caching for your AI system on Databricks, we're excited to introduce the Semantic Cache Solution Accelerator. This accelerator is designed to help you get started quickly and efficiently, providing a streamlined path to implementation.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Case Study
+# MAGIC
+# MAGIC Imagine you're operating an AI chatbot on your products' public documentation page. This chatbot answers visitors' questions about your products using a retriever-augmented generation (RAG) architecture. After reviewing the submitted user questions and responses, you notice a large number of redundant queries—phrased differently but carrying the same meaning. You're getting feedback from the users that the chatbot's response time is too long and also facing pressure from management to reduce the operational costs of the chatbot.
+# MAGIC
+# MAGIC In the following notebooks, we'll explore how semantic caching can significantly lower both total cost and latency, with only a minor trade-off in response quality.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Exploratory Data Analysis
+# MAGIC
+# MAGIC The dataset we use for this solution accelrator was synthesized and is stored inside `./data/`. We first generated a list of 10 questions related to Databricks Machine Learning product features using [dbrx-instruct](https://docs.databricks.com/en/machine-learning/foundation-models/supported-models.html#dbrx-instruct). We then bootstrapped these questions to generate 100 questions. We reformulate each of the 10 questions slightly differently without changing the meaning. We used [Meta Llama 3.1 70B Instruct](https://docs.databricks.com/en/machine-learning/foundation-models/supported-models.html#meta-llama-31-70b-instruct) for this.
+# MAGIC
+# MAGIC
+# MAGIC The goal of this exploratory data analysis is to identify the optimal similarity score threshold that separates semantically similar questions from non-similar ones. This threshold should maximize the cache hit rate while minimizing false positives. A synthesized dataset is helpful as it provides ground truth labels, meaning that the questions bootstrapped from the same original question belong to the same semantic class. This is captured in the colume `base` in the `data/synthetic_questions_100.csv dataset`. This dataset allows for accurate validation of the threshold's performance in separating similar and non-similar questions, which we we see in the following.
+# MAGIC
+# MAGIC Let's first load in the configuration parameters (find more information about `Config` in the next notebook).
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We will read in the dataset as a pandas DataFrame and apply an embedding model to the questions.
+
+# COMMAND ----------
+
+import pandas as pd
+import mlflow.deployments
+from pyspark.sql.functions import udf, pandas_udf
+from pyspark.sql.types import StringType
+
+df = pd.read_csv('data/synthetic_questions_100.csv')[['base', 'question']]
+
+deploy_client = mlflow.deployments.get_deploy_client("databricks")
+def get_embedding(question):
+ response = deploy_client.predict(endpoint=config.EMBEDDING_MODEL_SERVING_ENDPOINT_NAME, inputs={"input": question})
+ return response.data[0]["embedding"]
+
+# Apply an embedding model to the 'question' column and create a new column 'embedding'
+df["embedding"] = df["question"].apply(lambda x: get_embedding(x))
+
+display(df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We will perform a cross join between all the questions to calculate the similarity score for every possible pair of combinations, which will result in 10,000 rows.
+
+# COMMAND ----------
+
+df = df.merge(df, how='cross')
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC [Databricks Mosaic AI Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html) uses L2 distance as a similarity score:
+# MAGIC
+# MAGIC $$\frac{1}{(1 + dist(q,x)^2)}$$
+# MAGIC
+# MAGIC where dist is the Euclidean distance between the query q and the index entry x, defined as:
+# MAGIC
+# MAGIC $$dist(q,x) = \sqrt{(q_1-x_1)^2 + (q_2-x_2)^2 + \ldots + (q_d-x_d)^2}.$$
+# MAGIC
+# MAGIC We will calculate this metric for each combination of questions. The `similar` column shown below indicates whether both questions in the pair belong to the same semantic class.
+
+# COMMAND ----------
+
+import numpy as np
+
+def get_similarity_score(embedding_x, embedding_y):
+ l_norm = np.linalg.norm(np.array(embedding_x) - np.array(embedding_y))
+ score = 1.0/(1.0 + l_norm*l_norm)
+ return score
+
+# Apply an embedding model to the 'question' column and create a new column 'embedding'
+df["score"] = df.apply(lambda x: get_similarity_score(x["embedding_x"], x["embedding_y"]), axis=1)
+df = df.loc[df["score"] != 1] # Exclude the self-similar combinations
+df ["similar"] = df.apply(lambda x: True if x["base_x"] == x["base_y"] else False, axis=1)
+df = df[["similar", "score"]]
+
+display(df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Let's look at the summary statistics and the distribution of the simiar and non-similar pairs.
+
+# COMMAND ----------
+
+df.groupby('similar').describe().T
+
+# COMMAND ----------
+
+df.groupby('similar')['score'].plot(
+ kind='hist',
+ bins=50,
+ alpha=0.65,
+ density=True,
+ figsize=(10, 6),
+ grid=True,
+ legend=True,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC The analysis shows that the similar and non-similar questions synthesized for this demo exhibit distinct distributions. However, there is a notable overlap between the two distributions, presenting a critical decision point for the solution. If we prioritize the hit rate and set a low similarity threshold (e.g., 0.005), we can achieve a recall of over 0.75, but this will come at the expense of precision. On the other hand, setting a higher threshold (e.g., 0.015) to prioritize precision will limit recall to around 0.25. This trade-off must be carefully evaluated by the team in collaboration with business stakeholders.
+# MAGIC
+# MAGIC In the following notebook, we will set the threshold to 0.01 as a balanced starting point.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/01_data_preparation.py b/01_data_preparation.py
new file mode 100644
index 0000000..6f87be0
--- /dev/null
+++ b/01_data_preparation.py
@@ -0,0 +1,137 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #Set up Vector Search for RAG
+# MAGIC
+# MAGIC Our AI chatbot utilizes a retriever-augmented generation (RAG) approach. Before implementing semantic caching, we’ll first set up the vector database that supports this RAG architecture. For this, we’ll use [Databricks Mosaic AI Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Cluster configuration
+# MAGIC We recommend using a cluster with the following specifications to run this solution accelerator:
+# MAGIC - Unity Catalog enabled cluster
+# MAGIC - Databricks Runtime 15.4 LTS ML or above
+# MAGIC - Single-node cluster: e.g. `m6id.2xlarge` on AWS or `Standard_D8ds_v4` on Azure Databricks.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We install the required packages from the `requirements.txt` file into the current session.
+
+# COMMAND ----------
+
+# DBTITLE 1,Install requirements
+# MAGIC %pip install -r requirements.txt --quiet
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC `config.py` is a key file that holds all the essential parameters for the application. Open the file and define the values for the parameters according to your specific setup, such as the embedding/generation model endpoint, catalog, schema, vector search endpoint, and more. The following cell will load these parameters into the `config` variable.
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In the next cell, we run the `99_init` notebook, which sets up the logging policy and downloads the chunked Databricks product documentation (if it doesn't already exist) into the specified tables under the catalog and schema you defined in config.py.
+
+# COMMAND ----------
+
+# DBTITLE 1,Run init notebok
+# MAGIC %run ./99_init $reset_all_data=false
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create Vector Search endpoint
+# MAGIC
+# MAGIC We create a Vector Search endpoint using custom functions defined in the `utils.py` script.
+
+# COMMAND ----------
+
+import utils
+from databricks.vector_search.client import VectorSearchClient
+
+# Instantiate the Vector Search Client
+vsc = VectorSearchClient(disable_notice=True)
+
+# Check if the endpoint exists, if not create it
+if not utils.vs_endpoint_exists(vsc, config.VECTOR_SEARCH_ENDPOINT_NAME):
+ utils.create_or_wait_for_endpoint(vsc, config.VECTOR_SEARCH_ENDPOINT_NAME)
+
+print(f"Endpoint named {config.VECTOR_SEARCH_ENDPOINT_NAME} is ready.")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create Vector Search index
+# MAGIC Create a Vector Search index from the chunks of documents loaded in the previous cell. We use custom functions defined in the `utils.py` script.
+
+# COMMAND ----------
+
+from databricks.sdk import WorkspaceClient
+
+# Check if the index exists, if not create it
+if not utils.index_exists(vsc, config.VECTOR_SEARCH_ENDPOINT_NAME, config.VS_INDEX_FULLNAME):
+
+ print(f"Creating index {config.VS_INDEX_FULLNAME} on endpoint {config.VECTOR_SEARCH_ENDPOINT_NAME}...")
+
+ # Create a delta sync index
+ vsc.create_delta_sync_index(
+ endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
+ index_name=config.VS_INDEX_FULLNAME,
+ source_table_name=config.SOURCE_TABLE_FULLNAME,
+ pipeline_type="TRIGGERED",
+ primary_key="id",
+ embedding_source_column='content', # The column containing our text
+ embedding_model_endpoint_name=config.EMBEDDING_MODEL_SERVING_ENDPOINT_NAME, #The embedding endpoint used to create the embeddings
+ )
+
+ # Let's wait for the index to be ready and all our embeddings to be created and indexed
+ utils.wait_for_index_to_be_ready(vsc, config.VECTOR_SEARCH_ENDPOINT_NAME, config.VS_INDEX_FULLNAME)
+else:
+ # Trigger a sync to update our vs content with the new data saved in the table
+ utils.wait_for_index_to_be_ready(vsc, config.VECTOR_SEARCH_ENDPOINT_NAME, config.VS_INDEX_FULLNAME)
+ vsc.get_index(config.VECTOR_SEARCH_ENDPOINT_NAME, config.VS_INDEX_FULLNAME).sync()
+
+print(f"index {config.VS_INDEX_FULLNAME} on table {config.SOURCE_TABLE_FULLNAME} is ready")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Query Vector Search index
+# MAGIC
+# MAGIC Let's see if we can run a similarity search against the index.
+
+# COMMAND ----------
+
+# Let's search for the chunks that are most relevant to the query "What is Model Serving?"
+results = vsc.get_index(
+ config.VECTOR_SEARCH_ENDPOINT_NAME,
+ config.VS_INDEX_FULLNAME
+ ).similarity_search(
+ query_text="What is Model Serving?",
+ columns=["url", "content"],
+ num_results=1)
+docs = results.get('result', {}).get('data_array', [])
+docs
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We have successfully set up the vector database for our RAG chatbot. In the next `02_rag_chatbot` notebook, we will build a standard RAG chatbot without semantic caching, which will serve as a benchmark. Later, in the `03_rag_chatbot_with_cache notebook`, we will introduce semantic caching and compare its performance.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/02_rag_chatbot.py b/02_rag_chatbot.py
new file mode 100644
index 0000000..d24be6f
--- /dev/null
+++ b/02_rag_chatbot.py
@@ -0,0 +1,237 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #Create and deploy a standard RAG chain
+# MAGIC
+# MAGIC In this notebook, we will build a standard RAG chatbot without semantic caching to serve as a benchmark. We will utilize the [Databricks Mosaic AI Agent Framework](https://www.databricks.com/product/machine-learning/retrieval-augmented-generation), which enables rapid prototyping of the initial application. In the following cells, we will define a chain, log and register it using MLflow and Unity Catalog, and finally deploy it behind a [Databricks Mosaic AI Model Serving](https://docs.databricks.com/en/machine-learning/model-serving/index.html) endpoint.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Cluster configuration
+# MAGIC We recommend using a cluster with the following specifications to run this solution accelerator:
+# MAGIC - Unity Catalog enabled cluster
+# MAGIC - Databricks Runtime 15.4 LTS ML or above
+# MAGIC - Single-node cluster: e.g. `m6id.2xlarge` on AWS or `Standard_D8ds_v4` on Azure Databricks.
+
+# COMMAND ----------
+
+# DBTITLE 1,Install requirements
+# MAGIC %pip install -r requirements.txt --quiet
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# DBTITLE 1,Run init notebok
+# MAGIC %run ./99_init $reset_all_data=false
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Here, we define environmental variables `HOST` and `TOKEN` for our Model Serving endpoint to authenticate against our Vector Search index.
+
+# COMMAND ----------
+
+# DBTITLE 1,Define environmental variables
+import os
+
+HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
+TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
+
+os.environ['HOST'] = HOST
+os.environ['TOKEN'] = TOKEN
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create and register a chain to MLflow
+# MAGIC
+# MAGIC The next cell defines our standard RAG chain using Langchain. When executed, it will write the content to the `chain/chain.py` file, which will then be used to log the chain in MLflow.
+
+# COMMAND ----------
+
+# MAGIC %%writefile chain/chain.py
+# MAGIC from databricks.vector_search.client import VectorSearchClient
+# MAGIC from langchain_core.prompts import ChatPromptTemplate
+# MAGIC from langchain_community.chat_models import ChatDatabricks
+# MAGIC from langchain_community.vectorstores import DatabricksVectorSearch
+# MAGIC from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+# MAGIC from langchain_core.output_parsers import StrOutputParser
+# MAGIC from operator import itemgetter
+# MAGIC from config import Config
+# MAGIC import mlflow
+# MAGIC import os
+# MAGIC
+# MAGIC ## Enable MLflow Tracing
+# MAGIC mlflow.langchain.autolog()
+# MAGIC
+# MAGIC # load parameters
+# MAGIC config = Config()
+# MAGIC
+# MAGIC # Connect to the Vector Search Index
+# MAGIC vs_index = VectorSearchClient(
+# MAGIC workspace_url=os.environ['HOST'],
+# MAGIC personal_access_token=os.environ['TOKEN'],
+# MAGIC disable_notice=True,
+# MAGIC ).get_index(
+# MAGIC endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
+# MAGIC index_name=config.VS_INDEX_FULLNAME,
+# MAGIC )
+# MAGIC
+# MAGIC # Turn the Vector Search index into a LangChain retriever
+# MAGIC vector_search_as_retriever = DatabricksVectorSearch(
+# MAGIC vs_index,
+# MAGIC text_column="content",
+# MAGIC columns=["id", "content", "url"],
+# MAGIC ).as_retriever(search_kwargs={"k": 3}) # Number of search results that the retriever returns
+# MAGIC # Enable the RAG Studio Review App and MLFlow to properly display track and display retrieved chunks for evaluation
+# MAGIC mlflow.models.set_retriever_schema(primary_key="id", text_column="content", doc_uri="url")
+# MAGIC
+# MAGIC # Method to format the docs returned by the retriever into the prompt (keep only the text from chunks)
+# MAGIC def format_context(docs):
+# MAGIC chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
+# MAGIC return "".join(chunk_contents)
+# MAGIC
+# MAGIC # Prompt template to be used to prompt the LLM
+# MAGIC prompt = ChatPromptTemplate.from_messages(
+# MAGIC [
+# MAGIC ("system", f"{config.LLM_PROMPT_TEMPLATE}"),
+# MAGIC ("user", "{question}"),
+# MAGIC ]
+# MAGIC )
+# MAGIC
+# MAGIC # Our foundation model answering the final prompt
+# MAGIC model = ChatDatabricks(
+# MAGIC endpoint=config.LLM_MODEL_SERVING_ENDPOINT_NAME,
+# MAGIC extra_params={"temperature": 0.01, "max_tokens": 500}
+# MAGIC )
+# MAGIC
+# MAGIC # Return the string contents of the most recent messages: [{...}] from the user to be used as input question
+# MAGIC def extract_user_query_string(chat_messages_array):
+# MAGIC return chat_messages_array[-1]["content"]
+# MAGIC
+# MAGIC # RAG Chain
+# MAGIC chain = (
+# MAGIC {
+# MAGIC "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
+# MAGIC "context": itemgetter("messages")
+# MAGIC | RunnableLambda(extract_user_query_string)
+# MAGIC | vector_search_as_retriever
+# MAGIC | RunnableLambda(format_context),
+# MAGIC }
+# MAGIC | prompt
+# MAGIC | model
+# MAGIC | StrOutputParser()
+# MAGIC )
+# MAGIC
+# MAGIC # Tell MLflow logging where to find your chain.
+# MAGIC mlflow.models.set_model(model=chain)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In this cell, we log the chain to MLflow. Note that we are passing `config.py` as a dependency, allowing the chain to load the necessary parameters when deployed to another compute environment or to a Model Serving endpoint. MLflow returns a trace of the inference that shows the detail breakdown of the latency and the input/output from each step in the chain.
+
+# COMMAND ----------
+
+# Log the model to MLflow
+config_file_path = "config.py"
+
+# Create a config file to be used by the chain
+with mlflow.start_run(run_name=f"rag_chatbot"):
+ logged_chain_info = mlflow.langchain.log_model(
+ lc_model=os.path.join(os.getcwd(), 'chain/chain.py'), # Chain code file e.g., /path/to/the/chain.py
+ artifact_path="chain", # Required by MLflow
+ input_example=config.INPUT_EXAMPLE, # MLflow will execute the chain before logging & capture it's output schema.
+ code_paths = [config_file_path], # Include the config file in the model
+ )
+
+# Test the chain locally
+chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
+chain.invoke(config.INPUT_EXAMPLE)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC If we are happy with the logged chain, we will go ahead and register the chain in Unity Catalog.
+
+# COMMAND ----------
+
+# Register to UC
+uc_registered_model_info = mlflow.register_model(
+ model_uri=logged_chain_info.model_uri,
+ name=config.MODEL_FULLNAME
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Deploy the chain to a Model Serving endpoint
+# MAGIC
+# MAGIC We deploy the chaing using custom functions defined in the `utils.py` script.
+
+# COMMAND ----------
+
+import utils
+utils.deploy_model_serving_endpoint(
+ spark,
+ config.MODEL_FULLNAME,
+ config.CATALOG,
+ config.LOGGING_SCHEMA,
+ config.ENDPOINT_NAME,
+ HOST,
+ TOKEN,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Wait until the endpoint is ready. This may take some time (~15 minutes), so grab a coffee!
+
+# COMMAND ----------
+
+utils.wait_for_model_serving_endpoint_to_be_ready(config.ENDPOINT_NAME)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Once the endpoint is up and running, let's send a request and see how it responds.
+
+# COMMAND ----------
+
+import utils
+data = {
+ "inputs": {
+ "messages": [
+ {
+ "content": "What is Model Serving?",
+ "role": "user"
+ }
+ ]
+ }
+}
+# Now, call the function with the correctly formatted data
+utils.send_request_to_endpoint(
+ config.ENDPOINT_NAME,
+ data,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In this notebook, we built a standard RAG chatbot without semantic caching to serve. We will use this chain to benchmak against the chain with semantic caching, which we will build in the next `03_rag_chatbot_with_cache` notebook.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/03_rag_chatbot_with_cache.py b/03_rag_chatbot_with_cache.py
new file mode 100644
index 0000000..9f52739
--- /dev/null
+++ b/03_rag_chatbot_with_cache.py
@@ -0,0 +1,348 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #Create and deploy a RAG chain with semantic caching
+# MAGIC
+# MAGIC In this notebook, we will build a RAG chatbot with semantic caching. To do this, we first need to create and warm up our cache. We’ll use [Mosaic AI Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html) for semantic caching, taking advantage of its high-performance similarity search. In the following cells, we will create and warm the cache, build a chain with a semantic caching layer, log and register it using MLflow and Unity Catalog, and finally deploy it behind a [Databricks Mosaic AI Model Serving](https://docs.databricks.com/en/machine-learning/model-serving/index.html) endpoint.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Cluster configuration
+# MAGIC We recommend using a cluster with the following specifications to run this solution accelerator:
+# MAGIC - Unity Catalog enabled cluster
+# MAGIC - Databricks Runtime 15.4 LTS ML or above
+# MAGIC - Single-node cluster: e.g. `m6id.2xlarge` on AWS or `Standard_D8ds_v4` on Azure Databricks.
+
+# COMMAND ----------
+
+# DBTITLE 1,Install requirements
+# MAGIC %pip install -r requirements.txt --quiet
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# DBTITLE 1,Run init notebok
+# MAGIC %run ./99_init $reset_all_data=false
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Here, we define environmental variables `HOST` and `TOKEN` for our Model Serving endpoint to authenticate against our Vector Search index.
+
+# COMMAND ----------
+
+# DBTITLE 1,Set environmental variables
+import os
+
+HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
+TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
+
+os.environ['HOST'] = HOST
+os.environ['TOKEN'] = TOKEN
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create and warm a cache
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We instantiate a Vector Search client to interact with a Vector Search endpoint to create a cache.
+
+# COMMAND ----------
+
+from databricks.vector_search.client import VectorSearchClient
+from cache import Cache
+
+# Create a Vector Search Client
+vsc = VectorSearchClient(
+ workspace_url=HOST,
+ personal_access_token=TOKEN,
+ disable_notice=True,
+ )
+
+# Initialize the cache
+semantic_cache = Cache(vsc, config)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We first delete the cache if it already exists.
+
+# COMMAND ----------
+
+semantic_cache.clear_cache()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We then create a cache.
+
+# COMMAND ----------
+
+semantic_cache.create_cache()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We finally load the cache with predefined Q&A pairs: i.e., `/data/synthetic_qa.txt`.
+
+# COMMAND ----------
+
+semantic_cache.warm_cache()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create and register a chain to MLflow
+# MAGIC
+# MAGIC The next cell defines our RAG chain with semantic cache using Langchain. When executed, it will write the content to the `chain/chain_cache.py` file, which will then be used to log the chain in MLflow.
+
+# COMMAND ----------
+
+# MAGIC %%writefile chain/chain_cache.py
+# MAGIC from databricks.vector_search.client import VectorSearchClient
+# MAGIC from langchain_community.vectorstores import DatabricksVectorSearch
+# MAGIC from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+# MAGIC from langchain_core.output_parsers import StrOutputParser
+# MAGIC from langchain_core.prompts import ChatPromptTemplate
+# MAGIC from langchain_community.chat_models import ChatDatabricks
+# MAGIC from operator import itemgetter
+# MAGIC from datetime import datetime
+# MAGIC from uuid import uuid4
+# MAGIC import os
+# MAGIC import mlflow
+# MAGIC from cache import Cache
+# MAGIC from config import Config
+# MAGIC
+# MAGIC
+# MAGIC # Set up logging
+# MAGIC import logging
+# MAGIC logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+# MAGIC logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
+# MAGIC logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)
+# MAGIC
+# MAGIC ## Enable MLflow Tracing
+# MAGIC mlflow.langchain.autolog()
+# MAGIC
+# MAGIC # Get configuration
+# MAGIC config = Config()
+# MAGIC
+# MAGIC # Connect to Vector Search
+# MAGIC vsc = VectorSearchClient(
+# MAGIC workspace_url=os.environ['HOST'],
+# MAGIC personal_access_token=os.environ['TOKEN'],
+# MAGIC disable_notice=True,
+# MAGIC )
+# MAGIC
+# MAGIC # Get the Vector Search index
+# MAGIC vs_index = vsc.get_index(
+# MAGIC index_name=config.VS_INDEX_FULLNAME,
+# MAGIC endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
+# MAGIC )
+# MAGIC
+# MAGIC # Instantiate a Cache object
+# MAGIC semantic_cache = Cache(vsc, config)
+# MAGIC
+# MAGIC # Turn the Vector Search index into a LangChain retriever
+# MAGIC vector_search_as_retriever = DatabricksVectorSearch(
+# MAGIC vs_index,
+# MAGIC text_column="content",
+# MAGIC columns=["id", "content", "url"],
+# MAGIC ).as_retriever(search_kwargs={"k": 3}) # Number of search results that the retriever returns
+# MAGIC
+# MAGIC # Method to retrieve the context from the Vector Search index
+# MAGIC def retrieve_context(qa):
+# MAGIC return vector_search_as_retriever.invoke(qa["question"])
+# MAGIC
+# MAGIC # Enable the RAG Studio Review App and MLFlow to properly display track and display retrieved chunks for evaluation
+# MAGIC mlflow.models.set_retriever_schema(primary_key="id", text_column="content", doc_uri="url")
+# MAGIC
+# MAGIC # Method to format the docs returned by the retriever into the prompt (keep only the text from chunks)
+# MAGIC def format_context(docs):
+# MAGIC chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
+# MAGIC return "".join(chunk_contents)
+# MAGIC
+# MAGIC # Create a prompt template for response generation
+# MAGIC prompt = ChatPromptTemplate.from_messages(
+# MAGIC [
+# MAGIC ("system", f"{config.LLM_PROMPT_TEMPLATE}"),
+# MAGIC ("user", "{question}"),
+# MAGIC ]
+# MAGIC )
+# MAGIC
+# MAGIC # Define our foundation model answering the final prompt
+# MAGIC model = ChatDatabricks(
+# MAGIC endpoint=config.LLM_MODEL_SERVING_ENDPOINT_NAME,
+# MAGIC extra_params={"temperature": 0.01, "max_tokens": 500}
+# MAGIC )
+# MAGIC
+# MAGIC # Call the foundation model
+# MAGIC def call_model(prompt):
+# MAGIC response = model.invoke(prompt)
+# MAGIC semantic_cache.store_in_cache(
+# MAGIC question = prompt.dict()['messages'][1]['content'],
+# MAGIC answer = response.content
+# MAGIC )
+# MAGIC return response
+# MAGIC
+# MAGIC # Return the string contents of the most recent messages: [{...}] from the user to be used as input question
+# MAGIC def extract_user_query_string(chat_messages_array):
+# MAGIC return chat_messages_array[-1]["content"]
+# MAGIC
+# MAGIC # Router to determine which subsequent step to be executed
+# MAGIC def router(qa):
+# MAGIC if qa["answer"] == "":
+# MAGIC return rag_chain
+# MAGIC else:
+# MAGIC return (qa["answer"])
+# MAGIC
+# MAGIC # RAG chain
+# MAGIC rag_chain = (
+# MAGIC {
+# MAGIC "question": lambda x: x["question"],
+# MAGIC "context": RunnablePassthrough()
+# MAGIC | RunnableLambda(retrieve_context)
+# MAGIC | RunnableLambda(format_context),
+# MAGIC }
+# MAGIC | prompt
+# MAGIC | RunnableLambda(call_model)
+# MAGIC )
+# MAGIC
+# MAGIC # Full chain with cache
+# MAGIC full_chain = (
+# MAGIC itemgetter("messages")
+# MAGIC | RunnableLambda(extract_user_query_string)
+# MAGIC | RunnableLambda(semantic_cache.get_from_cache)
+# MAGIC | RunnableLambda(router)
+# MAGIC | StrOutputParser()
+# MAGIC )
+# MAGIC
+# MAGIC # Tell MLflow logging where to find your chain.
+# MAGIC mlflow.models.set_model(model=full_chain)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In this cell, we log the chain to MLflow. Note that this time we are passing `cache.py` and `utils.py` along with `config.py` as dependencies, allowing the chain to also load custom classes and functions needed to another compute environment or to a Model Serving endpoint. MLflow returns a trace of the inference that shows the detail breakdown of the latency and the input/output from each step in the chain.
+
+# COMMAND ----------
+
+# Log the model to MLflow
+config_file_path = "config.py"
+cache_file_path = "cache.py"
+utils_file_path = "utils.py"
+
+with mlflow.start_run(run_name=f"rag_chatbot"):
+ logged_chain_info = mlflow.langchain.log_model(
+ lc_model=os.path.join(os.getcwd(), 'chain/chain_cache.py'), # Chain code file e.g., /path/to/the/chain.py
+ artifact_path="chain", # Required by MLflow
+ input_example=config.INPUT_EXAMPLE, # MLflow will execute the chain before logging & capture it's output schema.
+ code_paths = [cache_file_path, config_file_path, utils_file_path],
+ )
+
+# Test the chain locally
+chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
+chain.invoke(config.INPUT_EXAMPLE)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Let's ask a question to the chain that we know a similar question has not been asked before therefore doesn't exist in the caceh. We see in the trace that the entire chain is indeed executed.
+
+# COMMAND ----------
+
+chain.invoke({'messages': [{'content': "How does Databricks' feature Genie automate feature engineering for machine learning models?", 'role': 'user'}]})
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC If we reformulate the question without changing the meaning, we get the response from the cache. We see this in the trace and the execution time is less than half.
+
+# COMMAND ----------
+
+chain.invoke({'messages': [{'content': "What is the role of Databricks' feature Genie in automating feature engineering for machine learning models?", 'role': 'user'}]})
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Where to set the similarity threshold -`0.01` in this demo defined in `config.py`- is arguably the most important degin decision you need to make for your solution. A threshold that is too high will reduce the hit rate and undermine the effect of semantic caching, but a threshold too low could generate many false positives. There is a fine balance you would need to strike. To make an infromed descision, refer to the exploratory data analysis performed in the `00_introduction` notebook.
+# MAGIC
+# MAGIC If we are happy with the chain, we will go ahead and register the chain in Unity Catalog.
+
+# COMMAND ----------
+
+# Register to UC
+uc_registered_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=config.MODEL_FULLNAME_CACHE)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Deploy the chain to a Model Serving endpoint
+# MAGIC
+# MAGIC We deploy the chaing using custom functions defined in the `utils.py` script.
+
+# COMMAND ----------
+
+import utils
+utils.deploy_model_serving_endpoint(
+ spark,
+ config.MODEL_FULLNAME_CACHE,
+ config.CATALOG_CACHE,
+ config.LOGGING_SCHEMA_CACHE,
+ config.ENDPOINT_NAME_CACHE,
+ HOST,
+ TOKEN,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Wait until the endpoint is ready. This may take some time (~15 minutes), so grab a coffee!
+
+# COMMAND ----------
+
+utils.wait_for_model_serving_endpoint_to_be_ready(config.ENDPOINT_NAME_CACHE)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Once the endpoint is up and running, let's send a request and see how it responds.
+
+# COMMAND ----------
+
+import utils
+data = {
+ "inputs": {
+ "messages": [
+ {
+ "content": "What is Model Serving?",
+ "role": "user"
+ }
+ ]
+ }
+}
+# Now, call the function with the correctly formatted data
+utils.send_request_to_endpoint(config.ENDPOINT_NAME_CACHE, data)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In this notebook, we built a RAG chatbot with semantic caching. In the next `04_evaluate` notebook, we will compare the two chains we built.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/04_evaluate.py b/04_evaluate.py
new file mode 100644
index 0000000..9d8cf0e
--- /dev/null
+++ b/04_evaluate.py
@@ -0,0 +1,273 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #Evaluate the RAG chains with and without caching
+# MAGIC
+# MAGIC In the previous notebooks, we created and deployed RAG chains with and without semantic caching. Both are now up and running, ready to handle requests. In this notebook, we will conduct a benchmarking exercise to evaluate the latency reduction achieved by the cached chain and assess the trade-off in response quality.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Cluster configuration
+# MAGIC We recommend using a cluster with the following specifications to run this solution accelerator:
+# MAGIC - Unity Catalog enabled cluster
+# MAGIC - Databricks Runtime 15.4 LTS ML or above
+# MAGIC - Single-node cluster: e.g. `m6id.2xlarge` on AWS or `Standard_D8ds_v4` on Azure Databricks.
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Data preparation
+# MAGIC
+# MAGIC For the benchmarking exercise, we will use a hundred synthesized questions stored in `data/synthetic_questions_100.csv`. To create these, we first generated ten questions related to Databricks Machine Learning product features using [dbrx-instruct](https://e2-demo-field-eng.cloud.databricks.com/editor/notebooks/1284968239746639?o=1444828305810485#command/1284968239757668). We then expanded these by reformulating each of the ten questions slightly, without changing their meaning, generating ten variations of each. This resulted in a hundred questions in total. For this process, we used [Meta Llama 3.1 70B Instruct](https://docs.databricks.com/en/machine-learning/foundation-models/supported-models.html#meta-llama-31-70b-instruct).
+# MAGIC
+# MAGIC We read this dataset in and save it into a delta table.
+
+# COMMAND ----------
+
+import pandas as pd
+df = pd.read_csv('data/synthetic_questions_100.csv') # this is a small sample of 100 questions
+df = spark.createDataFrame(df) # convert to a Spark DataFrame
+df.write.mode('overwrite').saveAsTable(f'{config.CATALOG}.{config.SCHEMA}.synthetic_questions_100') # save to a table
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Next, we will format the questions so that we can apply the chain directly later. We store the formatted dataset in another delta table.
+
+# COMMAND ----------
+
+spark.sql(f"""
+CREATE OR REPLACE TABLE {config.CATALOG}.{config.SCHEMA}.synthetic_questions_100_formatted AS
+SELECT STRUCT(ARRAY(STRUCT(question AS content, "user" AS role)) AS messages) AS question, base as base
+FROM {config.CATALOG}.{config.SCHEMA}.synthetic_questions_100;
+""")
+
+df = spark.table(f'{config.CATALOG}.{config.SCHEMA}.synthetic_questions_100_formatted')
+display(df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Test standard rag chain endpoint
+# MAGIC
+# MAGIC Now that we have our test dataset, we are going to go ahead and test the standard RAG chain endpoint. We will use [ai_query](https://docs.databricks.com/en/sql/language-manual/functions/ai_query.html) to apply the chain to the formatted table. We write the result out to another delta table.
+
+# COMMAND ----------
+
+# DBTITLE 1,Load testing standard RAG chain
+spark.sql(f"""
+CREATE OR REPLACE TABLE {config.CATALOG}.{config.SCHEMA}.standard_rag_chain_results AS
+SELECT question, ai_query(
+ 'standard_rag_chatbot',
+ question,
+ returnType => 'STRUCT>'
+ ) AS prediction, base
+FROM {config.CATALOG}.{config.SCHEMA}.synthetic_questions_100_formatted;
+""")
+
+standard_rag_chain_results = spark.table(f'{config.CATALOG}.{config.SCHEMA}.standard_rag_chain_results')
+display(standard_rag_chain_results)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Test rag chain with cache endpoint
+# MAGIC
+# MAGIC We are now going to test the RAG chain with cache endpoint.
+
+# COMMAND ----------
+
+# DBTITLE 1,Load testing RAG chain with cache
+spark.sql(f"""
+CREATE OR REPLACE TABLE {config.CATALOG}.{config.SCHEMA}.rag_chain_with_cache_results AS
+SELECT question, ai_query(
+ 'rag_chatbot_with_cache',
+ question,
+ returnType => 'STRUCT>'
+ ) AS prediction, base
+FROM {config.CATALOG}.{config.SCHEMA}.synthetic_questions_100_formatted;
+""")
+
+rag_chain_with_cache_results = spark.table(f'{config.CATALOG}.{config.SCHEMA}.rag_chain_with_cache_results')
+display(rag_chain_with_cache_results)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Just by looking at the execution time, we notice that the chain with cache ran more thatn 2x faster than the the chain without.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Evaluate results using MLflow
+# MAGIC
+# MAGIC We will begin by evaluating the quality of the responses from both endpoints. Since the 100 questions were derived from the original 10 through reformulation (without changing their meaning), we can use the answers to the original questions as the ground truth for evaluating the responses to the 100 variations.
+
+# COMMAND ----------
+
+# DBTITLE 1,Reading in the original 10 questions and answers
+import json
+synthetic_qa = []
+with open('data/synthetic_qa.txt', 'r') as file:
+ for line in file:
+ synthetic_qa.append(json.loads(line))
+
+display(synthetic_qa)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We construct an evaluation dataset for the standard RAG chain and the chain with the cache. The `prediction` colume stores the responses from the chain.
+
+# COMMAND ----------
+
+evaluation_standard = spark.table(f'{config.CATALOG}.{config.SCHEMA}.standard_rag_chain_results').toPandas()
+evaluation_cache = spark.table(f'{config.CATALOG}.{config.SCHEMA}.rag_chain_with_cache_results').toPandas()
+
+evaluation_standard["question"] = evaluation_standard["question"].apply(lambda x: x["messages"][0]["content"])
+evaluation_standard["prediction"] = evaluation_standard["prediction"].apply(lambda x: json.loads(x["choices"][0])["message"]["content"])
+
+evaluation_cache["question"] = evaluation_cache["question"].apply(lambda x: x["messages"][0]["content"])
+evaluation_cache["prediction"] = evaluation_cache["prediction"].apply(lambda x: json.loads(x["choices"][0])["message"]["content"])
+
+labels = pd.DataFrame(synthetic_qa).drop(["question"], axis=1)
+
+evaluation_standard = evaluation_standard.merge(labels, on='base')
+evaluation_cache = evaluation_cache.merge(labels, on='base')
+
+# COMMAND ----------
+
+evaluation_standard
+
+# COMMAND ----------
+
+evaluation_cache
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC To assess the quality of the responses, we will use [`mlflow.evaluate`](https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.evaluate).
+
+# COMMAND ----------
+
+import mlflow
+from mlflow.deployments import set_deployments_target
+
+set_deployments_target("databricks")
+judge_model = "endpoints:/databricks-meta-llama-3-1-70b-instruct" # this is the model endpont you want to use as a judge
+
+# Run evaluation for the standard chain
+with mlflow.start_run(run_name="evaluation_standard"):
+ standard_results = mlflow.evaluate(
+ data=evaluation_standard,
+ targets="answer",
+ predictions="prediction",
+ model_type="question-answering",
+ extra_metrics=[
+ mlflow.metrics.genai.answer_similarity(model=judge_model),
+ mlflow.metrics.genai.answer_correctness(model=judge_model),
+ mlflow.metrics.genai.answer_relevance(model=judge_model),
+ ],
+ evaluator_config={
+ 'col_mapping': {'inputs': 'question'}
+ }
+ )
+
+# Run evaluation for the chain with cache
+with mlflow.start_run(run_name="evaluation_cache"):
+ cache_results = mlflow.evaluate(
+ data=evaluation_cache,
+ targets="answer",
+ predictions="prediction",
+ model_type="question-answering",
+ extra_metrics=[
+ mlflow.metrics.genai.answer_similarity(model=judge_model),
+ mlflow.metrics.genai.answer_correctness(model=judge_model),
+ mlflow.metrics.genai.answer_relevance(model=judge_model),
+ ],
+ evaluator_config={
+ 'col_mapping': {'inputs': 'question'}
+ }
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Let's print out the aggregated statistics of the quality metrics.
+
+# COMMAND ----------
+
+print(f"See aggregated evaluation results below: \n{standard_results.metrics}")
+
+# COMMAND ----------
+
+print(f"See aggregated evaluation results below: \n{cache_results.metrics}")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC The evaluation results show that the standard RAG chain performed slightly better on metrics like `answer_correctness/v1/mean` (scoring `4.82` vs. `4.69`) and `answer_relevance/v1/mean` (scoring `4.91` vs. `4.7`). These minor drops in performance are expected when responses are retrieved from the cache. The key takeaway is to assess whether these differences are acceptable given the cost and latency reductions provided by the caching solution. Ultimately, the decision should be based on how these trade-offs impact the business value of your use case.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Query the Inference tables
+# MAGIC
+# MAGIC Each request and response that hits the endpoint can be logged to an [inference table](https://docs.databricks.com/en/machine-learning/model-serving/inference-tables.html) along with its [trace](https://docs.databricks.com/en/mlflow/mlflow-tracing.html#use-mlflow-tracing-in-production). These tables are particularly useful for debugging and auditing. We will query the inference tables for both endpoints to gain insights into performance optimization.
+
+# COMMAND ----------
+
+# You can just query the inference table
+standard_log = spark.read.table(f"{config.CATALOG}.{config.LOGGING_SCHEMA}.standard_rag_chatbot_payload").toPandas()
+display(standard_log)
+
+# COMMAND ----------
+
+cache_log = spark.read.table(f"{config.CATALOG_CACHE}.{config.LOGGING_SCHEMA_CACHE}.rag_chatbot_with_cache_payload").toPandas()
+display(cache_log)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Let's calculate the mean execution time per query. We see a significant drop in the chain with cache, which is direclty translatable for cost reduction.
+
+# COMMAND ----------
+
+print(f"standard rag chain mean execution time: {round(standard_log['execution_time_ms'].mean()/1000, 4)} seconds")
+print(f"rag chain with cache mean execution time: {round(cache_log['execution_time_ms'].mean()/1000, 4)} seconds")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC One of the important KPIs for a cachin solution is the hit rate. We can retrieve this information from the traces stored in the inferenc table.
+
+# COMMAND ----------
+
+import json
+import numpy as np
+
+cache_trace = np.array(
+ cache_log["response"].apply(lambda x: 1 if len(json.loads(x)["databricks_output"]["trace"]["data"]["spans"]) == 6 else 0)
+)
+print(f"Number of times the query hit the cache: {cache_trace.sum()}/100")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC In this notebook, we conducted a benchmarking exercise to compare the solutions with and without semantic caching. For this specific dataset, we observed a significant reduction in both cost and latency, though with a slight trade-off in quality. It’s important to emphasize that every use case should carefully assess the impact of these gains and losses on business objectives before making a final decision.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/05_cache_eviction.py b/05_cache_eviction.py
new file mode 100644
index 0000000..96180e4
--- /dev/null
+++ b/05_cache_eviction.py
@@ -0,0 +1,126 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions).
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #Cache eviction
+# MAGIC
+# MAGIC This notebook walks you through some of the eviction strategies you can employ to your semantic cache.
+
+# COMMAND ----------
+
+# DBTITLE 1,Install requirements
+# MAGIC %pip install -r requirements.txt --quiet
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# DBTITLE 1,Load parameters
+from config import Config
+config = Config()
+
+# COMMAND ----------
+
+# DBTITLE 1,Set environmental variables
+import os
+
+HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
+TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
+
+os.environ['HOST'] = HOST
+os.environ['TOKEN'] = TOKEN
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Cleaning up the cache
+# MAGIC
+# MAGIC We instantiate a Vector Search client to interact with a Vector Search endpoint.
+
+# COMMAND ----------
+
+from databricks.vector_search.client import VectorSearchClient
+from cache import Cache
+
+vsc = VectorSearchClient(
+ workspace_url=HOST,
+ personal_access_token=TOKEN,
+ disable_notice=True,
+ )
+
+semantic_cache = Cache(vsc, config)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## FIFO (First-In-First-Out) Strategy
+# MAGIC
+# MAGIC **FIFO** (First-In-First-Out) removes the oldest cached items first. In a **semantic caching** context for **LLM responses**, it is useful when:
+# MAGIC **Static or frequently changing queries**: If queries or questions tend to change frequently over time, older answers might become irrelevant quickly.
+# MAGIC - **Use Case**: Effective in scenarios where users query frequently changing topics (e.g., breaking news or real-time.)
+# MAGIC
+# MAGIC #### Pros:
+# MAGIC - Simple to implement.
+# MAGIC - Removes outdated or stale responses automatically.
+# MAGIC
+# MAGIC #### Cons:
+# MAGIC - Does not account for query popularity. Frequently asked questions might be evicted even if they are still relevant.
+# MAGIC - Not ideal for handling frequently recurring queries, as important cached answers could be removed.
+# MAGIC
+
+# COMMAND ----------
+
+semantic_cache.evict(strategy='FIFO', max_documents=4, batch_size=4)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## LRU (Least Recently Used) Strategy
+# MAGIC
+# MAGIC **LRU** (Least Recently Used) evicts items that haven't been accessed recently. This strategy works well in **semantic caching** for **LLM responses** when:
+# MAGIC - **Popular or recurring questions**: Frequently asked questions (FAQs) remain in the cache while infrequent or one-off queries are evicted.
+# MAGIC - **Use Case**: Best suited for systems handling recurring queries, such as customer support, FAQ systems, or educational queries where the same questions are asked repeatedly.
+# MAGIC
+# MAGIC #### Pros:
+# MAGIC - Ensures that frequently accessed answers stay in the cache.
+# MAGIC - Minimizes re-computation for common queries.
+# MAGIC
+# MAGIC #### Cons:
+# MAGIC - Higher overhead compared to FIFO, as it tracks access patterns.
+# MAGIC - May retain less relevant but frequently accessed responses, while important but less commonly asked answers could be evicted.
+# MAGIC
+
+# COMMAND ----------
+
+semantic_cache.evict(strategy='LRU', max_documents=49)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### **Limitations:**
+# MAGIC
+# MAGIC - **Sequential Batch Eviction:** Both FIFO and LRU rely on batch eviction that involves querying and removing documents iteratively. This sequential process could slow down as the number of documents increases.
+# MAGIC - **Full Cache Query:** The current implementation of __evict_fifo_ and __evict_lru_ fetches a batch of documents for each iteration, which requires a similarity search query each time. This may introduce latency for larger caches.
+# MAGIC - **Single-threaded Eviction:** The eviction process operates in a single thread, and as the number of documents grows, the time taken to query and delete entries will increase.
+# MAGIC
+# MAGIC **Potential Improvements:**
+# MAGIC
+# MAGIC - **Bulk Deletion:**
+# MAGIC - Instead of deleting documents in small batches (based on batch_size), consider implementing bulk deletion by gathering all the documents to be evicted in a single query and deleting them all at once.
+# MAGIC - **Parallelism/Concurrency:**
+# MAGIC - Use parallel or multi-threaded processing to speed up both the similarity search and deletion processes using Spark.
+# MAGIC - Implementing asynchronous operations can allow multiple batches to be processed concurrently, reducing overall eviction time.
+# MAGIC - **Optimize Batch Size:**
+# MAGIC - Fine-tune the batch_size dynamically based on the current system load or cache size. Larger batches may reduce the number of queries but may also consume more memory, so optimization here is key.
+# MAGIC - **Index Partitioning:**
+# MAGIC - If possible, partition the index based on time (for FIFO) or access time (for LRU). This would allow the search and eviction process to be more efficient, as it would target a specific partition instead of querying the entire cache.
+# MAGIC - **Cache Usage Statistics:**
+# MAGIC - Integrate a system to track the real-time size of the cache and update indexed_row_count without querying the entire cache each time. This would reduce the number of times you need to check the total cache size during eviction.
+# MAGIC
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC © 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.
diff --git a/99_init.py b/99_init.py
new file mode 100644
index 0000000..0f6d03b
--- /dev/null
+++ b/99_init.py
@@ -0,0 +1,72 @@
+# Databricks notebook source
+dbutils.widgets.text("reset_all_data", "false", "Reset Data")
+reset_all_data = dbutils.widgets.get("reset_all_data") == "true"
+
+# COMMAND ----------
+
+from pyspark.sql.functions import pandas_udf
+import pandas as pd
+import pyspark.sql.functions as F
+from pyspark.sql.functions import col, udf, length, pandas_udf
+import os
+import mlflow
+import yaml
+from typing import Iterator
+from mlflow import MlflowClient
+mlflow.set_registry_uri('databricks-uc')
+
+# Set up logging
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
+logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)
+logging.getLogger('mlflow').setLevel(logging.ERROR) # Disable MLflow warnings
+from urllib3.connectionpool import log as urllib3_log
+urllib3_log.setLevel(logging.ERROR)
+
+# Workaround for a bug fix that is in progress
+mlflow.spark.autolog(disable=True)
+
+import warnings
+warnings.filterwarnings("ignore")
+
+# COMMAND ----------
+
+if reset_all_data:
+ print(f'clearing up schema {config.CATALOG}.{config.SCHEMA}')
+ _ = spark.sql(f"DROP DATABASE IF EXISTS `{config.CATALOG}.{config.SCHEMA}` CASCADE")
+
+# COMMAND ----------
+
+def use_and_create_db(CATALOG, SCHEMA, cloud_storage_path = None):
+ print(f"USE CATALOG `{CATALOG}`")
+ _ = spark.sql(f"USE CATALOG `{CATALOG}`")
+ _ = spark.sql(f"""CREATE DATABASE IF NOT EXISTS `{SCHEMA}` """)
+
+#If the catalog is defined, we force it to the given value and throw exception if not.
+if len(config.CATALOG) > 0:
+ current_catalog = spark.sql("SELECT current_catalog()").collect()[0]['current_catalog()']
+ if current_catalog != config.CATALOG:
+ catalogs = [r['catalog'] for r in spark.sql("SHOW CATALOGS").collect()]
+ if config.CATALOG not in catalogs:
+ _ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {config.CATALOG}")
+ use_and_create_db(config.CATALOG, config.SCHEMA)
+
+print(f"using catalog.database `{config.CATALOG}`.`{config.SCHEMA}`")
+_ = spark.sql(f"""USE `{config.CATALOG}`.`{config.SCHEMA}`""")
+
+# COMMAND ----------
+
+if not spark.catalog.tableExists(config.SOURCE_TABLE_FULLNAME) or spark.table(config.SOURCE_TABLE_FULLNAME).isEmpty() or \
+ not spark.catalog.tableExists(config.EVALUATION_TABLE_FULLNAME) or spark.table(config.EVALUATION_TABLE_FULLNAME).isEmpty():
+ _ = spark.sql(f'''CREATE TABLE IF NOT EXISTS {config.SOURCE_TABLE_FULLNAME} (
+ id BIGINT GENERATED BY DEFAULT AS IDENTITY,
+ url STRING,
+ content STRING
+ ) TBLPROPERTIES (delta.enableChangeDataFeed = true)''')
+ (spark.createDataFrame(pd.read_parquet('https://notebooks.databricks.com/demos/dbdemos-dataset/llm/databricks-documentation/databricks_documentation.parquet'))
+ .drop('title').write.mode('overwrite').saveAsTable(config.SOURCE_TABLE_FULLNAME))
+ (spark.createDataFrame(pd.read_parquet('https://notebooks.databricks.com/demos/dbdemos-dataset/llm/databricks-documentation/databricks_doc_eval_set.parquet'))
+ .write.mode('overwrite').saveAsTable(config.EVALUATION_TABLE_FULLNAME))
+ # Make sure enableChangeDataFeed is enabled
+ _ = spark.sql('ALTER TABLE databricks_documentation SET TBLPROPERTIES (delta.enableChangeDataFeed = true)')
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..4f202ad
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1 @@
+We happily welcome contributions to this project. We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes pursuant to a CLA.
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000..5122ccd
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,24 @@
+# DB license
+**Definitions.**
+
+Agreement: The agreement between Databricks, Inc., and you governing the use of the Databricks Services, as that term is defined in the Master Cloud Services Agreement (MCSA) located at www.databricks.com/legal/mcsa.
+
+Licensed Materials: The source code, object code, data, and/or other works to which this license applies.
+
+**Scope of Use.** You may not use the Licensed Materials except in connection with your use of the Databricks Services pursuant to the Agreement. Your use of the Licensed Materials must comply at all times with any restrictions applicable to the Databricks Services, generally, and must be used in accordance with any applicable documentation. You may view, use, copy, modify, publish, and/or distribute the Licensed Materials solely for the purposes of using the Licensed Materials within or connecting to the Databricks Services. If you do not agree to these terms, you may not view, use, copy, modify, publish, and/or distribute the Licensed Materials.
+
+**Redistribution.** You may redistribute and sublicense the Licensed Materials so long as all use is in compliance with these terms. In addition:
+
+- You must give any other recipients a copy of this License;
+- You must cause any modified files to carry prominent notices stating that you changed the files;
+- You must retain, in any derivative works that you distribute, all copyright, patent, trademark, and attribution notices, excluding those notices that do not pertain to any part of the derivative works; and
+- If a "NOTICE" text file is provided as part of its distribution, then any derivative works that you distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the derivative works.
+
+
+You may add your own copyright statement to your modifications and may provide additional license terms and conditions for use, reproduction, or distribution of your modifications, or for any such derivative works as a whole, provided your use, reproduction, and distribution of the Licensed Materials otherwise complies with the conditions stated in this License.
+
+**Termination.** This license terminates automatically upon your breach of these terms or upon the termination of your Agreement. Additionally, Databricks may terminate this license at any time on notice. Upon termination, you must permanently delete the Licensed Materials and all copies thereof.
+
+**DISCLAIMER; LIMITATION OF LIABILITY.**
+
+THE LICENSED MATERIALS ARE PROVIDED “AS-IS” AND WITH ALL FAULTS. DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY DISCLAIMS ALL WARRANTIES RELATING TO THE LICENSED MATERIALS, EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE LICENSED MATERIALS SHALL BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS OR THE USE OR OTHER DEALINGS IN THE LICENSED MATERIALS.
diff --git a/NOTICE.md b/NOTICE.md
new file mode 100644
index 0000000..defc5ec
--- /dev/null
+++ b/NOTICE.md
@@ -0,0 +1,4 @@
+Copyright (2024) Databricks, Inc.
+
+This Software includes software developed at Databricks (https://www.databricks.com/) and its use is subject to the included LICENSE file.
+By using this repository and the notebooks within, you consent to Databricks collection and use of usage and tracking information in accordance with our privacy policy at www.databricks/privacypolicy.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3434164
--- /dev/null
+++ b/README.md
@@ -0,0 +1,26 @@
+
+
+[![DBR](https://img.shields.io/badge/DBR-CHANGE_ME-red?logo=databricks&style=for-the-badge)](https://docs.databricks.com/release-notes/runtime/CHANGE_ME.html)
+[![CLOUD](https://img.shields.io/badge/CLOUD-CHANGE_ME-blue?logo=googlecloud&style=for-the-badge)](https://databricks.com/try-databricks)
+
+## Business Problem
+WHAT IS THE BUSINESS PROBLEM ADDRESSED BY THIS SOLUTION
+
+## Reference Architecture
+IMAGE TO REFERENCE ARCHITECTURE
+
+## Authors
+
+
+## Project support
+
+Please note the code in this project is provided for your exploration only, and are not formally supported by Databricks with Service Level Agreements (SLAs). They are provided AS-IS and we do not make any guarantees of any kind. Please do not submit a support ticket relating to any issues arising from the use of these projects. The source in this project is provided subject to the Databricks [License](./LICENSE.md). All included or referenced third party libraries are subject to the licenses set forth below.
+
+Any issues discovered through the use of this project should be filed as GitHub Issues on the Repo. They will be reviewed as time permits, but there are no formal SLAs for support.
+
+## License
+
+© 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License [https://databricks.com/db-license-source]. All included or referenced third party libraries are subject to the licenses set forth below.
+
+| library | description | license | source |
+|----------------------------------------|-------------------------|------------|-----------------------------------------------------|
diff --git a/RUNME.md b/RUNME.md
new file mode 100644
index 0000000..b9c4f3f
--- /dev/null
+++ b/RUNME.md
@@ -0,0 +1,3 @@
+# Run me...
+
+Guiding users through implementation details. This may contain gotchas', code snippets and / or general guidance expressed as markdown.
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000..b3483c5
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,6 @@
+# Security Policy
+
+## Reporting a Vulnerability
+
+Please email bugbounty@databricks.com to report any security vulnerabilities. We will acknowledge receipt of your vulnerability and strive to send you regular updates about our progress. If you're curious about the status of your disclosure please feel free to email us again. If you want to encrypt your disclosure email, you can use [this PGP key](https://keybase.io/arikfr/key.asc).
+
diff --git a/cache.py b/cache.py
new file mode 100644
index 0000000..843ace3
--- /dev/null
+++ b/cache.py
@@ -0,0 +1,215 @@
+import json
+import utils
+import mlflow
+import logging
+from uuid import uuid4
+from datetime import datetime
+from databricks.vector_search.client import VectorSearchClient
+
+class Cache:
+ def __init__(self, vsc, config):
+ mlflow.set_tracking_uri("databricks")
+ self.vsc = vsc
+ self.config = config
+
+ def create_cache(self):
+ # Create or wait for the endpoint
+ utils.create_or_wait_for_endpoint(self.vsc, self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE)
+ logging.info(f"Vector search endpoint '{self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE}' is ready")
+
+ # Create or update the main index
+ utils.create_or_update_direct_index(
+ self.vsc,
+ self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE,
+ self.config.VS_INDEX_FULLNAME_CACHE,
+ self.config.VECTOR_SEARCH_INDEX_SCHEMA_CACHE,
+ self.config.VECTOR_SEARCH_INDEX_CONFIG_CACHE,
+ )
+ logging.info(f"Main index '{self.config.VS_INDEX_FULLNAME_CACHE}' created/updated and is ready")
+ logging.info("Environment setup completed successfully")
+
+ @staticmethod
+ def load_data(file_path):
+ data = []
+ with open(file_path, 'r') as file:
+ for line in file:
+ data.append(json.loads(line))
+ return data
+
+ def get_embedding(self, text):
+ from mlflow.deployments import get_deploy_client
+ client = get_deploy_client("databricks")
+ response = client.predict(
+ endpoint=self.config.EMBEDDING_MODEL_SERVING_ENDPOINT_NAME,
+ inputs={"input": [text]})
+ return response.data[0]['embedding']
+
+ def warm_cache(self, batch_size=100):
+ vs_index_cache = self.vsc.get_index(
+ index_name=self.config.VS_INDEX_FULLNAME_CACHE,
+ endpoint_name=self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE,
+ )
+ # Load dataset
+ data = Cache.load_data(self.config.CACHE_WARMING_FILE_PATH)
+ logging.info(f"Loaded {len(data)} documents from {self.config.CACHE_WARMING_FILE_PATH}")
+ documents = []
+ for idx, item in enumerate(data):
+ if 'question' in item and 'answer' in item:
+ embedding = self.get_embedding(item['question'])
+ doc = {
+ "id": str(idx),
+ "creator": "system",
+ "question": item["question"],
+ "answer": item["answer"],
+ "access_level": 0,
+ "created_at": datetime.now().isoformat(),
+ "text_vector": embedding
+ }
+ documents.append(doc)
+
+ # Upsert when batch size is reached
+ if len(documents) >= batch_size:
+ try:
+ vs_index_cache.upsert(documents)
+ print(f"Successfully upserted batch of {len(documents)} documents.")
+ except Exception as e:
+ print(f"Error upserting batch: {str(e)}")
+ documents = [] # Clear the batch
+
+ # Upsert any remaining documents
+ if documents:
+ try:
+ vs_index_cache.upsert(documents)
+ print(f"Successfully upserted final batch of {len(documents)} documents.")
+ except Exception as e:
+ print(f"Error upserting final batch: {str(e)}")
+
+ logging.info("Index details:")
+ logging.info(f" Type: {type(vs_index_cache)}")
+ logging.info(f" Name: {vs_index_cache.name}")
+ logging.info(f" Endpoint name: {vs_index_cache.endpoint_name}")
+ logging.info(f"Finished loading documents into the index.")
+ logging.info("Cache warming completed successfully")
+
+ # Get response from cache
+ def get_from_cache(self, question, creator="user", access_level=0):
+ vs_index_cache = self.vsc.get_index(
+ index_name=self.config.VS_INDEX_FULLNAME_CACHE,
+ endpoint_name=self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE,
+ )
+ # Check if the question exists in the cache
+ qa = {"question": question, "answer": ""}
+ results = vs_index_cache.similarity_search(
+ query_vector=self.get_embedding(question),
+ columns=["id", "question", "answer"],
+ num_results=1
+ )
+ if results and results['result']['row_count'] > 0:
+ score = results['result']['data_array'][0][3] # Get the score
+ logging.info(f"Score: {score}")
+ try:
+ if float(score) >= self.config.SIMILARITY_THRESHOLD:
+ # Cache hit
+ qa["answer"] = results['result']['data_array'][0][2]
+ record_id = results['result']['data_array'][0][0] # Assuming 'id' is the first column
+ logging.info("Cache hit: True")
+ else:
+ logging.info("Cache hit: False")
+ except ValueError:
+ logging.info(f"Warning: Invalid score value: {score}")
+ return qa
+
+ # Store response to the cache
+ def store_in_cache(self, question, answer, creator="user", access_level=0):
+ vs_index_cache = self.vsc.get_index(
+ index_name=self.config.VS_INDEX_FULLNAME_CACHE,
+ endpoint_name=self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE,
+ )
+ document = {
+ "id": str(uuid4()),
+ "creator": creator,
+ "question": question,
+ "answer": answer,
+ "access_level": access_level,
+ "created_at": datetime.now().isoformat(),
+ "text_vector": self.get_embedding(question),
+ }
+ vs_index_cache.upsert([document])
+
+ def evict(self, strategy='FIFO', max_documents=1000, batch_size=100):
+ total_docs = self.get_indexed_row_count()
+
+ if total_docs <= max_documents:
+ logging.info(f"Cache size ({total_docs}) is within limit ({max_documents}). No eviction needed.")
+ return
+
+ docs_to_remove = total_docs - max_documents
+ logging.info(f"Evicting {docs_to_remove} documents from cache using {strategy} strategy...")
+
+ index = self.vsc.get_index(
+ index_name=self.config.VS_INDEX_FULLNAME_CACHE,
+ endpoint_name=self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE
+ )
+
+ if strategy == 'FIFO':
+ self._evict_fifo(index, docs_to_remove, batch_size)
+ elif strategy == 'LRU':
+ self._evict_lru(index, docs_to_remove, batch_size)
+ else:
+ raise ValueError(f"Unknown eviction strategy: {strategy}")
+
+ logging.info("Cache eviction completed.")
+
+ def _evict_fifo(self, index, docs_to_remove, batch_size):
+ while docs_to_remove > 0:
+ results = index.similarity_search(
+ query_vector=[0] * self.config.EMBEDDING_DIMENSION,
+ columns=["id", "created_at"],
+ num_results=min(docs_to_remove, batch_size),
+ )
+
+ if not results or results['result']['row_count'] == 0:
+ break
+
+ ids_to_remove = [row[0] for row in results['result']['data_array']]
+ index.delete(ids_to_remove)
+
+ docs_to_remove -= len(ids_to_remove)
+ logging.info(f"Removed {len(ids_to_remove)} documents from cache (FIFO).")
+
+ def _evict_lru(self, index, docs_to_remove, batch_size):
+ while docs_to_remove > 0:
+ results = index.similarity_search(
+ query_vector=[0] * self.config.EMBEDDING_DIMENSION,
+ columns=["id", "last_accessed"],
+ num_results=min(docs_to_remove, batch_size),
+ )
+
+ if not results or results['result']['row_count'] == 0:
+ break
+
+ ids_to_remove = [row[0] for row in results['result']['data_array']]
+ index.delete(ids_to_remove)
+
+ docs_to_remove -= len(ids_to_remove)
+ logging.info(f"Removed {len(ids_to_remove)} documents from cache (LRU).")
+
+ def get_indexed_row_count(self):
+ index = self.vsc.get_index(
+ index_name=self.config.VS_INDEX_FULLNAME_CACHE,
+ endpoint_name=self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE,
+ )
+ description = index.describe()
+ return description.get('status', {}).get('indexed_row_count', 0)
+
+ def clear_cache(self):
+ logging.info(f"Cleaning cache on endpoint {self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE}...")
+ if utils.index_exists(self.vsc, self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE, self.config.VS_INDEX_FULLNAME_CACHE):
+ try:
+ self.vsc.delete_index(self.config.VECTOR_SEARCH_ENDPOINT_NAME_CACHE, self.config.VS_INDEX_FULLNAME_CACHE)
+ logging.info(f"Cache index {self.config.VS_INDEX_FULLNAME_CACHE} deleted successfully")
+ except Exception as e:
+ logging.error(f"Error deleting cache index {self.config.VS_INDEX_FULLNAME_CACHE}: {str(e)}")
+ else:
+ logging.info(f"Cache index {self.config.VS_INDEX_FULLNAME_CACHE} does not exist")
+
diff --git a/chain/chain.py b/chain/chain.py
new file mode 100644
index 0000000..115fd0b
--- /dev/null
+++ b/chain/chain.py
@@ -0,0 +1,75 @@
+from databricks.vector_search.client import VectorSearchClient
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_community.chat_models import ChatDatabricks
+from langchain_community.vectorstores import DatabricksVectorSearch
+from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+from langchain_core.output_parsers import StrOutputParser
+from operator import itemgetter
+from config import Config
+import mlflow
+import os
+
+## Enable MLflow Tracing
+mlflow.langchain.autolog()
+
+# load parameters
+config = Config()
+
+# Connect to the Vector Search Index
+vs_index = VectorSearchClient(
+ workspace_url=os.environ['HOST'],
+ personal_access_token=os.environ['TOKEN'],
+ disable_notice=True,
+ ).get_index(
+ endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
+ index_name=config.VS_INDEX_FULLNAME,
+)
+
+# Turn the Vector Search index into a LangChain retriever
+vector_search_as_retriever = DatabricksVectorSearch(
+ vs_index,
+ text_column="content",
+ columns=["id", "content", "url"],
+).as_retriever(search_kwargs={"k": 3}) # Number of search results that the retriever returns
+# Enable the RAG Studio Review App and MLFlow to properly display track and display retrieved chunks for evaluation
+mlflow.models.set_retriever_schema(primary_key="id", text_column="content", doc_uri="url")
+
+# Method to format the docs returned by the retriever into the prompt (keep only the text from chunks)
+def format_context(docs):
+ chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
+ return "".join(chunk_contents)
+
+# Prompt template to be used to prompt the LLM
+prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", f"{config.LLM_PROMPT_TEMPLATE}"),
+ ("user", "{question}"),
+ ]
+)
+
+# Our foundation model answering the final prompt
+model = ChatDatabricks(
+ endpoint=config.LLM_MODEL_SERVING_ENDPOINT_NAME,
+ extra_params={"temperature": 0.01, "max_tokens": 500}
+)
+
+# Return the string contents of the most recent messages: [{...}] from the user to be used as input question
+def extract_user_query_string(chat_messages_array):
+ return chat_messages_array[-1]["content"]
+
+# RAG Chain
+chain = (
+ {
+ "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
+ "context": itemgetter("messages")
+ | RunnableLambda(extract_user_query_string)
+ | vector_search_as_retriever
+ | RunnableLambda(format_context),
+ }
+ | prompt
+ | model
+ | StrOutputParser()
+)
+
+# Tell MLflow logging where to find your chain.
+mlflow.models.set_model(model=chain)
diff --git a/chain/chain_cache.py b/chain/chain_cache.py
new file mode 100644
index 0000000..7a33bb9
--- /dev/null
+++ b/chain/chain_cache.py
@@ -0,0 +1,119 @@
+from databricks.vector_search.client import VectorSearchClient
+from langchain_community.vectorstores import DatabricksVectorSearch
+from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import ChatPromptTemplate
+from langchain_community.chat_models import ChatDatabricks
+from operator import itemgetter
+from datetime import datetime
+from uuid import uuid4
+import os
+import mlflow
+from cache import Cache
+from config import Config
+
+
+# Set up logging
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
+logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)
+
+## Enable MLflow Tracing
+mlflow.langchain.autolog()
+
+# Get configuration
+config = Config()
+
+# Connect to Vector Search
+vsc = VectorSearchClient(
+ workspace_url=os.environ['HOST'],
+ personal_access_token=os.environ['TOKEN'],
+ disable_notice=True,
+)
+
+# Get the Vector Search index
+vs_index = vsc.get_index(
+ index_name=config.VS_INDEX_FULLNAME,
+ endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
+ )
+
+# Instantiate a Cache object
+semantic_cache = Cache(vsc, config)
+
+# Turn the Vector Search index into a LangChain retriever
+vector_search_as_retriever = DatabricksVectorSearch(
+ vs_index,
+ text_column="content",
+ columns=["id", "content", "url"],
+).as_retriever(search_kwargs={"k": 3}) # Number of search results that the retriever returns
+
+# Method to retrieve the context from the Vector Search index
+def retrieve_context(qa):
+ return vector_search_as_retriever.invoke(qa["question"])
+
+# Enable the RAG Studio Review App and MLFlow to properly display track and display retrieved chunks for evaluation
+mlflow.models.set_retriever_schema(primary_key="id", text_column="content", doc_uri="url")
+
+# Method to format the docs returned by the retriever into the prompt (keep only the text from chunks)
+def format_context(docs):
+ chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
+ return "".join(chunk_contents)
+
+# Create a prompt template for response generation
+prompt = ChatPromptTemplate.from_messages(
+ [
+ ("system", f"{config.LLM_PROMPT_TEMPLATE}"),
+ ("user", "{question}"),
+ ]
+)
+
+# Define our foundation model answering the final prompt
+model = ChatDatabricks(
+ endpoint=config.LLM_MODEL_SERVING_ENDPOINT_NAME,
+ extra_params={"temperature": 0.01, "max_tokens": 500}
+)
+
+# Call the foundation model
+def call_model(prompt):
+ response = model.invoke(prompt)
+ semantic_cache.store_in_cache(
+ question = prompt.dict()['messages'][1]['content'],
+ answer = response.content
+ )
+ return response
+
+# Return the string contents of the most recent messages: [{...}] from the user to be used as input question
+def extract_user_query_string(chat_messages_array):
+ return chat_messages_array[-1]["content"]
+
+# Router to determine which subsequent step to be executed
+def router(qa):
+ if qa["answer"] == "":
+ return rag_chain
+ else:
+ return (qa["answer"])
+
+# RAG chain
+rag_chain = (
+ {
+ "question": lambda x: x["question"],
+ "context": RunnablePassthrough()
+ | RunnableLambda(retrieve_context)
+ | RunnableLambda(format_context),
+ }
+ | prompt
+ | RunnableLambda(call_model)
+)
+
+# Full chain with cache
+full_chain = (
+ itemgetter("messages")
+ | RunnableLambda(extract_user_query_string)
+ | RunnableLambda(semantic_cache.get_from_cache)
+ | RunnableLambda(router)
+ | StrOutputParser()
+)
+
+# Tell MLflow logging where to find your chain.
+mlflow.models.set_model(model=full_chain)
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..e2596ed
--- /dev/null
+++ b/config.py
@@ -0,0 +1,50 @@
+class Config:
+ def __init__(self):
+
+ self.LLM_MODEL_SERVING_ENDPOINT_NAME = "databricks-dbrx-instruct"
+ self.EMBEDDING_MODEL_SERVING_ENDPOINT_NAME = "databricks-gte-large-en"
+
+ # For vector search index
+ self.CATALOG = "semantic_cache_solacc"
+ self.SCHEMA = "chatbot_rag"
+ self.SOURCE_TABLE_FULLNAME = f"{self.CATALOG}.{self.SCHEMA}.databricks_documentation"
+ self.EVALUATION_TABLE_FULLNAME = f"{self.CATALOG}.{self.SCHEMA}.eval_databricks_documentation"
+ self.VECTOR_SEARCH_ENDPOINT_NAME = "one-env-shared-endpoint-13"
+ self.VS_INDEX_FULLNAME = f"{self.CATALOG}.{self.SCHEMA}.databricks_documentation_vs_index"
+ self.MODEL_FULLNAME = f"{self.CATALOG}.{self.SCHEMA}.standard_rag_chatbot"
+ self.ENDPOINT_NAME = "standard_rag_chatbot"
+ self.LOGGING_SCHEMA = f"logging"
+
+ # For semantic cache
+ self.CATALOG_CACHE = "semantic_cache_solacc"
+ self.SCHEMA_CACHE = "chatbot_cache"
+ self.VECTOR_SEARCH_ENDPOINT_NAME_CACHE = "one-env-shared-endpoint-13"
+ self.VS_INDEX_FULLNAME_CACHE = f"{self.CATALOG}.{self.SCHEMA}.cache_vs_index"
+ self.VS_METRICS_INDEX_FULLNAME_CACHE = f"{self.CATALOG}.{self.SCHEMA}.metrics"
+ self.MODEL_FULLNAME_CACHE = f"{self.CATALOG}.{self.SCHEMA}.rag_chatbot_with_cache"
+ self.ENDPOINT_NAME_CACHE = "rag_chatbot_with_cache"
+ self.LOGGING_SCHEMA_CACHE = f"logging"
+ self.CACHE_WARMING_FILE_PATH="data/synthetic_qa.txt"
+ self.INDEX_NAME = "cache_index"
+ self.SIMILARITY_THRESHOLD = 0.01
+
+ self.VECTOR_SEARCH_INDEX_SCHEMA_CACHE = {
+ "id": "string",
+ "creator": "string",
+ "question": "string",
+ "answer": "string",
+ "access_level": "int",
+ "created_at": "timestamp",
+ "text_vector": "array"
+ }
+
+ self.EMBEDDING_DIMENSION = 1024
+ self.VECTOR_SEARCH_INDEX_CONFIG_CACHE = {
+ "primary_key": "id",
+ "embedding_dimension": self.EMBEDDING_DIMENSION,
+ "embedding_vector_column": "text_vector"
+ }
+
+ self.LLM_PROMPT_TEMPLATE = """You are an assistant that answers questions. Use the following pieces of retrieved context to answer the question. Some pieces of context may be irrelevant, in which case you should not use them to form the answer.\n\nContext: {context}"""
+
+ self.INPUT_EXAMPLE = {"messages": [{"content": "What is Databricks Model Serving?", "role": "user"}]}
diff --git a/data/synthetic_qa.txt b/data/synthetic_qa.txt
new file mode 100644
index 0000000..122dc76
--- /dev/null
+++ b/data/synthetic_qa.txt
@@ -0,0 +1,10 @@
+{"base": 1, "answer": "Databricks' machine learning platform supports automated hyperparameter tuning for model optimization through the use of Hyperopt and Ray. Hyperopt is a Python library that facilitates distributed hyperparameter tuning and model selection, allowing you to scan a set of Python models while varying algorithms and hyperparameters across spaces that you define. Ray is an open-source framework that specializes in parallel compute processing for scaling ML workflows and AI applications, which can be used to distribute the hyperparameter tuning process.\n\nDatabricks Runtime ML includes both Hyperopt and Ray, and provides examples and documentation on how to use them for hyperparameter tuning. For instance, you can define an objective function to minimize, specify the hyperparameter search space, choose a search algorithm, and then run the Hyperopt function fmin() to identify the set of hyperparameters that minimizes the objective function.\n\nAdditionally, Databricks Runtime 5.4 ML introduced a new implementation of Hyperopt powered by Apache Spark to scale and simplify hyperparameter tuning. A new Trials class, SparkTrials, was implemented to distribute Hyperopt trial runs among multiple machines and nodes using Apache Spark. This feature also includes automated MLflow tracking, which logs all tuning experiments, along with the tuned hyperparameters and targeted metrics, to MLflow runs.\n\nDatabricks also provides tutorials and notebooks that demonstrate how to use Hyperopt for automated hyperparameter tuning, along with other tools such as MLflow tracking and autologging for model development, and Model Registry for model management.", "question": "How does Databricks' machine learning platform support automated hyperparameter tuning for model optimization?"}
+{"base": 2, "answer": "Yes, Databricks' machine learning platform can handle large-scale deep learning workloads. Here's how:\n\n1. **Databricks Runtime for Machine Learning (Databricks Runtime ML)**: Databricks Runtime ML is a runtime environment that is optimized for machine learning and deep learning workloads. It includes popular deep learning libraries like TensorFlow, PyTorch, and Keras, and supporting libraries such as Petastorm, Hyperopt, and Horovod. Databricks Runtime ML clusters also include pre-configured GPU support with drivers and supporting libraries.\n\n2. **Horovod and Hyperopt**: Databricks provides Horovod and Hyperopt to parallelize training and optimize hyperparameters, respectively. Horovod is an open-source project that scales deep learning training to multi-GPU or distributed computation, while Hyperopt provides adaptive hyperparameter tuning for machine learning.\n\n3. **Pandas UDFs for inference**: Pandas UDFs (User Defined Functions) can be used for inference, allowing you to perform inference on large datasets in a distributed manner.\n\n4. **A100 GPUs**: Databricks supports A100 GPUs on all clouds, which are an efficient choice for many deep learning tasks, such as training and tuning large language models, natural language processing, object detection and classification, and recommendation engines.\n\n5. **Delta Lake and Petastorm**: Databricks Runtime ML includes Delta Lake and Petastorm to optimize data throughput for deep learning applications. Delta Lake simplifies ETL and lets you access data efficiently, while Petastorm provides APIs that let you prepare data in parquet format for use by TensorFlow, Keras, or PyTorch.\n\n6. **Databricks Model Serving**: Databricks Model Serving enables creation of scalable GPU endpoints for deep learning models with no extra configuration.\n\n7. **Databricks Workflows**: Databricks Workflows can be used for automated workflows and production-ready ETL pipelines, allowing you to automate the training and deployment of your deep learning models.\n\n8. **Unity Catalog**: Unity Catalog provides governance, discovery, versioning, and access control for data, features, models, and functions, allowing you to manage your deep learning models and data in a centralized manner.\n\n9. **Lakehouse Monitoring**: Lakehouse Monitoring can be used to track model prediction quality and drift, allowing you to monitor the performance of your deep learning models in production.\n\n10. **Databricks Git folders**: D", "question": "Can Databricks' machine learning platform handle large-scale deep learning workloads, and if so, how?"}
+{"base": 3, "answer": "Databricks provides several feature engineering and data preprocessing capabilities as part of its machine learning platform. Here are some of the key features:\n\n1. **Databricks Feature Store**: This tool allows you to create new features, explore and re-use existing features, select features for training and scoring machine learning models, and publish features to low-latency online stores for real-time inference.\n\n2. **Spark SQL and MLlib**: On large datasets, you can use Spark SQL and MLlib for feature engineering. These tools provide a wide range of functions for data manipulation and transformation.\n\n3. **Third-party libraries**: Databricks Runtime ML includes third-party libraries such as scikit-learn, which provide useful helper methods for feature engineering.\n\n4. **Transfer learning**: For more complex deep learning feature processing, Databricks provides an example notebook that illustrates how to use transfer learning for featurization.\n\n5. **Data preprocessing**: Databricks supports data preprocessing using tools like Autoloader and Apache Spark for ingesting and transforming data, and Delta tables for tracking changes to data including versioning and lineage.\n\n6. **Exploratory data analysis and dashboards**: Databricks provides tools like Databricks SQL, Dashboards, and Databricks notebooks for exploratory data analysis and creating dashboards.\n\nThese capabilities are designed to help you preprocess and engineer features for your machine learning models in a scalable and efficient manner.", "question": "What kind of feature engineering and data preprocessing capabilities are available in Databricks' machine learning platform?"}
+{"base": 4, "answer": "Databricks' machine learning platform supports model interpretability and explainability techniques, such as SHAP and LIME, through various built-in tools and integrations. Here's how:\n\n1. **MLflow**: Databricks integrates with MLflow, an open-source platform for managing the end-to-end machine learning lifecycle. MLflow allows you to track, version, and share machine learning models and experiments. It also supports model interpretability techniques, such as SHAP and LIME, by enabling you to log and visualize feature importances and model explanations alongside your model's performance metrics.\n\n2. **Databricks Runtime for Machine Learning**: Databricks Runtime for Machine Learning includes popular machine learning libraries, such as scikit-learn, which has built-in support for SHAP and LIME. This allows you to use these interpretability techniques directly in your Databricks notebooks and workflows.\n\n3. **Integration with Open Source Tools**: Databricks has a strong commitment to the open-source community and supports various open-source libraries and tools that enable model interpretability and explainability, such as SHAP and LIME. You can install and use these tools directly in your Databricks notebooks and workflows.\n\n4. **Notebook-based Workflow**: Databricks provides a collaborative notebook-based environment that allows you to document and share your machine learning workflows, including model interpretability and explainability analyses. This helps ensure that your model's behavior and decision-making processes are transparent and understandable to stakeholders.\n\nIn summary, Databricks supports model interpretability and explainability techniques through its integration with MLflow, the inclusion of popular machine learning libraries in Databricks Runtime for Machine Learning, support for open-source tools, and a collaborative notebook-based workflow.", "question": "How does Databricks' machine learning platform support model interpretability and explainability techniques, such as SHAP and LIME?"}
+{"base": 5, "answer": "Yes, Databricks' machine learning platform can be used for natural language processing (NLP) tasks. Databricks supports popular open source libraries such as Spark ML and spark-nlp for performing NLP tasks. With Spark ML, you can create input features from text for model training algorithms directly in your Spark ML pipelines. It supports a range of text processors, including tokenization, stop-word processing, word2vec, and feature hashing. Spark NLP, on the other hand, allows you to scale out many deep learning methods for natural language processing on Spark. It supports standard NLP operations such as tokenizing, named entity recognition, and vectorization using the included annotators. You can also summarize, perform named entity recognition, translate, and generate text using many pre-trained deep learning models based on Spark NLP\u2019s transformers such as BERT and T5 Marion. Additionally, Databricks Runtime for Machine Learning includes libraries like Hugging Face Transformers and LangChain that allow you to integrate existing pre-trained models or other open-source libraries into your workflow.", "question": "Can Databricks' machine learning platform be used for natural language processing (NLP) tasks, and if so, what libraries and tools are supported?"}
+{"base": 6, "answer": "Databricks' machine learning platform supports real-time model serving and The service automatically scales up or down to meet demand changes, saving infrastructure costs while optimizing latency performance. This functionality uses serverless compute, which means that the endpoints and associated compute resources are managed and run in your Databricks account. Serverless Real-Time Inference offers high availability and scalability, with the ability to support up to 3000 queries-per-second (QPS). It also provides dashboards to monitor the health of your model endpoints using metrics such as QPS, latency, and error rate. Additionally, it offers feature store integration, allowing models to be packaged with feature metadata and incorporated in real-time as scoring requests are received. However, during the public preview, there are certain limitations such as a payload size limit of 16 MB per request, a default limit of 200 QPS of scoring requests per workspace enrolled, and best effort support on less than 100 millisecond latency overhead and availability. It is also important to note that Serverless Real-Time Inference endpoints are open to the internet for inbound traffic unless an IP allowlist is enabled in the workspace.", "question": "How does Databricks' machine learning platform support real-time model serving and inference for low-latency applications?"}
+{"base": 7, "answer": "Databricks' machine learning platform offers robust model monitoring and logging capabilities through Lakehouse Monitoring. This feature allows you to monitor statistical properties, such as data drift and model performance, of input data and model predictions. Here's what you can do with Lakehouse Monitoring:\n\n1. **Data Ingestion**: The pipeline reads in logs from batch, streaming, or online inference, allowing you to monitor various types of data inputs.\n\n2. **Check Accuracy and Data Drift**: Compute metrics about the input data, the model\u2019s predictions, and the infrastructure performance. Data scientists can specify data and model metrics during development, and ML engineers can specify infrastructure metrics. You can also define custom metrics with Lakehouse Monitoring.\n\n3. **Publish Metrics and Set Up Alerts**: The pipeline writes to tables in the production catalog for analysis and reporting. You can configure these tables to be readable from the development environment so data scientists have access for analysis. You can use Databricks SQL to create monitoring dashboards to track model performance, and set up the monitoring job or the dashboard tool to issue a notification when a metric exceeds a specified threshold.\n\n4. **Trigger Model Retraining**: When monitoring metrics indicate performance issues or changes in the input data, the data scientist may need to develop a new model version. You can set up SQL alerts to notify data scientists when this happens.\n\n5. **Retraining**: Databricks supports both scheduled and triggered retraining. Scheduled retraining can be set up if new data is available on a regular basis. Triggered retraining can be initiated when the monitoring pipeline identifies model performance issues and sends alerts.\n\n6. **Dashboard Creation**: You can use Databricks SQL to create monitoring dashboards to track model performance and set up alerts when a metric exceeds a specified threshold.\n\nThese capabilities help you ensure the quality and consistency of your data and models over time, and quickly identify and address any changes or issues.", "question": "What kind of model monitoring and logging capabilities are available in Databricks' machine learning platform for tracking model performance and data drift?"}
+{"base": 8, "answer": "Yes, Databricks' machine learning platform can be used for computer vision tasks, such as image classification and object detection. Databricks provides a reference solution for distributed image model inference, which can be found in the \"Image processing and computer vision\" article. This reference solution is based on a common setup shared by many real-world image applications.\n\nIn terms of libraries and tools, Databricks Runtime for Machine Learning includes popular deep learning libraries like TensorFlow, PyTorch, and Keras, which can be used for computer vision tasks. Additionally, Databricks provides pre-configured GPU support with drivers and supporting libraries, which can be beneficial for computer vision applications that require significant computational resources.\n\nFor more information on using Databricks for computer vision tasks, you can refer to the \"Image processing and computer vision\" article and the \"Reference solution for image applications\" provided in the context.", "question": "Can Databricks' machine learning platform be used for computer vision tasks, such as image classification and object detection, and if so, what libraries and tools are supported?"}
+{"base": 9, "answer": "Databricks' machine learning platform supports ensemble methods and stacking for combining multiple models through various built-in tools and features. Although the provided context does not explicitly mention ensemble methods or stacking, it does discuss the tools and capabilities that enable these techniques. Here's how Databricks supports ensemble methods and stacking:\n\n1. **Databricks AutoML**: Databricks AutoML can automatically train and tune multiple models, making it easier to create an ensemble of models. You can then use techniques like stacking or voting to combine the predictions of these models.\n\n2. **MLflow**: MLflow is a platform for managing the end-to-end machine learning lifecycle. It includes tools for tracking, packaging, and deploying models. With MLflow, you can easily manage and deploy multiple models, which is essential for ensemble methods and stacking.\n\n3. **Databricks Feature Store**: The Databricks Feature Store allows you to store, manage, and discover features for model training. By having a centralized feature store, you can ensure that all your models use the same features, which is important for ensemble methods and stacking.\n\n4. **Lakehouse Monitoring**: Lakehouse Monitoring helps you track model prediction quality and drift. This is crucial for ensemble methods and stacking, as it allows you to monitor the performance of each individual model and the ensemble as a whole.\n\n5. **Databricks Workflows**: Databricks Workflows enable you to create automated workflows and production-ready ETL pipelines. You can use workflows to automate the process of training, evaluating, and combining multiple models using ensemble methods or stacking.\n\n6. **Databricks Git folders**: Databricks Git folders provide code management and Git integration. This enables you to manage the code for each individual model and the ensemble or stacking code in a version-controlled environment.\n\n7. **Support for popular machine learning libraries**: Databricks Runtime for Machine Learning includes popular machine learning libraries like TensorFlow, PyTorch, and Keras. These libraries can be used to implement various ensemble methods and stacking techniques.\n\nIn summary, Databricks' machine learning platform provides a comprehensive set of tools and features that support ensemble methods and stacking for combining multiple models. These tools include Databricks AutoML, MLflow, Databricks Feature Store, Lakehouse Monitoring, Databricks Workflows, Databricks Git folders, and support for popular machine learning libraries.", "question": "How does Databricks' machine learning platform support ensemble methods and stacking for combining multiple models?"}
+{"base": 10, "answer": "Yes, Databricks' machine learning platform can be used for time series forecasting and anomaly detection. For time series forecasting, Databricks AutoML supports the ARIMA model in addition to Prophet. You can set up a forecasting problem using the AutoML UI and specify the time column and forecast horizon. For anomaly detection, Databricks does not explicitly mention any specific algorithms or techniques in the provided context, but it does mention that you can use existing feature tables from Databricks Feature Store to expand the input training dataset for your classification and regression problems, which could potentially be used for anomaly detection. Additionally, Databricks supports the use of popular machine learning libraries such as TensorFlow, PyTorch, and XGBoost, which could be used for anomaly detection.", "question": "Can Databricks' machine learning platform be used for time series forecasting and anomaly detection, and if so, what algorithms and techniques are supported?"}
diff --git a/data/synthetic_questions_100.csv b/data/synthetic_questions_100.csv
new file mode 100644
index 0000000..0cc8f35
--- /dev/null
+++ b/data/synthetic_questions_100.csv
@@ -0,0 +1,101 @@
+question,base
+What automated hyperparameter tuning methods are available on Databricks' machine learning platform?,1
+How does Databricks' platform optimize hyperparameters for machine learning models?,1
+Can you describe the hyperparameter tuning process on Databricks' machine learning platform?,1
+What tools and libraries does Databricks' platform provide for automated hyperparameter tuning?,1
+How does Databricks' platform support Bayesian optimization for hyperparameter tuning?,1
+What is the approach used by Databricks' platform for grid search and random search hyperparameter tuning?,1
+Can you outline the automated hyperparameter tuning workflow on Databricks' machine learning platform?,1
+How does Databricks' platform integrate with other libraries and frameworks for hyperparameter tuning?,1
+What are the benefits of using Databricks' platform for automated hyperparameter tuning?,1
+Can you provide an example of how to use Databricks' platform for automated hyperparameter tuning in a machine learning project?,1
+What capabilities does Databricks' machine learning platform offer for large-scale deep learning?,2
+How does Databricks' platform support the training and deployment of deep learning models at scale?,2
+Can Databricks' platform handle the computational demands of large-scale deep learning workloads?,2
+What distributed computing capabilities does Databricks' platform offer for deep learning?,2
+How does Databricks' platform optimize deep learning workloads for performance and scalability?,2
+What is the architecture of Databricks' platform for supporting large-scale deep learning?,2
+"Can you describe the deep learning workflow on Databricks' platform, from data preparation to model deployment?",2
+"How does Databricks' platform integrate with popular deep learning frameworks, such as TensorFlow and PyTorch?",2
+"What are the benefits of using Databricks' platform for large-scale deep learning, compared to other solutions?",2
+Can you provide an example of a large-scale deep learning project that was successfully deployed on Databricks' platform?,2
+What feature engineering techniques are supported by Databricks' machine learning platform?,3
+How does Databricks' platform handle data preprocessing and feature extraction?,3
+What data transformation and normalization capabilities are available in Databricks' machine learning platform?,3
+Can you describe the data preprocessing workflow on Databricks' platform?,3
+What feature selection and dimensionality reduction techniques are supported by Databricks' platform?,3
+How does Databricks' platform handle missing data and outliers in feature engineering?,3
+What data encoding and scaling capabilities are available in Databricks' machine learning platform?,3
+Can you outline the feature engineering and data preprocessing tools available in Databricks' platform?,3
+How does Databricks' platform support the creation of custom feature engineering and data preprocessing pipelines?,3
+"What are the benefits of using Databricks' platform for feature engineering and data preprocessing, compared to other solutions?",3
+What model interpretability and explainability techniques are supported by Databricks' machine learning platform?,4
+How does Databricks' platform provide insights into model behavior and decision-making processes?,4
+Can you describe the SHAP and LIME implementation on Databricks' machine learning platform?,4
+What visualization tools are available on Databricks' platform for model interpretability and explainability?,4
+How does Databricks' platform support the use of feature importance and partial dependence plots for model interpretability?,4
+Can you outline the model explainability workflow on Databricks' machine learning platform?,4
+"What are the benefits of using Databricks' platform for model interpretability and explainability, compared to other solutions?",4
+"How does Databricks' platform integrate with popular model interpretability and explainability libraries, such as SHAP and LIME?",4
+Can you provide an example of how to use Databricks' platform for model interpretability and explainability in a real-world scenario?,4
+"What are the limitations and challenges of using Databricks' platform for model interpretability and explainability, and how can they be addressed?",4
+What NLP capabilities are available on Databricks' machine learning platform?,5
+Can Databricks' platform be used for text analysis and processing tasks?,5
+What NLP libraries and frameworks are supported by Databricks' machine learning platform?,5
+How does Databricks' platform support the use of word embeddings and language models for NLP tasks?,5
+Can you describe the NLP workflow on Databricks' machine learning platform?,5
+"What are the benefits of using Databricks' platform for NLP tasks, compared to other solutions?",5
+"How does Databricks' platform integrate with popular NLP libraries, such as NLTK and spaCy?",5
+"Can you provide an example of how to use Databricks' platform for NLP tasks, such as text classification and sentiment analysis?",5
+"What are the limitations and challenges of using Databricks' platform for NLP tasks, and how can they be addressed?",5
+"How does Databricks' platform support the use of deep learning models for NLP tasks, such as recurrent neural networks (RNNs) and transformers?",5
+What real-time model serving capabilities are available on Databricks' machine learning platform?,6
+How does Databricks' platform support low-latency model inference for real-time applications?,6
+Can you describe the architecture of Databricks' platform for real-time model serving and inference?,6
+"What are the benefits of using Databricks' platform for real-time model serving and inference, compared to other solutions?",6
+How does Databricks' platform support the use of GPU acceleration for real-time model inference?,6
+Can you outline the workflow for deploying and managing real-time models on Databricks' platform?,6
+What are the latency and throughput characteristics of Databricks' platform for real-time model serving and inference?,6
+"How does Databricks' platform integrate with popular model serving frameworks, such as TensorFlow Serving and AWS SageMaker?",6
+Can you provide an example of how to use Databricks' platform for real-time model serving and inference in a production environment?,6
+What are the scalability and reliability features of Databricks' platform for real-time model serving and inference?,6
+What model monitoring and logging features are available in Databricks' machine learning platform?,7
+How does Databricks' platform support the tracking of model performance and data drift?,7
+Can you describe the model monitoring workflow on Databricks' machine learning platform?,7
+What kind of metrics and logs are available for model monitoring and logging on Databricks' platform?,7
+How does Databricks' platform support the use of dashboards and visualizations for model monitoring and logging?,7
+Can you outline the steps for setting up model monitoring and logging on Databricks' platform?,7
+"What are the benefits of using Databricks' platform for model monitoring and logging, compared to other solutions?",7
+"How does Databricks' platform integrate with popular model monitoring and logging tools, such as Prometheus and Grafana?",7
+Can you provide an example of how to use Databricks' platform for model monitoring and logging in a production environment?,7
+What are the scalability and reliability features of Databricks' platform for model monitoring and logging?,7
+What computer vision capabilities are available on Databricks' machine learning platform?,8
+Can Databricks' platform be used for image classification and object detection tasks?,8
+What libraries and frameworks are supported by Databricks' platform for computer vision tasks?,8
+How does Databricks' platform support the use of deep learning models for computer vision tasks?,8
+Can you describe the computer vision workflow on Databricks' machine learning platform?,8
+"What are the benefits of using Databricks' platform for computer vision tasks, compared to other solutions?",8
+"How does Databricks' platform integrate with popular computer vision libraries, such as OpenCV and Pillow?",8
+"Can you provide an example of how to use Databricks' platform for computer vision tasks, such as image classification and object detection?",8
+What are the scalability and reliability features of Databricks' platform for computer vision tasks?,8
+How does Databricks' platform support the use of transfer learning and pre-trained models for computer vision tasks?,8
+What ensemble methods are supported by Databricks' machine learning platform?,9
+Can Databricks' platform be used to combine multiple models using stacking?,9
+How does Databricks' platform support the use of bagging and boosting for ensemble methods?,9
+What are the benefits of using ensemble methods and stacking on Databricks' machine learning platform?,9
+Can you describe the workflow for creating and deploying ensemble models on Databricks' platform?,9
+How does Databricks' platform support the use of hyperparameter tuning for ensemble methods?,9
+What are the scalability and reliability features of Databricks' platform for ensemble methods and stacking?,9
+Can you provide an example of how to use Databricks' platform for ensemble methods and stacking in a real-world scenario?,9
+"How does Databricks' platform integrate with popular ensemble libraries, such as scikit-learn and XGBoost?",9
+"What are the limitations and challenges of using ensemble methods and stacking on Databricks' machine learning platform, and how can they be addressed?",9
+What time series forecasting capabilities are available on Databricks' machine learning platform?,10
+Can Databricks' platform be used for anomaly detection in time series data?,10
+What algorithms and techniques are supported by Databricks' platform for time series forecasting and anomaly detection?,10
+"How does Databricks' platform support the use of ARIMA, SARIMA, and other traditional time series forecasting methods?",10
+Can you describe the workflow for creating and deploying time series forecasting models on Databricks' platform?,10
+"What are the benefits of using Databricks' platform for time series forecasting and anomaly detection, compared to other solutions?",10
+"How does Databricks' platform support the use of deep learning models, such as LSTM and GRU, for time series forecasting?",10
+Can you provide an example of how to use Databricks' platform for time series forecasting and anomaly detection in a real-world scenario?,10
+What are the scalability and reliability features of Databricks' platform for time series forecasting and anomaly detection?,10
+"How does Databricks' platform integrate with popular time series libraries, such as pandas and statsmodels?",10
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..66b4995
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+mlflow>=2.16.2
+cloudpickle==2.2.1
+configparser==5.2.0
+cryptography==41.0.3
+databricks-agents==0.5.0
+databricks-vectorsearch==0.40
+google-cloud-storage==2.10.0
+grpcio-status==1.60.0
+langchain-community==0.2.4
+langchain==0.2.1
+numpy==1.23.5
+pandas==1.5.3
+pyarrow==14.0.1
+pydantic==1.10.6
+pyspark==3.5.0
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..3aa349b
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,220 @@
+from pyspark.sql.functions import pandas_udf
+import pandas as pd
+import pyspark.sql.functions as F
+from pyspark.sql.functions import col, udf, length, pandas_udf
+import os
+import mlflow
+import yaml
+import time
+from typing import Iterator
+from mlflow import MlflowClient
+mlflow.set_registry_uri('databricks-uc')
+
+
+########################################################################
+###### Functions for setting up vector search index for RAG and cache
+########################################################################
+def vs_endpoint_exists(vsc, vs_endpoint_name):
+ '''Check if a vector search endpoint exists'''
+ try:
+ return vs_endpoint_name in [e['name'] for e in vsc.list_endpoints().get('endpoints', [])]
+ except Exception as e:
+ #Temp fix for potential REQUEST_LIMIT_EXCEEDED issue
+ if "REQUEST_LIMIT_EXCEEDED" in str(e):
+ print("WARN: couldn't get endpoint status due to REQUEST_LIMIT_EXCEEDED error. The demo will consider it exists")
+ return True
+ else:
+ raise e
+
+
+def create_or_wait_for_endpoint(vsc, vs_endpoint_name):
+ '''Create a vector search endpoint if it doesn't exist. If it does exist, wait for it to be ready'''
+ if not vs_endpoint_exists(vsc, vs_endpoint_name):
+ vsc.create_endpoint(name=vs_endpoint_name, endpoint_type="STANDARD")
+ wait_for_vs_endpoint_to_be_ready(vsc, vs_endpoint_name)
+
+
+def wait_for_vs_endpoint_to_be_ready(vsc, vs_endpoint_name):
+ '''Wait for a vector search endpoint to be ready'''
+ for i in range(180):
+ try:
+ endpoint = vsc.get_endpoint(vs_endpoint_name)
+ except Exception as e:
+ #Temp fix for potential REQUEST_LIMIT_EXCEEDED issue
+ if "REQUEST_LIMIT_EXCEEDED" in str(e):
+ print("WARN: couldn't get endpoint status due to REQUEST_LIMIT_EXCEEDED error. Please manually check your endpoint status")
+ return
+ else:
+ raise e
+ status = endpoint.get("endpoint_status", endpoint.get("status"))["state"].upper()
+ if "ONLINE" in status:
+ return endpoint
+ elif "PROVISIONING" in status or i <6:
+ if i % 20 == 0:
+ print(f"Waiting for endpoint to be ready, this can take a few min... {endpoint}")
+ time.sleep(10)
+ else:
+ raise Exception(f'''Error with the endpoint {vs_endpoint_name}. - this shouldn't happen: {endpoint}.\n Please delete it and re-run the previous cell: vsc.delete_endpoint("{vs_endpoint_name}")''')
+ raise Exception(f"Timeout, your endpoint isn't ready yet: {vsc.get_endpoint(vs_endpoint_name)}")
+
+
+def delete_endpoint(vsc, vs_endpoint_name):
+ '''Delete a vector search endpoint'''
+ print(f"Deleting endpoint {vs_endpoint_name}...")
+ try:
+ vsc.delete_endpoint(vs_endpoint_name)
+ print(f"Endpoint {vs_endpoint_name} deleted successfully")
+ except Exception as e:
+ print(f"Error deleting endpoint {vs_endpoint_name}: {str(e)}")
+
+
+def index_exists(vsc, vs_endpont_name, vs_index_name):
+ '''Check if a vector search index exists'''
+ try:
+ vsc.get_index(vs_endpont_name, vs_index_name).describe()
+ return True
+ except Exception as e:
+ if 'RESOURCE_DOES_NOT_EXIST' not in str(e):
+ print(f'Unexpected error describing the index. This could be a permission issue.')
+ raise e
+ return False
+
+
+def wait_for_index_to_be_ready(vsc, vs_endpoint_name, vs_index_fullname):
+ '''Wait for a vector search index to be ready'''
+ for i in range(180):
+ idx = vsc.get_index(vs_endpoint_name, vs_index_fullname).describe()
+ index_status = idx.get('status', idx.get('index_status', {}))
+ status = index_status.get('detailed_state', index_status.get('status', 'UNKNOWN')).upper()
+ url = index_status.get('index_url', index_status.get('url', 'UNKNOWN'))
+ if "ONLINE" in status:
+ return
+ if "UNKNOWN" in status:
+ print(f"Can't get the status - will assume index is ready {idx} - url: {url}")
+ return
+ elif "PROVISIONING" in status:
+ if i % 40 == 0: print(f"Waiting for index to be ready, this can take a few min... {index_status} - pipeline url:{url}")
+ time.sleep(10)
+ else:
+ raise Exception(f'''Error with the index - this shouldn't happen. DLT pipeline might have been killed.\n Please delete it and re-run the previous cell: vsc.delete_index("{vs_index_fullname}, {vs_endpoint_name}") \nIndex details: {idx}''')
+ raise Exception(f"Timeout, your index isn't ready yet: {vsc.get_index(vs_index_fullname, vs_endpoint_name)}")
+
+
+def create_or_update_direct_index(vsc, vs_endpoint_name, vs_index_fullname, vector_search_index_schema, vector_search_index_config):
+ '''Create a direct access vector search index if it doesn't exist. If it does exist, update it.'''
+ try:
+ vsc.create_direct_access_index(
+ endpoint_name=vs_endpoint_name,
+ index_name=vs_index_fullname,
+ schema=vector_search_index_schema,
+ **vector_search_index_config
+ )
+ except Exception as e:
+ if 'RESOURCE_ALREADY_EXISTS' not in str(e):
+ print(f'Unexpected error...')
+ raise e
+ wait_for_index_to_be_ready(vsc, vs_endpoint_name, vs_index_fullname)
+ print(f"index {vs_index_fullname} is ready")
+
+
+#######################################################################
+###### Functions for deploying a chain in Model Serving
+#######################################################################
+def get_latest_model_version(model_name):
+ '''Get the latest model version for a given model name'''
+ mlflow_client = MlflowClient(registry_uri="databricks-uc")
+ latest_version = 1
+ for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
+ version_int = int(mv.version)
+ if version_int > latest_version:
+ latest_version = version_int
+ return latest_version
+
+
+def deploy_model_serving_endpoint(
+ spark,
+ model_full_name,
+ catalog,
+ logging_schema,
+ endpoint_name,
+ host,
+ token,
+ ):
+ '''Deploy a model serving endpoint'''
+ from mlflow.deployments import get_deploy_client
+ client = get_deploy_client("databricks")
+ _config = {
+ "served_models": [{
+ "model_name": model_full_name,
+ "model_version": get_latest_model_version(model_full_name),
+ "workload_type": "CPU",
+ "workload_size": "Small",
+ "scale_to_zero_enabled": "true",
+ "environment_vars": {
+ "HOST": host,
+ "TOKEN": token,
+ "ENABLE_MLFLOW_TRACING": "true",
+ }
+ }],
+ "auto_capture_config": {
+ "catalog_name": catalog,
+ "schema_name": logging_schema,
+ "table_name_prefix": endpoint_name,
+ }
+ }
+ try:
+ r = client.get_endpoint(endpoint_name)
+ endpoint = client.update_endpoint(
+ endpoint="chat",
+ config=_config,
+ )
+ except:
+ # Make sure to the schema for the inference table exists
+ _ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{logging_schema}")
+ # Make sure to drop the inference table it exists
+ _ = spark.sql(f"DROP TABLE IF EXISTS {catalog}.{logging_schema}.`{endpoint_name}_payload`")
+
+ endpoint = client.create_endpoint(
+ name = endpoint_name,
+ config = _config,
+ )
+
+
+def wait_for_model_serving_endpoint_to_be_ready(endpoint_name):
+ '''Wait for a model serving endpoint to be ready'''
+ from databricks.sdk import WorkspaceClient
+ from databricks.sdk.service.serving import EndpointStateReady, EndpointStateConfigUpdate
+ import time
+
+ # Wait for it to be ready
+ w = WorkspaceClient()
+ state = ""
+ for i in range(200):
+ state = w.serving_endpoints.get(endpoint_name).state
+ if state.config_update == EndpointStateConfigUpdate.IN_PROGRESS:
+ if i % 40 == 0:
+ print(f"Waiting for endpoint to deploy {endpoint_name}. Current state: {state}")
+ time.sleep(10)
+ elif state.ready == EndpointStateReady.READY:
+ print('endpoint ready.')
+ return
+ else:
+ break
+ raise Exception(f"Couldn't start the endpoint, timeout, please check your endpoint for more details: {state}")
+
+
+def send_request_to_endpoint(endpoint_name, data):
+ '''Send a request to a model serving endpoint'''
+ from mlflow.deployments import get_deploy_client
+ client = get_deploy_client("databricks")
+ response = client.predict(endpoint=endpoint_name, inputs=data)
+ return response
+
+
+def delete_model_serving_endpoint(endpoint_name):
+ '''Delete a model serving endpoint'''
+ from mlflow.deployments import get_deploy_client
+ client = get_deploy_client("databricks")
+ r = client.delete_endpoint(endpoint_name)
+
+
\ No newline at end of file