diff --git a/docs/docs/integrations/text_embedding/databricks.ipynb b/docs/docs/integrations/text_embedding/databricks.ipynb index b5438845823c8..f447d6c71345e 100644 --- a/docs/docs/integrations/text_embedding/databricks.ipynb +++ b/docs/docs/integrations/text_embedding/databricks.ipynb @@ -1,22 +1,34 @@ { "cells": [ { - "attachments": {}, + "cell_type": "raw", + "id": "afaf8039", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: Databricks\n", + "---" + ] + }, + { "cell_type": "markdown", + "id": "9a3d6f34", "metadata": {}, "source": [ - "# Databricks\n", + "# DatabricksEmbeddings\n", "\n", "> [Databricks](https://www.databricks.com/) Lakehouse Platform unifies data, analytics, and AI on one platform.\n", "\n", - "This notebook provides a quick overview for getting started with Databricks [embedding models](/docs/concepts/#embedding-models). For detailed documentation of all DatabricksEmbeddings features and configurations head to the [API reference](https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.databricks.DatabricksEmbeddings.html).\n", + "This notebook provides a quick overview for getting started with Databricks [embedding models](/docs/concepts/#embedding-models). For detailed documentation of all `DatabricksEmbeddings` features and configurations head to the [API reference](https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.databricks.DatabricksEmbeddings.html).\n", "\n", "\n", "\n", "## Overview\n", + "### Integration details\n", "\n", - "`DatabricksEmbeddings` class wraps an embedding model endpoint hosted on [Databricks Model Serving](https://docs.databricks.com/en/machine-learning/model-serving/index.html). This example notebook shows how to wrap your serving endpoint and use it as a embedding model in your LangChain application.\n", - "\n", + "| Class | Package |\n", + "| :--- | :--- |\n", + "| [DatabricksEmbeddings](https://api.python.langchain.com/en/latest/embeddings/langchain_databricks.embeddings.DatabricksEmbeddings.html) | [langchain-databricks](https://api.python.langchain.com/en/latest/databricks_api_reference.html) |\n", "\n", "### Supported Methods\n", "\n", @@ -30,13 +42,9 @@ "1. Foundation Models - Curated list of state-of-the-art foundation models such as BAAI General Embedding (BGE). These endpoint are ready to use in your Databricks workspace without any set up.\n", "2. Custom Models - You can also deploy custom embedding models to a serving endpoint via MLflow with\n", "your choice of framework such as LangChain, Pytorch, Transformers, etc.\n", - "3. External Models - Databricks endpoints can serve models that are hosted outside Databricks as a proxy, such as proprietary model service like OpenAI text-embedding-3.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ + "3. External Models - Databricks endpoints can serve models that are hosted outside Databricks as a proxy, such as proprietary model service like OpenAI text-embedding-3.\n", + "\n", + "\n", "## Setup\n", "\n", "To access Databricks models you'll need to create a Databricks account, set up credentials (only if you are outside Databricks workspace), and install required packages.\n", @@ -51,6 +59,7 @@ { "cell_type": "code", "execution_count": null, + "id": "36521c2a", "metadata": {}, "outputs": [], "source": [ @@ -63,33 +72,27 @@ }, { "cell_type": "markdown", + "id": "d9664366", "metadata": {}, "source": [ "### Installation\n", "\n", - "The LangChain Databricks integration lives in the `langchain-community` package. Also, `mlflow >= 2.9 ` is required to run the code in this notebook." + "The LangChain Databricks integration lives in the `langchain-databricks` package:" ] }, { "cell_type": "code", "execution_count": null, + "id": "64853226", "metadata": {}, "outputs": [], "source": [ - "%pip install -qU langchain-community mlflow>=2.9.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We first demonstrates how to query BGE model hosted as Foundation Models endpoint with `DatabricksEmbeddings`.\n", - "\n", - "For other type of endpoints, there are some difference in how to set up the endpoint itself, however, once the endpoint is ready, there is no difference in how to query it." + "%pip install -qU langchain-databricks" ] }, { "cell_type": "markdown", + "id": "45dd1724", "metadata": {}, "source": [ "## Instantiation" @@ -98,10 +101,11 @@ { "cell_type": "code", "execution_count": null, + "id": "9ea7a09b", "metadata": {}, "outputs": [], "source": [ - "from langchain_community.embeddings import DatabricksEmbeddings\n", + "from langchain_databricks import DatabricksEmbeddings\n", "\n", "embeddings = DatabricksEmbeddings(\n", " endpoint=\"databricks-bge-large-en\",\n", @@ -113,65 +117,131 @@ }, { "cell_type": "markdown", + "id": "77d271b6", + "metadata": {}, + "source": [ + "## Indexing and Retrieval\n", + "\n", + "Embedding models are often used in retrieval-augmented generation (RAG) flows, both as part of indexing data as well as later retrieving it. For more detailed instructions, please see our RAG tutorials under the [working with external knowledge tutorials](/docs/tutorials/#working-with-external-knowledge).\n", + "\n", + "Below, see how to index and retrieve data using the `embeddings` object we initialized above. In this example, we will index and retrieve a sample document in the `InMemoryVectorStore`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d817716b", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a vector store with a sample text\n", + "from langchain_core.vectorstores import InMemoryVectorStore\n", + "\n", + "text = \"LangChain is the framework for building context-aware reasoning applications\"\n", + "\n", + "vectorstore = InMemoryVectorStore.from_texts(\n", + " [text],\n", + " embedding=embeddings,\n", + ")\n", + "\n", + "# Use the vectorstore as a retriever\n", + "retriever = vectorstore.as_retriever()\n", + "\n", + "# Retrieve the most similar text\n", + "retrieved_document = retriever.invoke(\"What is LangChain?\")\n", + "\n", + "# show the retrieved document's content\n", + "retrieved_document[0].page_content" + ] + }, + { + "cell_type": "markdown", + "id": "e02b9855", "metadata": {}, "source": [ - "## Embed single text" + "## Direct Usage\n", + "\n", + "Under the hood, the vectorstore and retriever implementations are calling `embeddings.embed_documents(...)` and `embeddings.embed_query(...)` to create embeddings for the text(s) used in `from_texts` and retrieval `invoke` operations, respectively.\n", + "\n", + "You can directly call these methods to get embeddings for your own use cases.\n", + "\n", + "### Embed single texts\n", + "\n", + "You can embed single texts or documents with `embed_query`:" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "id": "0d2befcd", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.051055908203125, 0.007221221923828125, 0.003879547119140625]\n" - ] - } - ], + "outputs": [], "source": [ - "embeddings.embed_query(\"hello\")[:3]" + "single_vector = embeddings.embed_query(text)\n", + "print(str(single_vector)[:100]) # Show the first 100 characters of the vector" ] }, { "cell_type": "markdown", + "id": "1b5a7d03", "metadata": {}, "source": [ - "## Embed documents" + "### Embed multiple texts\n", + "\n", + "You can embed multiple texts with `embed_documents`:" ] }, { "cell_type": "code", "execution_count": null, + "id": "2f4d6e97", "metadata": {}, "outputs": [], "source": [ - "documents = [\"This is a dummy document.\", \"This is another dummy document.\"]\n", - "response = embeddings.embed_documents(documents)\n", - "print([e[:3] for e in response]) # Show first 3 elements of each embedding" + "text2 = (\n", + " \"LangGraph is a library for building stateful, multi-actor applications with LLMs\"\n", + ")\n", + "two_vectors = embeddings.embed_documents([text, text2])\n", + "for vector in two_vectors:\n", + " print(str(vector)[:100]) # Show the first 100 characters of the vector" ] }, { "cell_type": "markdown", + "id": "98785c12", "metadata": {}, "source": [ - "## Wrapping Other Types of Endpoints\n", + "### Async Usage\n", + "\n", + "You can also use `aembed_query` and `aembed_documents` for producing embeddings asynchronously:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c3bef91", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "\n", + "async def async_example():\n", + " single_vector = await embeddings.aembed_query(text)\n", + " print(str(single_vector)[:100]) # Show the first 100 characters of the vector\n", "\n", - "The example above uses an embedding model hosted as a Foundation Models API. To learn about how to use the other endpoint types, please refer to the documentation for `ChatDatabricks`. While the model type is different, required steps are the same.\n", "\n", - "* [Custom Model Endpoint](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#wrapping-custom-model-endpoint)\n", - "* [External Models](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#wrapping-external-models)" + "asyncio.run(async_example())" ] }, { "cell_type": "markdown", + "id": "0d053b64", "metadata": {}, "source": [ - "## API reference\n", + "## API Reference\n", "\n", - "For detailed documentation of all ChatDatabricks features and configurations head to the API reference: https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.databricks.DatabricksEmbeddings.html" + "For detailed documentation on `DatabricksEmbeddings` features and configuration options, please refer to the [API reference](https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.databricks.DatabricksEmbeddings.html).\n" ] } ], @@ -191,9 +261,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.5" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/docs/docs/integrations/vectorstores/databricks_vector_search.ipynb b/docs/docs/integrations/vectorstores/databricks_vector_search.ipynb index dad7b5d0c7449..626e34568f296 100644 --- a/docs/docs/integrations/vectorstores/databricks_vector_search.ipynb +++ b/docs/docs/integrations/vectorstores/databricks_vector_search.ipynb @@ -1,139 +1,185 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "raw", + "id": "1957f5cb", "metadata": {}, "source": [ - "# Databricks Vector Search\n", - "\n", - "Databricks Vector Search is a serverless similarity search engine that allows you to store a vector representation of your data, including metadata, in a vector database. With Vector Search, you can create auto-updating vector search indexes from Delta tables managed by Unity Catalog and query them with a simple API to return the most similar vectors.\n", - "\n", - "This notebook shows how to use LangChain with Databricks Vector Search." + "---\n", + "sidebar_label: Databricks\n", + "---" ] }, { "cell_type": "markdown", + "id": "ef1f0986", "metadata": {}, "source": [ - "Install `databricks-vectorsearch` and related Python packages used in this notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain-core databricks-vectorsearch langchain-openai tiktoken" + "# DatabricksVectorSearch\n", + "\n", + "[Databricks Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html) is a serverless similarity search engine that allows you to store a vector representation of your data, including metadata, in a vector database. With Vector Search, you can create auto-updating vector search indexes from Delta tables managed by Unity Catalog and query them with a simple API to return the most similar vectors.\n", + "\n", + "This notebook shows how to use LangChain with Databricks Vector Search." ] }, { "cell_type": "markdown", + "id": "36fdc060", "metadata": {}, "source": [ - "Use `OpenAIEmbeddings` for the embeddings." + "## Setup\n", + "\n", + "To access Databricks models you'll need to create a Databricks account, set up credentials (only if you are outside Databricks workspace), and install required packages.\n", + "\n", + "### Credentials (only if you are outside Databricks)\n", + "\n", + "If you are running LangChain app inside Databricks, you can skip this step.\n", + "\n", + "Otherwise, you need manually set the Databricks workspace hostname and personal access token to `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables, respectively. See [Authentication Documentation](https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-personal-access-tokens) for how to get an access token." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "id": "5fb2788f", "metadata": {}, "outputs": [], "source": [ "import getpass\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + "os.environ[\"DATABRICKS_HOST\"] = \"https://your-databricks-workspace\"\n", + "os.environ[\"DATABRICKS_TOKEN\"] = getpass.getpass(\"Enter your Databricks access token: \")" ] }, { "cell_type": "markdown", + "id": "93df377e", "metadata": {}, "source": [ - "Split documents and get embeddings." + "### Installation\n", + "\n", + "The LangChain Databricks integration lives in the `langchain-databricks` package." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "b03d22f1", + "metadata": { + "vscode": { + "languageId": "shellscript" + } + }, "outputs": [], "source": [ - "from langchain_community.document_loaders import TextLoader\n", - "from langchain_openai import OpenAIEmbeddings\n", - "from langchain_text_splitters import CharacterTextSplitter\n", + "%pip install -qU langchain-databricks" + ] + }, + { + "cell_type": "markdown", + "id": "08c6ef75", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "### Create a Vector Search Endpoint and Index (if you haven't already)\n", "\n", - "loader = TextLoader(\"../../how_to/state_of_the_union.txt\")\n", - "documents = loader.load()\n", - "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", - "docs = text_splitter.split_documents(documents)\n", + "In this section, we will create a Databricks Vector Search endpoint and an index using the client SDK.\n", "\n", - "embeddings = OpenAIEmbeddings()\n", - "emb_dim = len(embeddings.embed_query(\"hello\"))" + "If you already have an endpoint and an index, you can skip the section and go straight to \"Instantiation\" section." ] }, { "cell_type": "markdown", - "metadata": {}, + "id": "db62918b", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, "source": [ - "## Setup Databricks Vector Search client" + "First, instantiate the Databricks VectorSearch client:" ] }, { "cell_type": "code", "execution_count": null, + "id": "c0f2957b", "metadata": {}, "outputs": [], "source": [ "from databricks.vector_search.client import VectorSearchClient\n", "\n", - "vsc = VectorSearchClient()" + "client = VectorSearchClient()" ] }, { "cell_type": "markdown", + "id": "31311046", "metadata": {}, "source": [ - "## Create a Vector Search Endpoint\n", - "This endpoint is used to create and access vector search indexes." + "Next, we will create a new VectorSearch endpoint." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, + "id": "be8f7d3a", "metadata": {}, "outputs": [], "source": [ - "vsc.create_endpoint(name=\"vector_search_demo_endpoint\", endpoint_type=\"STANDARD\")" + "endpoint_name = \"\"\n", + "\n", + "client.create_endpoint(name=endpoint_name, endpoint_type=\"STANDARD\")" + ] + }, + { + "cell_type": "markdown", + "id": "63498435", + "metadata": {}, + "source": [ + "Lastly, we will create an index that cna be queried on the endpoint. There are two types of indexes in Databricks Vector Search and the `DatabricksVectorSearch` class support both use cases.\n", + "\n", + "* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.\n", + "\n", + "* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.\n", + "\n", + "Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes)." ] }, { "cell_type": "markdown", + "id": "863d7218", "metadata": {}, "source": [ - "## Create Direct Vector Access Index\n", - "Direct Vector Access Index supports direct read and write of embedding vectors and metadata through a REST API or an SDK. For this index, you manage embedding vectors and index updates yourself." + "The following code creates a **direct-access** index. Please refer to the [Databricks documentation](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html) for the instruction to create the other type of indexes." ] }, { "cell_type": "code", "execution_count": null, + "id": "474aea5c", "metadata": {}, "outputs": [], "source": [ - "vector_search_endpoint_name = \"vector_search_demo_endpoint\"\n", - "index_name = \"vector_search_demo.vector_search.state_of_the_union_index\"\n", + "index_name = \"\" # Format: \"..\"\n", "\n", - "index = vsc.create_direct_access_index(\n", - " endpoint_name=vector_search_endpoint_name,\n", + "index = client.create_direct_access_index(\n", + " endpoint_name=endpoint_name,\n", " index_name=index_name,\n", " primary_key=\"id\",\n", - " embedding_dimension=emb_dim,\n", + " # Dimension of the embeddings. Please change according to the embedding model you are using.\n", + " embedding_dimension=3072,\n", + " # A column to store the embedding vectors for the text data\n", " embedding_vector_column=\"text_vector\",\n", " schema={\n", " \"id\": \"string\",\n", " \"text\": \"string\",\n", " \"text_vector\": \"array\",\n", + " # Optional metadata columns\n", " \"source\": \"string\",\n", " },\n", ")\n", @@ -141,90 +187,333 @@ "index.describe()" ] }, + { + "cell_type": "markdown", + "id": "979bea9b", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "## Instantiation\n", + "\n", + "The instantiation of `DatabricksVectorSearch` is a bit different depending on whether your index uses Databricks-managed embeddings or self-managed embeddings i.e. LangChain Embeddings object of your choice." + ] + }, + { + "cell_type": "markdown", + "id": "d34c1b01", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "If you are using a delta-sync index with Databricks-managed embeddings:" + ] + }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "id": "dc37144c-208d-4ab3-9f3a-0407a69fe052", + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "from langchain_community.vectorstores import DatabricksVectorSearch\n", + "from langchain_databricks.vectorstores import DatabricksVectorSearch\n", "\n", - "dvs = DatabricksVectorSearch(\n", - " index, text_column=\"text\", embedding=embeddings, columns=[\"source\"]\n", + "vector_store = DatabricksVectorSearch(\n", + " endpoint=endpoint_name,\n", + " index_name=index_name,\n", ")" ] }, { "cell_type": "markdown", + "id": "f48e4e85", "metadata": {}, "source": [ - "## Add docs to the index" + "If you are using a direct-access index or a delta-sync index with self-managed embeddings,\n", + "you also need to provide the embedding model and text column in your source table to\n", + "use for the embeddings:\n", + "\n", + "```{=mdx}\n", + "import EmbeddingTabs from \"@theme/EmbeddingTabs\";\n", + "\n", + "\n", + "```" ] }, { "cell_type": "code", "execution_count": null, + "id": "ec6288a7", "metadata": {}, "outputs": [], "source": [ - "dvs.add_documents(docs)" + "# | output: false\n", + "# | echo: false\n", + "from langchain_openai import OpenAIEmbeddings\n", + "\n", + "embeddings = OpenAIEmbeddings(model=\"text-embedding-3-large\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b1bdbdf", + "metadata": {}, + "outputs": [], + "source": [ + "vector_store = DatabricksVectorSearch(\n", + " endpoint=endpoint_name,\n", + " index_name=index_name,\n", + " embedding=embeddings,\n", + " # The column name in the index that contains the text data to be embedded\n", + " text_column=\"document_content\",\n", + ")" ] }, { "cell_type": "markdown", + "id": "ac6071d4", "metadata": {}, "source": [ - "## Similarity search\n", - "Optional keyword arguments to similarity_search include specifying k number of documents to retrive, \n", - "a filters dictionary for metadata filtering based on [this syntax](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#use-filters-on-queries),\n", - "as well as the [query_type](https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search) which can be ANN or HYBRID " + "## Manage vector store\n", + "\n", + "### Add items to vector store\n", + "\n", + "Note: Adding items to vector store via `add_documents` method is only supported for a **direct-access** index." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, + "id": "17f5efc0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['1', '2', '3']" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.documents import Document\n", + "\n", + "document_1 = Document(page_content=\"foo\", metadata={\"source\": \"https://example.com\"})\n", + "\n", + "document_2 = Document(page_content=\"bar\", metadata={\"source\": \"https://example.com\"})\n", + "\n", + "document_3 = Document(page_content=\"baz\", metadata={\"source\": \"https://example.com\"})\n", + "\n", + "documents = [document_1, document_2, document_3]\n", + "\n", + "vector_store.add_documents(documents=documents, ids=[\"1\", \"2\", \"3\"])" + ] + }, + { + "cell_type": "markdown", + "id": "dcf1b905", "metadata": {}, - "outputs": [], "source": [ - "query = \"What did the president say about Ketanji Brown Jackson\"\n", - "dvs.similarity_search(query)\n", - "print(docs[0].page_content)" + "### Delete items from vector store\n", + "\n", + "Note: Deleting items to vector store via `delete` method is only supported for a **direct-access** index." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ef61e188", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vector_store.delete(ids=[\"3\"])" ] }, { "cell_type": "markdown", + "id": "c3620501", "metadata": {}, "source": [ - "## Work with Delta Sync Index\n", + "## Query vector store\n", + "\n", + "Once your vector store has been created and the relevant documents have been added you will most likely wish to query it during the running of your chain or agent. \n", "\n", - "You can also use `DatabricksVectorSearch` to search in a Delta Sync Index. Delta Sync Index automatically syncs from a Delta table. You don't need to call `add_text`/`add_documents` manually. See [Databricks documentation page](https://docs.databricks.com/en/generative-ai/vector-search.html#delta-sync-index-with-managed-embeddings) for more details." + "### Query directly\n", + "\n", + "Performing a simple similarity search can be done as follows:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, + "id": "aa0a16fa", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* foo [{'id': '1'}]\n" + ] + } + ], "source": [ - "delta_sync_index = vsc.create_delta_sync_index(\n", - " endpoint_name=vector_search_endpoint_name,\n", - " source_table_name=\"vector_search_demo.vector_search.state_of_the_union\",\n", - " index_name=\"vector_search_demo.vector_search.state_of_the_union_index\",\n", - " pipeline_type=\"TRIGGERED\",\n", - " primary_key=\"id\",\n", - " embedding_source_column=\"text\",\n", - " embedding_model_endpoint_name=\"e5-small-v2\",\n", + "results = vector_store.similarity_search(\n", + " query=\"thud\", k=1, filter={\"source\": \"https://example.com\"}\n", + ")\n", + "for doc in results:\n", + " print(f\"* {doc.page_content} [{doc.metadata}]\")" + ] + }, + { + "cell_type": "markdown", + "id": "562056dd", + "metadata": {}, + "source": [ + "Note: By default, similarity search only returns the primary key and text column. If you want to retrieve the custom metadata associated with the document, pass the additional columns in the `columns` parameter when initializing the vector store." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "a1c746a2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* foo [{'source': 'https://example.com', 'id': '1'}]\n" + ] + } + ], + "source": [ + "vector_store = DatabricksVectorSearch(\n", + " endpoint=endpoint_name,\n", + " index_name=index_name,\n", + " embedding=embeddings,\n", + " text_column=\"text\",\n", + " columns=[\"source\"],\n", + ")\n", + "\n", + "results = vector_store.similarity_search(query=\"thud\", k=1)\n", + "for doc in results:\n", + " print(f\"* {doc.page_content} [{doc.metadata}]\")" + ] + }, + { + "cell_type": "markdown", + "id": "3ed9d733", + "metadata": {}, + "source": [ + "If you want to execute a similarity search and receive the corresponding scores you can run:" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "5efd2eaa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* [SIM=0.414035] foo [{'source': 'https://example.com', 'id': '1'}]\n" + ] + } + ], + "source": [ + "results = vector_store.similarity_search_with_score(\n", + " query=\"thud\", k=1, filter={\"source\": \"https://example.com\"}\n", ")\n", - "dvs_delta_sync = DatabricksVectorSearch(delta_sync_index)\n", - "dvs_delta_sync.similarity_search(query)" + "for doc, score in results:\n", + " print(f\"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]\")" + ] + }, + { + "cell_type": "markdown", + "id": "0c235cdc", + "metadata": {}, + "source": [ + "### Query by turning into retriever\n", + "\n", + "You can also transform the vector store into a retriever for easier usage in your chains. " + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "f3460093", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(metadata={'source': 'https://example.com', 'id': '1'}, page_content='foo')]" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever = vector_store.as_retriever(search_type=\"mmr\", search_kwargs={\"k\": 1})\n", + "retriever.invoke(\"thud\")" + ] + }, + { + "cell_type": "markdown", + "id": "901c75dc", + "metadata": {}, + "source": [ + "## Usage for retrieval-augmented generation\n", + "\n", + "For guides on how to use this vector store for retrieval-augmented generation (RAG), see the following sections:\n", + "\n", + "- [Tutorials: working with external knowledge](https://python.langchain.com/v0.2/docs/tutorials/#working-with-external-knowledge)\n", + "- [How-to: Question and answer with RAG](https://python.langchain.com/v0.2/docs/how_to/#qa-with-rag)\n", + "- [Retrieval conceptual docs](https://python.langchain.com/v0.2/docs/concepts/#retrieval)" + ] + }, + { + "cell_type": "markdown", + "id": "8a27244f", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all DatabricksVectorSearch features and configurations head to the API reference: https://api.python.langchain.com/en/latest/vectorstores/langchain_databricks.vectorstores.DatabricksVectorSearch.html" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "langchain-dev", "language": "python", - "name": "python3" + "name": "langchain-dev" }, "language_info": { "codemirror_mode": { @@ -236,9 +525,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/libs/partners/databricks/langchain_databricks/__init__.py b/libs/partners/databricks/langchain_databricks/__init__.py index 3bd93de6c3ae2..7d3130780e014 100644 --- a/libs/partners/databricks/langchain_databricks/__init__.py +++ b/libs/partners/databricks/langchain_databricks/__init__.py @@ -1,6 +1,8 @@ from importlib import metadata from langchain_databricks.chat_models import ChatDatabricks +from langchain_databricks.embeddings import DatabricksEmbeddings +from langchain_databricks.vectorstores import DatabricksVectorSearch try: __version__ = metadata.version(__package__) @@ -11,5 +13,7 @@ __all__ = [ "ChatDatabricks", + "DatabricksEmbeddings", + "DatabricksVectorSearch", "__version__", ] diff --git a/libs/partners/databricks/langchain_databricks/chat_models.py b/libs/partners/databricks/langchain_databricks/chat_models.py index fa24c08415f6c..2528e97668983 100644 --- a/libs/partners/databricks/langchain_databricks/chat_models.py +++ b/libs/partners/databricks/langchain_databricks/chat_models.py @@ -15,7 +15,6 @@ Type, Union, ) -from urllib.parse import urlparse from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel @@ -50,6 +49,8 @@ from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_databricks.utils import get_deployment_client + logger = logging.getLogger(__name__) @@ -230,25 +231,7 @@ class GetPopulation(BaseModel): def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self._validate_uri() - try: - from mlflow.deployments import get_deploy_client # type: ignore - - self._client = get_deploy_client(self.target_uri) - except ImportError as e: - raise ImportError( - "Failed to create the client. Please run `pip install mlflow` to " - "install required dependencies." - ) from e - - def _validate_uri(self) -> None: - if self.target_uri == "databricks": - return - - if urlparse(self.target_uri).scheme != "databricks": - raise ValueError( - "Invalid target URI. The target URI must be a valid databricks URI." - ) + self._client = get_deployment_client(self.target_uri) @property def _default_params(self) -> Dict[str, Any]: diff --git a/libs/partners/databricks/langchain_databricks/embeddings.py b/libs/partners/databricks/langchain_databricks/embeddings.py new file mode 100644 index 0000000000000..52113763e5d3f --- /dev/null +++ b/libs/partners/databricks/langchain_databricks/embeddings.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, Iterator, List + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr + +from langchain_databricks.utils import get_deployment_client + + +class DatabricksEmbeddings(Embeddings, BaseModel): + """Databricks embedding model integration. + + Setup: + Install ``langchain-databricks``. + + .. code-block:: bash + + pip install -U langchain-databricks + + If you are outside Databricks, set the Databricks workspace + hostname and personal access token to environment variables: + + .. code-block:: bash + + export DATABRICKS_HOSTNAME="https://your-databricks-workspace" + export DATABRICKS_TOKEN="your-personal-access-token" + + Key init args — completion params: + endpoint: str + Name of Databricks Model Serving endpoint to query. + target_uri: str + The target URI to use. Defaults to ``databricks``. + query_params: Dict[str, str] + The parameters to use for queries. + documents_params: Dict[str, str] + The parameters to use for documents. + + Instantiate: + .. code-block:: python + from langchain_databricks import DatabricksEmbeddings + embed = DatabricksEmbeddings( + endpoint="databricks-bge-large-en", + ) + + Embed single text: + .. code-block:: python + input_text = "The meaning of life is 42" + embed.embed_query(input_text) + + .. code-block:: python + [ + 0.01605224609375, + -0.0298309326171875, + ... + ] + + """ + + endpoint: str + """The endpoint to use.""" + target_uri: str = "databricks" + """The parameters to use for queries.""" + query_params: Dict[str, Any] = {} + """The parameters to use for documents.""" + documents_params: Dict[str, Any] = {} + """The target URI to use.""" + _client: Any = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self._client = get_deployment_client(self.target_uri) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self._embed(texts, params=self.documents_params) + + def embed_query(self, text: str) -> List[float]: + return self._embed([text], params=self.query_params)[0] + + def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]: + embeddings: List[List[float]] = [] + for txt in _chunk(texts, 20): + resp = self._client.predict( + endpoint=self.endpoint, + inputs={"input": txt, **params}, # type: ignore[arg-type] + ) + embeddings.extend(r["embedding"] for r in resp["data"]) + return embeddings + + +def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: + for i in range(0, len(texts), size): + yield texts[i : i + size] diff --git a/libs/partners/databricks/langchain_databricks/utils.py b/libs/partners/databricks/langchain_databricks/utils.py new file mode 100644 index 0000000000000..33e160a05bedc --- /dev/null +++ b/libs/partners/databricks/langchain_databricks/utils.py @@ -0,0 +1,101 @@ +from typing import Any, List, Union +from urllib.parse import urlparse + +import numpy as np + + +def get_deployment_client(target_uri: str) -> Any: + if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"): + raise ValueError( + "Invalid target URI. The target URI must be a valid databricks URI." + ) + + try: + from mlflow.deployments import get_deploy_client # type: ignore[import-untyped] + + return get_deploy_client(target_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + "Please run `pip install mlflow` to install " + "required dependencies." + ) from e + + +# Utility function for Maximal Marginal Relevance (MMR) reranking. +# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency +Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + + +def maximal_marginal_relevance( + query_embedding: np.ndarray, + embedding_list: list, + lambda_mult: float = 0.5, + k: int = 4, +) -> List[int]: + """Calculate maximal marginal relevance. + + Args: + query_embedding: Query embedding. + embedding_list: List of embeddings to select from. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of indices of embeddings selected by maximal marginal relevance. + """ + if min(k, len(embedding_list)) <= 0: + return [] + if query_embedding.ndim == 1: + query_embedding = np.expand_dims(query_embedding, axis=0) + similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): + best_score = -np.inf + idx_to_add = -1 + similarity_to_selected = cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): + if i in idxs: + continue + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) + if equation_score > best_score: + best_score = equation_score + idx_to_add = i + idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + return idxs + + +def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices. + + Raises: + ValueError: If the number of columns in X and Y are not the same. + """ + if len(X) == 0 or len(Y) == 0: + return np.array([]) + + X = np.array(X) + Y = np.array(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError( + "Number of columns in X and Y must be the same. X has shape" + f"{X.shape} " + f"and Y has shape {Y.shape}." + ) + + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity diff --git a/libs/partners/databricks/langchain_databricks/vectorstores.py b/libs/partners/databricks/langchain_databricks/vectorstores.py new file mode 100644 index 0000000000000..7359dcf9ab50e --- /dev/null +++ b/libs/partners/databricks/langchain_databricks/vectorstores.py @@ -0,0 +1,837 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from enum import Enum +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, +) + +import numpy as np +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VST, VectorStore + +from langchain_databricks.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + + +class IndexType(str, Enum): + DIRECT_ACCESS = "DIRECT_ACCESS" + DELTA_SYNC = "DELTA_SYNC" + + +_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index." +_NON_MANAGED_EMB_ONLY_MSG = ( + "`%s` is not supported for index with Databricks-managed embeddings." +) + + +class DatabricksVectorSearch(VectorStore): + """Databricks vector store integration. + + Setup: + Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages. + + .. code-block:: bash + + pip install -U langchain-databricks databricks-vectorsearch + + If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html + + If you are outside Databricks, set the Databricks workspace + hostname and personal access token to environment variables: + + .. code-block:: bash + + export DATABRICKS_HOSTNAME="https://your-databricks-workspace" + export DATABRICKS_TOKEN="your-personal-access-token" + + Key init args — indexing params: + + endpoint: The name of the Databricks Vector Search endpoint. + index_name: The name of the index to use. Format: "catalog.schema.index". + embedding: The embedding model. + Required for direct-access index or delta-sync index + with self-managed embeddings. + text_column: The name of the text column to use for the embeddings. + Required for direct-access index or delta-sync index + with self-managed embeddings. + Make sure the text column specified is in the index. + columns: The list of column names to get when doing the search. + Defaults to ``[primary_key, text_column]``. + + Instantiate: + + `DatabricksVectorSearch` supports two types of indexes: + + * **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes. + + * **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK. + + Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes). + + If you are using a delta-sync index with Databricks-managed embeddings: + + .. code-block:: python + + from langchain_databricks.vectorstores import DatabricksVectorSearch + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="" + ) + + If you are using a direct-access index or a delta-sync index with self-managed embeddings, + you also need to provide the embedding model and text column in your source table to + use for the embeddings: + + .. code-block:: python + + from langchain_openai import OpenAIEmbeddings + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + embedding=OpenAIEmbeddings(), + text_column="document_content" + ) + + Add Documents: + .. code-block:: python + from langchain_core.documents import Document + + document_1 = Document(page_content="foo", metadata={"baz": "bar"}) + document_2 = Document(page_content="thud", metadata={"bar": "baz"}) + document_3 = Document(page_content="i will be deleted :(") + documents = [document_1, document_2, document_3] + ids = ["1", "2", "3"] + vector_store.add_documents(documents=documents, ids=ids) + + Delete Documents: + .. code-block:: python + vector_store.delete(ids=["3"]) + + .. note:: + + The `delete` method is only supported for direct-access index. + + Search: + .. code-block:: python + results = vector_store.similarity_search(query="thud",k=1) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * thud [{'id': '2'}] + + .. note: + + By default, similarity search only returns the primary key and text column. + If you want to retrieve the custom metadata associated with the document, + pass the additional columns in the `columns` parameter when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + columns=["baz", "bar"], + ) + + vector_store.similarity_search(query="thud",k=1) + # Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}] + + Search with filter: + .. code-block:: python + results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"}) + for doc in results: + print(f"* {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * thud [{'id': '2'}] + + Search with score: + .. code-block:: python + results = vector_store.similarity_search_with_score(query="qux",k=1) + for doc, score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * [SIM=0.748804] foo [{'id': '1'}] + + Async: + .. code-block:: python + # add documents + await vector_store.aadd_documents(documents=documents, ids=ids) + # delete documents + await vector_store.adelete(ids=["3"]) + # search + results = vector_store.asimilarity_search(query="thud",k=1) + # search with score + results = await vector_store.asimilarity_search_with_score(query="qux",k=1) + for doc,score in results: + print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]") + .. code-block:: python + * [SIM=0.748807] foo [{'id': '1'}] + + Use as Retriever: + .. code-block:: python + retriever = vector_store.as_retriever( + search_type="mmr", + search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5}, + ) + retriever.invoke("thud") + .. code-block:: python + [Document(metadata={'id': '2'}, page_content='thud')] + """ # noqa: E501 + + def __init__( + self, + endpoint: str, + index_name: str, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + columns: Optional[List[str]] = None, + ): + try: + from databricks.vector_search.client import ( # type: ignore[import] + VectorSearchClient, + ) + except ImportError as e: + raise ImportError( + "Could not import databricks-vectorsearch python package. " + "Please install it with `pip install databricks-vectorsearch`." + ) from e + + self.index = VectorSearchClient().get_index(endpoint, index_name) + self._index_details = IndexDetails(self.index) + + _validate_embedding(embedding, self._index_details) + self._embeddings = embedding + self._text_column = _validate_and_get_text_column( + text_column, self._index_details + ) + self._columns = _validate_and_get_return_columns( + columns or [], self._text_column, self._index_details + ) + self._primary_key = self._index_details.primary_key + + @property + def embeddings(self) -> Optional[Embeddings]: + """Access the query embedding object if available.""" + return self._embeddings + + @classmethod + def from_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VST: + raise NotImplementedError( + "`from_texts` is not supported. " + "Use `add_texts` to add to existing direct-access index." + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict]] = None, + ids: Optional[List[Any]] = None, + **kwargs: Any, + ) -> List[str]: + """Add texts to the index. + + .. note:: + + This method is only supported for a direct-access index. + + Args: + texts: List of texts to add. + metadatas: List of metadata for each text. Defaults to None. + ids: List of ids for each text. Defaults to None. + If not provided, a random uuid will be generated for each text. + + Returns: + List of ids from adding the texts into the index. + """ + if self._index_details.is_delta_sync_index(): + raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts") + + # Wrap to list if input texts is a single string + if isinstance(texts, str): + texts = [texts] + texts = list(texts) + vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr] + ids = ids or [str(uuid.uuid4()) for _ in texts] + metadatas = metadatas or [{} for _ in texts] + + updates = [ + { + self._primary_key: id_, + self._text_column: text, + self._index_details.embedding_vector_column["name"]: vector, + **metadata, + } + for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) + ] + + upsert_resp = self.index.upsert(updates) + if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): + failed_ids = upsert_resp.get("result", dict()).get( + "failed_primary_keys", [] + ) + if upsert_resp.get("status") == "FAILURE": + logger.error("Failed to add texts to the index.") + else: + logger.warning("Some texts failed to be added to the index.") + return [id_ for id_ in ids if id_ not in failed_ids] + + return ids + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.add_texts, **kwargs), texts, metadatas + ) + + def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: + """Delete documents from the index. + + .. note:: + + This method is only supported for a direct-access index. + + Args: + ids: List of ids of documents to delete. + + Returns: + True if successful. + """ + if self._index_details.is_delta_sync_index(): + raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete") + + if ids is None: + raise ValueError("ids must be provided.") + self.index.delete(ids) + return True + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding. + """ + docs_with_score = self.similarity_search_with_score( + query=query, + k=k, + filter=filter, + query_type=query_type, + **kwargs, + ) + return [doc for doc, _ in docs_with_score] + + async def asimilarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search, query, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._index_details.is_databricks_managed_embeddings(): + query_text = query + query_vector = None + else: + # The value for `query_text` needs to be specified only for hybrid search. + if query_type is not None and query_type.upper() == "HYBRID": + query_text = query + else: + query_text = None + query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + + search_resp = self.index.similarity_search( + columns=self._columns, + query_text=query_text, + query_vector=query_vector, + filters=filter, + num_results=k, + query_type=query_type, + ) + return self._parse_search_response(search_resp) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + Databricks Vector search uses a normalized score 1/(1+d) where d + is the L2 distance. Hence, we simply return the identity function. + """ + return lambda score: score + + async def asimilarity_search_with_score( + self, *args: Any, **kwargs: Any + ) -> List[Tuple[Document, float]]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_with_score, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + query: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector" + ) + + docs_with_score = self.similarity_search_by_vector_with_score( + embedding=embedding, + k=k, + filter=filter, + query_type=query_type, + query=query, + **kwargs, + ) + return [doc for doc, _ in docs_with_score] + + async def asimilarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + query: Optional[str] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector, along with scores. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score" + ) + + if query_type is not None and query_type.upper() == "HYBRID": + if query is None: + raise ValueError( + "A value for `query` must be specified for hybrid search." + ) + query_text = query + else: + if query is not None: + raise ValueError( + ( + "Cannot specify both `embedding` and " + '`query` unless `query_type="HYBRID"' + ) + ) + query_text = None + + search_resp = self.index.similarity_search( + columns=self._columns, + query_vector=embedding, + query_text=query_text, + filters=filter, + num_results=k, + query_type=query_type, + ) + return self._parse_search_response(search_resp) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search" + ) + + query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr] + docs = self.max_marginal_relevance_search_by_vector( + query_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + query_type=query_type, + ) + return docs + + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + # This is a temporary workaround to make the similarity search + # asynchronous. The proper solution is to make the similarity search + # asynchronous in the vector store implementations. + func = partial( + self.max_marginal_relevance_search, + query, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + **kwargs, + ) + return await asyncio.get_event_loop().run_in_executor(None, func) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Any] = None, + *, + query_type: Optional[str] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + .. note:: + + This method is not supported for index with Databricks-managed embeddings. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter: Filters to apply to the query. Defaults to None. + query_type: The type of this query. Supported values are "ANN" and "HYBRID". + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._index_details.is_databricks_managed_embeddings(): + raise NotImplementedError( + _NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector" + ) + + embedding_column = self._index_details.embedding_vector_column["name"] + search_resp = self.index.similarity_search( + columns=list(set(self._columns + [embedding_column])), + query_text=None, + query_vector=embedding, + filters=filter, + num_results=fetch_k, + query_type=query_type, + ) + + embeddings_result_index = ( + search_resp.get("manifest").get("columns").index({"name": embedding_column}) + ) + embeddings = [ + doc[embeddings_result_index] + for doc in search_resp.get("result").get("data_array") + ] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) + + ignore_cols: List = ( + [embedding_column] if embedding_column not in self._columns else [] + ) + candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) + selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + async def amax_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + raise NotImplementedError + + def _parse_search_response( + self, search_resp: Dict, ignore_cols: Optional[List[str]] = None + ) -> List[Tuple[Document, float]]: + """Parse the search response into a list of Documents with score.""" + if ignore_cols is None: + ignore_cols = [] + + columns = [ + col["name"] + for col in search_resp.get("manifest", dict()).get("columns", []) + ] + docs_with_score = [] + for result in search_resp.get("result", dict()).get("data_array", []): + doc_id = result[columns.index(self._primary_key)] + text_content = result[columns.index(self._text_column)] + ignore_cols = [self._primary_key, self._text_column] + ignore_cols + metadata = { + col: value + for col, value in zip(columns[:-1], result[:-1]) + if col not in ignore_cols + } + metadata[self._primary_key] = doc_id + score = result[-1] + doc = Document(page_content=text_content, metadata=metadata) + docs_with_score.append((doc, score)) + return docs_with_score + + +def _validate_and_get_text_column( + text_column: Optional[str], index_details: IndexDetails +) -> str: + if index_details.is_databricks_managed_embeddings(): + index_source_column: str = index_details.embedding_source_column["name"] + # check if input text column matches the source column of the index + if text_column is not None: + raise ValueError( + f"The index '{index_details.name}' has the source column configured as " + f"'{index_source_column}'. Do not pass the `text_column` parameter." + ) + return index_source_column + else: + if text_column is None: + raise ValueError("The `text_column` parameter is required for this index.") + return text_column + + +def _validate_and_get_return_columns( + columns: List[str], text_column: str, index_details: IndexDetails +) -> List[str]: + """ + Get a list of columns to retrieve from the index. + + If the index is direct-access index, validate the given columns against the schema. + """ + # add primary key column and source column if not in columns + if index_details.primary_key not in columns: + columns.append(index_details.primary_key) + if text_column and text_column not in columns: + columns.append(text_column) + + # Validate specified columns are in the index + if index_details.is_direct_access_index() and ( + index_schema := index_details.schema + ): + if missing_columns := [c for c in columns if c not in index_schema]: + raise ValueError( + "Some columns specified in `columns` are not " + f"in the index schema: {missing_columns}" + ) + return columns + + +def _validate_embedding( + embedding: Optional[Embeddings], index_details: IndexDetails +) -> None: + if index_details.is_databricks_managed_embeddings(): + if embedding is not None: + raise ValueError( + f"The index '{index_details.name}' uses Databricks-managed embeddings. " + "Do not pass the `embedding` parameter when initializing vector store." + ) + else: + if not embedding: + raise ValueError( + "The `embedding` parameter is required for a direct-access index " + "or delta-sync index with self-managed embedding." + ) + _validate_embedding_dimension(embedding, index_details) + + +def _validate_embedding_dimension( + embeddings: Embeddings, index_details: IndexDetails +) -> None: + """validate if the embedding dimension matches with the index's configuration.""" + if index_embedding_dimension := index_details.embedding_vector_column.get( + "embedding_dimension" + ): + # Infer the embedding dimension from the embedding function.""" + actual_dimension = len(embeddings.embed_query("test")) + if actual_dimension != index_embedding_dimension: + raise ValueError( + f"The specified embedding model's dimension '{actual_dimension}' does " + f"not match with the index configuration '{index_embedding_dimension}'." + ) + + +class IndexDetails: + """An utility class to store the configuration details of an index.""" + + def __init__(self, index: Any): + self._index_details = index.describe() + + @property + def name(self) -> str: + return self._index_details["name"] + + @property + def schema(self) -> Optional[Dict]: + if self.is_direct_access_index(): + schema_json = self.index_spec.get("schema_json") + if schema_json is not None: + return json.loads(schema_json) + return None + + @property + def primary_key(self) -> str: + return self._index_details["primary_key"] + + @property + def index_spec(self) -> Dict: + return ( + self._index_details.get("delta_sync_index_spec", {}) + if self.is_delta_sync_index() + else self._index_details.get("direct_access_index_spec", {}) + ) + + @property + def embedding_vector_column(self) -> Dict: + if vector_columns := self.index_spec.get("embedding_vector_columns"): + return vector_columns[0] + return {} + + @property + def embedding_source_column(self) -> Dict: + if source_columns := self.index_spec.get("embedding_source_columns"): + return source_columns[0] + return {} + + def is_delta_sync_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DELTA_SYNC.value + + def is_direct_access_index(self) -> bool: + return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value + + def is_databricks_managed_embeddings(self) -> bool: + return ( + self.is_delta_sync_index() + and self.embedding_source_column.get("name") is not None + ) diff --git a/libs/partners/databricks/poetry.lock b/libs/partners/databricks/poetry.lock index 866aaa1a2f6bb..6c7ae53c5e2ea 100644 --- a/libs/partners/databricks/poetry.lock +++ b/libs/partners/databricks/poetry.lock @@ -339,6 +339,22 @@ requests = ">=2.28.1,<3" dev = ["autoflake", "databricks-connect", "ipython", "ipywidgets", "isort", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-rerunfailures", "pytest-xdist", "requests-mock", "wheel", "yapf"] notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"] +[[package]] +name = "databricks-vectorsearch" +version = "0.40" +description = "Databricks Vector Search Client" +optional = false +python-versions = ">=3.7" +files = [ + {file = "databricks_vectorsearch-0.40-py3-none-any.whl", hash = "sha256:c684291e1b0472ece8f6df8c6ff7982f49ce7075e1df5b93459e148dea1d70d7"}, +] + +[package.dependencies] +deprecation = ">=2" +mlflow-skinny = ">=2.11.3,<3" +protobuf = ">=3.12.0,<5" +requests = ">=2" + [[package]] name = "deprecated" version = "1.2.14" @@ -356,6 +372,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "docker" version = "7.1.0" @@ -1469,8 +1499,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1613,22 +1643,22 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "protobuf" -version = "5.27.3" +version = "4.25.4" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-5.27.3-cp310-abi3-win32.whl", hash = "sha256:dcb307cd4ef8fec0cf52cb9105a03d06fbb5275ce6d84a6ae33bc6cf84e0a07b"}, - {file = "protobuf-5.27.3-cp310-abi3-win_amd64.whl", hash = "sha256:16ddf3f8c6c41e1e803da7abea17b1793a97ef079a912e42351eabb19b2cffe7"}, - {file = "protobuf-5.27.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:68248c60d53f6168f565a8c76dc58ba4fa2ade31c2d1ebdae6d80f969cdc2d4f"}, - {file = "protobuf-5.27.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:b8a994fb3d1c11156e7d1e427186662b64694a62b55936b2b9348f0a7c6625ce"}, - {file = "protobuf-5.27.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:a55c48f2a2092d8e213bd143474df33a6ae751b781dd1d1f4d953c128a415b25"}, - {file = "protobuf-5.27.3-cp38-cp38-win32.whl", hash = "sha256:043853dcb55cc262bf2e116215ad43fa0859caab79bb0b2d31b708f128ece035"}, - {file = "protobuf-5.27.3-cp38-cp38-win_amd64.whl", hash = "sha256:c2a105c24f08b1e53d6c7ffe69cb09d0031512f0b72f812dd4005b8112dbe91e"}, - {file = "protobuf-5.27.3-cp39-cp39-win32.whl", hash = "sha256:c84eee2c71ed83704f1afbf1a85c3171eab0fd1ade3b399b3fad0884cbcca8bf"}, - {file = "protobuf-5.27.3-cp39-cp39-win_amd64.whl", hash = "sha256:af7c0b7cfbbb649ad26132e53faa348580f844d9ca46fd3ec7ca48a1ea5db8a1"}, - {file = "protobuf-5.27.3-py3-none-any.whl", hash = "sha256:8572c6533e544ebf6899c360e91d6bcbbee2549251643d32c52cf8a5de295ba5"}, - {file = "protobuf-5.27.3.tar.gz", hash = "sha256:82460903e640f2b7e34ee81a947fdaad89de796d324bcbc38ff5430bcdead82c"}, + {file = "protobuf-4.25.4-cp310-abi3-win32.whl", hash = "sha256:db9fd45183e1a67722cafa5c1da3e85c6492a5383f127c86c4c4aa4845867dc4"}, + {file = "protobuf-4.25.4-cp310-abi3-win_amd64.whl", hash = "sha256:ba3d8504116a921af46499471c63a85260c1a5fc23333154a427a310e015d26d"}, + {file = "protobuf-4.25.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:eecd41bfc0e4b1bd3fa7909ed93dd14dd5567b98c941d6c1ad08fdcab3d6884b"}, + {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:4c8a70fdcb995dcf6c8966cfa3a29101916f7225e9afe3ced4395359955d3835"}, + {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3319e073562e2515c6ddc643eb92ce20809f5d8f10fead3332f71c63be6a7040"}, + {file = "protobuf-4.25.4-cp38-cp38-win32.whl", hash = "sha256:7e372cbbda66a63ebca18f8ffaa6948455dfecc4e9c1029312f6c2edcd86c4e1"}, + {file = "protobuf-4.25.4-cp38-cp38-win_amd64.whl", hash = "sha256:051e97ce9fa6067a4546e75cb14f90cf0232dcb3e3d508c448b8d0e4265b61c1"}, + {file = "protobuf-4.25.4-cp39-cp39-win32.whl", hash = "sha256:90bf6fd378494eb698805bbbe7afe6c5d12c8e17fca817a646cd6a1818c696ca"}, + {file = "protobuf-4.25.4-cp39-cp39-win_amd64.whl", hash = "sha256:ac79a48d6b99dfed2729ccccee547b34a1d3d63289c71cef056653a846a2240f"}, + {file = "protobuf-4.25.4-py3-none-any.whl", hash = "sha256:bfbebc1c8e4793cfd58589acfb8a1026be0003e852b9da7db5a4285bde996978"}, + {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, ] [[package]] @@ -2492,4 +2522,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "6da52f0e39bc7da1a80cc181bdd481e57f4644daf2f3c6da6a6b0ead2e813be9" +content-hash = "857f47603d9dd6fe8882c7525613a54a54ee459a9ee012f3d19e510c5477f3db" diff --git a/libs/partners/databricks/pyproject.toml b/libs/partners/databricks/pyproject.toml index cdea854df91a8..22b6554cfd094 100644 --- a/libs/partners/databricks/pyproject.toml +++ b/libs/partners/databricks/pyproject.toml @@ -26,6 +26,7 @@ scipy = [ {version = ">=1.11", python = ">=3.12"}, {version = "<2", python = "<3.12"} ] +databricks-vectorsearch = "^0.40" [tool.poetry.group.test] optional = true diff --git a/libs/partners/databricks/tests/unit_tests/test_embeddings.py b/libs/partners/databricks/tests/unit_tests/test_embeddings.py new file mode 100644 index 0000000000000..655add03fe543 --- /dev/null +++ b/libs/partners/databricks/tests/unit_tests/test_embeddings.py @@ -0,0 +1,69 @@ +"""Test Together AI embeddings.""" + +from typing import Any, Dict, Generator +from unittest import mock + +import pytest +from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped] + +from langchain_databricks import DatabricksEmbeddings + + +def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]: + return { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": list(range(1536)), + "index": 0, + } + for _ in inputs["input"] + ], + "model": "text-embedding-3-small", + "usage": {"prompt_tokens": 8, "total_tokens": 8}, + } + + +@pytest.fixture +def mock_client() -> Generator: + client = mock.MagicMock() + client.predict.side_effect = _mock_embeddings + with mock.patch("mlflow.deployments.get_deploy_client", return_value=client): + yield client + + +@pytest.fixture +def embeddings() -> DatabricksEmbeddings: + return DatabricksEmbeddings( + endpoint="text-embedding-3-small", + documents_params={"fruit": "apple"}, + query_params={"fruit": "banana"}, + ) + + +def test_embed_documents( + mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings +) -> None: + documents = ["foo"] * 30 + output = embeddings.embed_documents(documents) + assert len(output) == 30 + assert len(output[0]) == 1536 + assert mock_client.predict.call_count == 2 + assert all( + call_arg[1]["inputs"]["fruit"] == "apple" + for call_arg in mock_client().predict.call_args_list + ) + + +def test_embed_query( + mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings +) -> None: + query = "foo bar" + output = embeddings.embed_query(query) + assert len(output) == 1536 + mock_client.predict.assert_called_once() + assert mock_client.predict.call_args[1] == { + "endpoint": "text-embedding-3-small", + "inputs": {"input": [query], "fruit": "banana"}, + } diff --git a/libs/partners/databricks/tests/unit_tests/test_imports.py b/libs/partners/databricks/tests/unit_tests/test_imports.py index 579123a8bbb06..dfcdfaa1ded84 100644 --- a/libs/partners/databricks/tests/unit_tests/test_imports.py +++ b/libs/partners/databricks/tests/unit_tests/test_imports.py @@ -2,6 +2,8 @@ EXPECTED_ALL = [ "ChatDatabricks", + "DatabricksEmbeddings", + "DatabricksVectorSearch", "__version__", ] diff --git a/libs/partners/databricks/tests/unit_tests/test_vectorstore.py b/libs/partners/databricks/tests/unit_tests/test_vectorstore.py new file mode 100644 index 0000000000000..ed8654e787036 --- /dev/null +++ b/libs/partners/databricks/tests/unit_tests/test_vectorstore.py @@ -0,0 +1,629 @@ +import uuid +from typing import Any, Dict, Generator, List, Optional, Set +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.embeddings import Embeddings + +from langchain_databricks.vectorstores import DatabricksVectorSearch + +INPUT_TEXTS = ["foo", "bar", "baz"] +DEFAULT_VECTOR_DIMENSION = 4 + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION): + super().__init__() + self.dimension = dimension + + def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (self.dimension - 1) + [float(i)] + for i in range(len(embedding_texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (self.dimension - 1) + [float(0.0)] + + +EMBEDDING_MODEL = FakeEmbeddings() + + +### Dummy similarity_search() Response ### +EXAMPLE_SEARCH_RESPONSE = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "text_vector"}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(INPUT_TEXTS), + "data_array": sorted( + [ + [str(uuid.uuid4()), s, e, 0.5] + for s, e in zip( + INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + ) + ], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + + +### Dummy Indices #### + +ENDPOINT_NAME = "test-endpoint" +DIRECT_ACCESS_INDEX = "test-direct-access-index" +DELTA_SYNC_INDEX = "test-delta-sync-index" +DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index" +ALL_INDEX_NAMES = { + DIRECT_ACCESS_INDEX, + DELTA_SYNC_INDEX, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, +} + +INDEX_DETAILS = { + DELTA_SYNC_INDEX: { + "name": DELTA_SYNC_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_source_columns": [ + { + "name": "text", + "embedding_model_endpoint_name": "openai-text-embedding", + } + ], + }, + }, + DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: { + "name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DELTA_SYNC", + "primary_key": "id", + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + }, + }, + DIRECT_ACCESS_INDEX: { + "name": DIRECT_ACCESS_INDEX, + "endpoint_name": ENDPOINT_NAME, + "index_type": "DIRECT_ACCESS", + "primary_key": "id", + "direct_access_index_spec": { + "embedding_vector_columns": [ + { + "name": "text_vector", + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + "schema_json": f"{{" + f'"{"id"}": "int", ' + f'"feat1": "str", ' + f'"feat2": "float", ' + f'"text": "string", ' + f'"{"text_vector"}": "array"' + f"}}", + }, + }, +} + + +@pytest.fixture(autouse=True) +def mock_vs_client() -> Generator: + def _get_index(endpoint: str, index_name: str) -> MagicMock: + from databricks.vector_search.client import VectorSearchIndex # type: ignore + + if endpoint != ENDPOINT_NAME: + raise ValueError(f"Unknown endpoint: {endpoint}") + + index = MagicMock(spec=VectorSearchIndex) + index.describe.return_value = INDEX_DETAILS[index_name] + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + return index + + mock_client = MagicMock() + mock_client.get_index.side_effect = _get_index + with mock.patch( + "databricks.vector_search.client.VectorSearchClient", + return_value=mock_client, + ): + yield + + +def init_vector_search( + index_name: str, columns: Optional[List[str]] = None +) -> DatabricksVectorSearch: + kwargs: Dict[str, Any] = { + "endpoint": ENDPOINT_NAME, + "index_name": index_name, + "columns": columns, + } + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding": EMBEDDING_MODEL, + "text_column": "text", + } + ) + return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type] + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + assert vectorsearch.index.describe() == INDEX_DETAILS[index_name] + + +def test_init_fail_text_column_mismatch() -> None: + with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"): + DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=DELTA_SYNC_INDEX, + text_column="some_other_column", + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_no_text_column(index_name: str) -> None: + with pytest.raises(ValueError, match="The `text_column` parameter is required"): + DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=index_name, + embedding=EMBEDDING_MODEL, + ) + + +def test_init_fail_columns_not_in_schema() -> None: + columns = ["some_random_column"] + with pytest.raises(ValueError, match="Some columns specified in `columns`"): + init_vector_search(DIRECT_ACCESS_INDEX, columns=columns) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_no_embedding(index_name: str) -> None: + with pytest.raises(ValueError, match="The `embedding` parameter is required"): + DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=index_name, + text_column="text", + ) + + +def test_init_fail_embedding_already_specified_in_source() -> None: + with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"): + DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=DELTA_SYNC_INDEX, + embedding=EMBEDDING_MODEL, + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_init_fail_embedding_dim_mismatch(index_name: str) -> None: + with pytest.raises( + ValueError, match="embedding model's dimension '1000' does not match" + ): + DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=index_name, + text_column="text", + embedding=FakeEmbeddings(1000), + ) + + +def test_from_texts_not_supported() -> None: + with pytest.raises(NotImplementedError, match="`from_texts` is not supported"): + DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX}) +def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + with pytest.raises( + NotImplementedError, + match="`add_texts` is only supported for direct-access index.", + ): + vectorsearch.add_texts(INPUT_TEXTS) + + +def is_valid_uuid(val: str) -> bool: + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def test_add_texts() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + ids = [idx for idx, i in enumerate(INPUT_TEXTS)] + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert added_ids == ids + + +def test_add_texts_handle_single_text() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS[0]) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids) + ] + ) + assert len(added_ids) == 1 + assert is_valid_uuid(added_ids[0]) + + +def test_add_texts_with_default_id() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + + added_ids = vectorsearch.add_texts(INPUT_TEXTS) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + } + for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +def test_add_texts_with_metadata() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS) + metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))] + + added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas) + vectorsearch.index.upsert.assert_called_once_with( + [ + { + "id": id_, + "text": text, + "text_vector": vector, + **metadata, # type: ignore[arg-type] + } + for text, vector, id_, metadata in zip( + INPUT_TEXTS, vectors, added_ids, metadatas + ) + ] + ) + assert len(added_ids) == len(INPUT_TEXTS) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_embeddings_property(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + assert vectorsearch.embeddings == EMBEDDING_MODEL + + +def test_delete() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + vectorsearch.delete(["some id"]) + vectorsearch.index.delete.assert_called_once_with(["some id"]) + + +def test_delete_fail_no_ids() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + with pytest.raises(ValueError, match="ids must be provided."): + vectorsearch.delete() + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX}) +def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + with pytest.raises( + NotImplementedError, match="`delete` is only supported for direct-access" + ): + vectorsearch.delete(["some id"]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("query_type", [None, "ANN"]) +def test_similarity_search(index_name: str, query_type: Optional[str]) -> None: + vectorsearch = init_vector_search(index_name) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search( + query, k=limit, filter=filters, query_type=query_type + ) + if index_name == DELTA_SYNC_INDEX: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + query_type=query_type, + ) + else: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=None, + query_vector=EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + query_type=query_type, + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_similarity_search_hybrid(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search( + query, k=limit, filter=filters, query_type="HYBRID" + ) + if index_name == DELTA_SYNC_INDEX: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + else: + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_text=query, + query_vector=EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +def test_similarity_search_both_filter_and_filters_passed() -> None: + vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX) + query = "foo" + filter = {"some filter": True} + filters = {"some other filter": False} + + vectorsearch.similarity_search(query, filter=filter, filters=filters) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=EMBEDDING_MODEL.embed_query(query), + # `filter` should prevail over `filters` + filters=filter, + num_results=4, + query_text=None, + query_type=None, + ) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +@pytest.mark.parametrize( + "columns, expected_columns", + [ + (None, {"id"}), + (["id", "text", "text_vector"], {"text_vector", "id"}), + ], +) +def test_mmr_search( + index_name: str, columns: Optional[List[str]], expected_columns: Set[str] +) -> None: + vectorsearch = init_vector_search(index_name, columns=columns) + + query = INPUT_TEXTS[0] + filters = {"some filter": True} + limit = 1 + + search_result = vectorsearch.max_marginal_relevance_search( + query, k=limit, filters=filters + ) + assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]] + assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns] + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_mmr_parameters(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + + query = INPUT_TEXTS[0] + limit = 1 + fetch_k = 3 + lambda_mult = 0.25 + filters = {"some filter": True} + + with patch( + "langchain_databricks.vectorstores.maximal_marginal_relevance" + ) as mock_mmr: + mock_mmr.return_value = [2] + retriever = vectorsearch.as_retriever( + search_type="mmr", + search_kwargs={ + "k": limit, + "fetch_k": fetch_k, + "lambda_mult": lambda_mult, + "filter": filters, + }, + ) + search_result = retriever.invoke(query) + + mock_mmr.assert_called_once() + assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult + assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k + assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters + assert len(search_result) == limit + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8]) +def test_similarity_score_threshold(index_name: str, threshold: float) -> None: + query = INPUT_TEXTS[0] + limit = len(INPUT_TEXTS) + + vectorsearch = init_vector_search(index_name) + retriever = vectorsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": limit, "score_threshold": threshold}, + ) + search_result = retriever.invoke(query) + if threshold <= 0.5: + assert len(search_result) == len(INPUT_TEXTS) + else: + assert len(search_result) == 0 + + +def test_standard_params() -> None: + vectorstore = init_vector_search(DIRECT_ACCESS_INDEX) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + "ls_embedding_provider": "FakeEmbeddings", + } + + vectorstore = init_vector_search(DELTA_SYNC_INDEX) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + } + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +@pytest.mark.parametrize("query_type", [None, "ANN"]) +def test_similarity_search_by_vector( + index_name: str, query_type: Optional[str] +) -> None: + vectorsearch = init_vector_search(index_name) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filter=filters, query_type=query_type + ) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=query_embedding, + filters=filters, + num_results=limit, + query_type=query_type, + query_text=None, + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX}) +def test_similarity_search_by_vector_hybrid(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo" + ) + vectorsearch.index.similarity_search.assert_called_once_with( + columns=["id", "text"], + query_vector=query_embedding, + filters=filters, + num_results=limit, + query_type="HYBRID", + query_text="foo", + ) + assert len(search_result) == len(INPUT_TEXTS) + assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS) + assert all(["id" in d.metadata for d in search_result]) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_similarity_search_empty_result(index_name: str) -> None: + vectorsearch = init_vector_search(index_name) + vectorsearch.index.similarity_search.return_value = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": "id"}, + {"name": "text"}, + {"name": "score"}, + ], + }, + "result": { + "row_count": 0, + "data_array": [], + }, + "next_page_token": "", + } + + search_result = vectorsearch.similarity_search("foo") + assert len(search_result) == 0 + + +def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None: + vectorsearch = init_vector_search(DELTA_SYNC_INDEX) + query_embedding = EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + with pytest.raises( + NotImplementedError, match="`similarity_search_by_vector` is not supported" + ): + vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filters=filters + )