From ba3b41169c1a72daf23325981ec2083c7fe1de00 Mon Sep 17 00:00:00 2001 From: Sanjay Nadhavajhala Date: Fri, 9 Feb 2024 00:06:34 -0800 Subject: [PATCH] init vector db api --- services/backend/vector_db_api/Dockerfile | 25 + services/backend/vector_db_api/README.md | 11 + .../vector_db_api/milvus/CustomEmbeddings.py | 55 ++ services/backend/vector_db_api/milvus/main.py | 119 +++ .../vector_db_api/milvus/query_milvus.py | 825 ++++++++++++++++++ .../backend/vector_db_api/requirements.txt | 8 + 6 files changed, 1043 insertions(+) create mode 100644 services/backend/vector_db_api/Dockerfile create mode 100644 services/backend/vector_db_api/README.md create mode 100644 services/backend/vector_db_api/milvus/CustomEmbeddings.py create mode 100644 services/backend/vector_db_api/milvus/main.py create mode 100644 services/backend/vector_db_api/milvus/query_milvus.py create mode 100644 services/backend/vector_db_api/requirements.txt diff --git a/services/backend/vector_db_api/Dockerfile b/services/backend/vector_db_api/Dockerfile new file mode 100644 index 0000000..da8c1be --- /dev/null +++ b/services/backend/vector_db_api/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.8-slim-buster + +# Create a separate directory for dependencies +WORKDIR /app/requirements + +# Copy requirements.txt into the container +COPY requirements.txt . + +# Install the required packages +RUN pip install --no-cache-dir -r requirements.txt + +# Set the working directory +WORKDIR /app + +# Copy the FastAPI app script into the container +COPY milvus ./milvus + + +# Expose the default port for FastAPI +EXPOSE 8010 + +ENV HOST=host.docker.internal + +# Run the FastAPI app using uvicorn +CMD ["uvicorn", "milvus.main:app", "--host", "0.0.0.0", "--port", "8010"] diff --git a/services/backend/vector_db_api/README.md b/services/backend/vector_db_api/README.md new file mode 100644 index 0000000..f50676a --- /dev/null +++ b/services/backend/vector_db_api/README.md @@ -0,0 +1,11 @@ +# Vector Search Service + +This service provides a FastAPI application that interfaces with a vector database for text embedding and similarity search operations. It is designed to be flexible and allows for the potential integration of various vector databases in the future. + +The service provides the following endpoints: + +- `POST /add_texts`: Add texts to a specified Milvus collection +- `DELETE /delete_docs`: Delete documents from a specified Milvus collection based on an expression +- `POST /similarity_match`: Perform a similarity search and return the top-k similar documents +- `GET /ping`: Health check endpoint + diff --git a/services/backend/vector_db_api/milvus/CustomEmbeddings.py b/services/backend/vector_db_api/milvus/CustomEmbeddings.py new file mode 100644 index 0000000..085f281 --- /dev/null +++ b/services/backend/vector_db_api/milvus/CustomEmbeddings.py @@ -0,0 +1,55 @@ +# Standard Library +import json +import os +from typing import List + +# Third Party +import requests +from langchain.embeddings.base import Embeddings + +HOST = os.getenv("EMBEDDING_HOST", "localhost") +EMBEDDING_URL = f"http://{HOST}:8020/embed_multiple" + + +def embed_text(texts: List[str]) -> List[List[float]]: + """Embed a list of texts using a remote service. + + Args: + texts (List[str]): List of texts to be embedded. + + Returns: + List[List[float]]: List of embedded texts. + """ + headers = {"accept": "application/json", "Content-Type": "application/json"} + data = json.dumps(texts) + + response = requests.post(EMBEDDING_URL, headers=headers, data=data) + response = response.json() + + return response["embeddings"] + + +class CustomEmbeddings(Embeddings): + """Custom embeddings class that uses a remote service for embedding.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents. + + Args: + texts (List[str]): List of documents to be embedded. + + Returns: + List[List[float]]: List of embedded documents. + """ + return embed_text(texts) + + def embed_query(self, text: str) -> List[float]: + """Embed a single query. + + Args: + text (str): Query to be embedded. + + Returns: + List[float]: Embedded query. + """ + return embed_text([text])[0] diff --git a/services/backend/vector_db_api/milvus/main.py b/services/backend/vector_db_api/milvus/main.py new file mode 100644 index 0000000..e3922d5 --- /dev/null +++ b/services/backend/vector_db_api/milvus/main.py @@ -0,0 +1,119 @@ +# Standard library imports +# Standard Library +import os +from typing import List, Optional + +# Third Party +# Third party imports +from fastapi import FastAPI +from pydantic import BaseModel + +# Local application imports +from .CustomEmbeddings import CustomEmbeddings +from .query_milvus import Milvus + +MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost") + +model = {} +top_re_rank = 5 +top_k_match = 10 +app = FastAPI() + + +class Query(BaseModel): + text: str + collection_name: str + topk: int = top_k_match + rerank: bool = False + topr: int = top_re_rank + + +@app.on_event("startup") +async def app_startup(): + pass + + +def drop_collection(collection_name: str): + load_collection(collection_name).drop_collection() + + +def load_collection(collection_name: str) -> Milvus: + return Milvus( + embedding_function=CustomEmbeddings(), + collection_name=collection_name, + connection_args={ + "host": MILVUS_HOST, + "port": "19530", + "user": "username", + "password": "password", + }, + index_params={ + "metric_type": "IP", + "index_type": "FLAT", + "params": {"nlist": 16384}, + }, + search_params={"metric_type": "IP", "params": {"nprobe": 32}}, + ) + + +@app.post("/add_texts") +async def add_texts_embeddings( + collection_name: str, + texts: List[str], + metadatas: Optional[List[dict]] = None, +): + """_summary_ + + Args: + texts (List[str]): _description_ + connlection_name (str): this should reflect user's random id + the assistant_id they created. + """ + pks = add_texts(collection_name, texts, metadatas) + + +def add_texts( + collection_name: str, + texts: List[str], + metadatas: Optional[List[dict]] = None, +): + c = load_collection(collection_name) + pks = c.add_texts(texts=texts, metadatas=metadatas) + print(pks) + return pks + + +@app.delete("/delete_docs") +async def delete_docs_api(collection_name: str, expr: str): + delete_docs(collection_name, expr) + + +def delete_docs(collection_name: str, expr: str): + c = load_collection(collection_name) + c.delete_entities(expr=expr) + + +def get_top_k_biencoder_match_milvus(query: Query): + c = load_collection(query.collection_name) + + docs = c.similarity_search(query.text, k=query.topk) + res = [] + for i, d in enumerate(docs): + thisd = {"id": i, "metadata": d.metadata, "text": d.page_content} + res.append(thisd) + return res + + +def get_similar_match(query, biencoder_match_method: str, rerank: bool = False): + query_biencoder_matches = get_top_k_biencoder_match_milvus(query) + return query_biencoder_matches[: query.topr] + + +@app.post("/similarity_match") +def text_similarity_match(query: Query): + res = get_similar_match(query, biencoder_match_method="milvus", rerank=query.rerank) + return {"response": res} + + +@app.get("/ping") +def ping(): + return {"response": "Pong!"} diff --git a/services/backend/vector_db_api/milvus/query_milvus.py b/services/backend/vector_db_api/milvus/query_milvus.py new file mode 100644 index 0000000..4e7779c --- /dev/null +++ b/services/backend/vector_db_api/milvus/query_milvus.py @@ -0,0 +1,825 @@ +"""Wrapper around the Milvus vector database.""" + +# Standard Library +import logging +import os +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +# Third Party +import numpy as np +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + +DEFAULT_MILVUS_CONNECTION = { + "host": os.getenv("MILVUS_HOST", "localhost"), + "port": "19530", + "user": "", + "password": "", + "secure": False, +} + + +class Milvus(VectorStore): + """Wrapper around the Milvus vector database.""" + + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "DefaultCollection", + connection_args: Optional[Dict[str, Any]] = None, + consistency_level: str = "Session", + index_params: Optional[Dict[str, Any]] = None, + search_params: Optional[Dict[str, Any]] = None, + drop_old: Optional[bool] = False, + ): + """Initialize wrapper around the milvus vector database. + + In order to use this you need to have `pymilvus` installed and a + running Milvus/Zilliz Cloud instance. + + See the following documentation for how to run a Milvus instance: + https://milvus.io/docs/install_standalone-docker.md + + If looking for a hosted Milvus, take a looka this documentation: + https://zilliz.com/cloud + + IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. + + The connection args used for this class comes in the form of a dict, + here are a few of the options: + address (str): The actual address of Milvus + instance. Example address: "localhost:19530" + uri (str): The uri of Milvus instance. Example uri: + "http://randomwebsite:19530", + "tcp:foobarsite:19530", + "https://ok.s3.south.com:19530". + host (str): The host of Milvus instance. Default at "localhost", + PyMilvus will fill in the default host if only port is provided. + port (str/int): The port of Milvus instance. Default at 19530, PyMilvus + will fill in the default port if only host is provided. + user (str): Use which user to connect to Milvus instance. If user and + password are provided, we will add related header in every RPC call. + password (str): Required when user is provided. The password + corresponding to the user. + secure (bool): Default is false. If set to true, tls will be enabled. + client_key_path (str): If use tls two-way authentication, need to + write the client.key path. + client_pem_path (str): If use tls two-way authentication, need to + write the client.pem path. + ca_pem_path (str): If use tls two-way authentication, need to write + the ca.pem path. + server_pem_path (str): If use tls one-way authentication, need to + write the server.pem path. + server_name (str): If use tls, need to write the common name. + + Args: + embedding_function (Embeddings): Function used to embed the text. + collection_name (str): Which Milvus collection to use. Defaults to + "LangChainCollection". + connection_args (Optional[dict[str, any]]): The arguments for connection to + Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION. + consistency_level (str): The consistency level to use for a collection. + Defaults to "Session". + index_params (Optional[dict]): Which index params to use. Defaults to + HNSW/AUTOINDEX depending on service. + search_params (Optional[dict]): Which search params to use. Defaults to + default of index. + drop_old (Optional[bool]): Whether to drop the current collection. Defaults + to False. + """ + try: + # Third Party + from pymilvus import Collection, utility + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) + + # Default search params when one is not provided. + self.default_search_params = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, + } + + self.embedding_func = embedding_function + self.collection_name = collection_name + self.index_params = index_params + self.search_params = search_params + self.consistency_level = consistency_level + + # In order for a collection to be compatible, pk needs to be auto'id and int + self._primary_field = "pk" + # In order for compatiblility, the text field will need to be called "text" + self._text_field = "text" + # In order for compatbility, the vector field needs to be called "vector" + self._vector_field = "vector" + self.fields: list[str] = [] + # Create the connection to the server + if connection_args is None: + connection_args = DEFAULT_MILVUS_CONNECTION + self.alias = self._create_connection_alias(connection_args) + self.col: Optional[Collection] = None + + # Grab the existing colection if it exists + if utility.has_collection(self.collection_name, using=self.alias): + self.col = Collection( + self.collection_name, + using=self.alias, + ) + # If need to drop old, drop it + if drop_old and isinstance(self.col, Collection): + self.col.drop() + self.col = None + + # Initialize the vector store + self._init() + + def drop_collection(self): + # Third Party + from pymilvus import utility + + utility.drop_collection(collection_name=self.collection_name, using=self.alias) + + def _create_connection_alias(self, connection_args: dict) -> str: + """Create the connection to the Milvus server.""" + # Third Party + from pymilvus import MilvusException, connections + + # Grab the connection arguments that are used for checking existing connection + host: str = connection_args.get("host", None) + port: Union[str, int] = connection_args.get("port", None) + address: str = connection_args.get("address", None) + uri: str = connection_args.get("uri", None) + user = connection_args.get("user", None) + + # Order of use is host/port, uri, address + if host is not None and port is not None: + given_address = str(host) + ":" + str(port) + elif uri is not None: + given_address = uri.split("https://")[1] + elif address is not None: + given_address = address + else: + given_address = None + logger.debug("Missing standard address type for reuse atttempt") + + # User defaults to empty string when getting connection info + if user is not None: + tmp_user = user + else: + tmp_user = "" + + # If a valid address was given, then check if a connection exists + if given_address is not None: + for con in connections.list_connections(): + addr = connections.get_connection_addr(con[0]) + if ( + con[1] + and ("address" in addr) + and (addr["address"] == given_address) + and ("user" in addr) + and (addr["user"] == tmp_user) + ): + logger.debug("Using previous connection: %s", con[0]) + return con[0] + + # Generate a new connection if one doesnt exist + alias = uuid4().hex + try: + connections.connect(alias=alias, **connection_args) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as e: + logger.error("Failed to create new connection using: %s", alias) + raise e + + def _init( + self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None + ) -> None: + if embeddings is not None: + self._create_collection(embeddings, metadatas) + self._extract_fields() + self._create_index() + self._create_search_params() + self._load() + + def _create_collection( + self, embeddings: list, metadatas: Optional[List[dict]] = None + ) -> None: + # Third Party + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusException, + ) + from pymilvus.orm.types import infer_dtype_bydata + + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + # Determine metadata schema + if metadatas: + # Create FieldSchema for each entry in metadata. + for key, value in metadatas[0].items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + # Datatype isnt compatible + if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + logger.error( + "Failure to create collection, unrecognized dtype for key: %s", + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + # Dataype is a string/varchar equivalent + elif dtype == DataType.VARCHAR: + fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) + else: + fields.append(FieldSchema(key, dtype)) + + # Create the text field + fields.append( + FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + self._primary_field, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + # Create the collection + try: + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + ) + except MilvusException as e: + logger.error( + "Failed to create collection: %s error: %s", self.collection_name, e + ) + raise e + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection): + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + # Since primary field is auto-id, no need to track it + self.fields.remove(self._primary_field) + + def _get_index(self) -> Optional[Dict[str, Any]]: + """Return the vector index information if it exists""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection): + for x in self.col.indexes: + if x.field_name == self._vector_field: + return x.to_dict() + return None + + def _create_index(self) -> None: + """Create a index on the collection""" + # Third Party + from pymilvus import Collection, MilvusException + + if isinstance(self.col, Collection) and self._get_index() is None: + try: + # If no index params, use a default HNSW based one + if self.index_params is None: + self.index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + + try: + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + + # If default did not work, most likely on Zilliz Cloud + except MilvusException: + # Use AUTOINDEX based index + self.index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise e + + def _create_search_params(self) -> None: + """Generate search params based on the current index type""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection) and self.search_params is None: + index = self._get_index() + if index is not None: + index_type: str = index["index_param"]["index_type"] + metric_type: str = index["index_param"]["metric_type"] + self.search_params = self.default_search_params[index_type] + self.search_params["metric_type"] = metric_type + + def _load(self) -> None: + """Load the collection if available.""" + # Third Party + from pymilvus import Collection + + if isinstance(self.col, Collection) and self._get_index() is not None: + self.col.load() + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + timeout: Optional[int] = None, + batch_size: int = 1000, + **kwargs: Any, + ) -> List[str]: + """Insert text data into Milvus. + + Inserting data when the collection has not be made yet will result + in creating a new Collection. The data of the first entity decides + the schema of the new collection, the dim is extracted from the first + embedding and the columns are decided by the first metadata dict. + Metada keys will need to be present for all inserted values. At + the moment there is no None equivalent in Milvus. + + Args: + texts (Iterable[str]): The texts to embed, it is assumed + that they all fit in memory. + metadatas (Optional[List[dict]]): Metadata dicts attached to each of + the texts. Defaults to None. + timeout (Optional[int]): Timeout for each batch insert. Defaults + to None. + batch_size (int, optional): Batch size to use for insertion. + Defaults to 1000. + + Raises: + MilvusException: Failure to add texts + + Returns: + List[str]: The resulting keys for each inserted element. + """ + # Third Party + from pymilvus import Collection, MilvusException + + texts = list(texts) + + try: + embeddings = self.embedding_func.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_func.embed_query(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + # If the collection hasnt been initialized yet, perform all steps to do so + if not isinstance(self.col, Collection): + self._init(embeddings, metadatas) + + # Dict to hold all insert columns + insert_dict: dict[str, list] = { + self._text_field: texts, + self._vector_field: embeddings, + } + + # Collect the metadata into the insert dict. + if metadatas is not None: + for d in metadatas: + for key in self.fields: + if key not in [self._text_field, self._vector_field]: + value = d.get( + key, "" + ) # Use a default value (e.g., None) if the key is not present in the metadata + insert_dict.setdefault(key, []).append(value) + + # Total insert count + vectors: list = insert_dict[self._vector_field] + total_count = len(vectors) + + pks: list[str] = [] + + assert isinstance(self.col, Collection) + for i in range(0, total_count, batch_size): + # Grab end index + end = min(i + batch_size, total_count) + # Convert dict to list of lists batch for insertion + insert_list = [insert_dict[x][i:end] for x in self.fields] + # Insert into the collection. + try: + res: Collection + res = self.col.insert(insert_list, timeout=timeout, **kwargs) + pks.extend(res.primary_keys) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def delete_entities(self, expr: str): + if self.col is None: + logger.debug("No existing collection to search.") + + self.col.delete(expr=expr) + + def similarity_search( + self, + query: str, + k: int = 5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + query (str): The text to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score( + query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + embedding (List[float]): The embedding vector to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + query (str): The text being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[float], List[Tuple[Document, any, any]]: + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + # Embed the query text. + embedding = self.embedding_func.embed_query(query) + + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return res + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + embedding (List[float]): The embedding vector being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Tuple[Document, float]]: Result doc and score. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ret = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + pair = (doc, result.score) + ret.append(pair) + + return ret + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + query (str): The text being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + embedding = self.embedding_func.embed_query(query) + + return self.max_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + param=param, + expr=expr, + timeout=timeout, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + embedding (str): The embedding vector being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=fetch_k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ids = [] + documents = [] + scores = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + documents.append(doc) + scores.append(result.score) + ids.append(result.id) + + vectors = self.col.query( + expr=f"{self._primary_field} in {ids}", + output_fields=[self._primary_field, self._vector_field], + timeout=timeout, + ) + # Reorganize the results from query to match search order. + vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} + + ordered_result_embeddings = [vectors[x] for x in ids] + + # Get the new order of results. + new_ordering = maximal_marginal_relevance( + np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult + ) + + # Reorder the values and return. + ret = [] + for x in new_ordering: + # Function can return -1 index + if x == -1: + break + else: + ret.append(documents[x]) + return ret + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + connection_args: Dict[str, Any] = DEFAULT_MILVUS_CONNECTION, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: bool = False, + **kwargs: Any, + ): + """Create a Milvus collection, indexes it with HNSW, and insert data. + + Args: + texts (List[str]): Text data. + embedding (Embeddings): Embedding function. + metadatas (Optional[List[dict]]): Metadata for each text if it exists. + Defaults to None. + collection_name (str, optional): Collection name to use. Defaults to + "LangChainCollection". + connection_args (dict[str, Any], optional): Connection args to use. Defaults + to DEFAULT_MILVUS_CONNECTION. + consistency_level (str, optional): Which consistency level to use. Defaults + to "Session". + index_params (Optional[dict], optional): Which index_params to use. Defaults + to None. + search_params (Optional[dict], optional): Which search params to use. + Defaults to None. + drop_old (Optional[bool], optional): Whether to drop the collection with + that name if it exists. Defaults to False. + + Returns: + Milvus: Milvus Vector Store + """ + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + consistency_level=consistency_level, + index_params=index_params, + search_params=search_params, + drop_old=drop_old, + **kwargs, + ) + vector_db.add_texts(texts=texts, metadatas=metadatas) + return vector_db diff --git a/services/backend/vector_db_api/requirements.txt b/services/backend/vector_db_api/requirements.txt new file mode 100644 index 0000000..eb46876 --- /dev/null +++ b/services/backend/vector_db_api/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.95.1 +langchain==0.0.187 +openai==0.27.7 +pymilvus==2.2.8 +python-dotenv==1.0.0 +tiktoken==0.4.0 +uvicorn==0.21.1 +python-multipart