From 8eb3908e870108affb7d78872a0e64d481deb0ef Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 8 Jul 2024 16:23:37 -0700 Subject: [PATCH] [FEAT][Faiss] --- README.md | 84 ++++++ chroma_db.py => examples/chroma_db.py | 0 examples/faiss_example.py | 78 ++++++ .../pinecome_wrapper_example.py | 0 swarms_memory/__init__.py | 3 +- swarms_memory/chroma_db_wrapper.py | 1 + swarms_memory/faiss_wrapper.py | 262 ++++++++++++++++++ swarms_memory/pinecone_wrapper.py | 1 + tests/test_chromadb.py | 4 +- 9 files changed, 430 insertions(+), 3 deletions(-) rename chroma_db.py => examples/chroma_db.py (100%) create mode 100644 examples/faiss_example.py rename pinecome_wrapper_example.py => examples/pinecome_wrapper_example.py (100%) create mode 100644 swarms_memory/faiss_wrapper.py diff --git a/README.md b/README.md index 6abc459..f632af2 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,90 @@ print(result) ``` +### Faiss + +```python +from typing import List, Dict, Any +from swarms_memory.faiss_wrapper import FAISSDB + + +from transformers import AutoTokenizer, AutoModel +import torch + + +# Custom embedding function using a HuggingFace model +def custom_embedding_function(text: str) -> List[float]: + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + model = AutoModel.from_pretrained("bert-base-uncased") + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + with torch.no_grad(): + outputs = model(**inputs) + embeddings = ( + outputs.last_hidden_state.mean(dim=1).squeeze().tolist() + ) + return embeddings + + +# Custom preprocessing function +def custom_preprocess(text: str) -> str: + return text.lower().strip() + + +# Custom postprocessing function +def custom_postprocess( + results: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + for result in results: + result["custom_score"] = ( + result["score"] * 2 + ) # Example modification + return results + + +# Initialize the wrapper with custom functions +wrapper = FAISSDB( + dimension=768, + index_type="Flat", + embedding_function=custom_embedding_function, + preprocess_function=custom_preprocess, + postprocess_function=custom_postprocess, + metric="cosine", + logger_config={ + "handlers": [ + { + "sink": "custom_faiss_rag_wrapper.log", + "rotation": "1 GB", + }, + {"sink": lambda msg: print(f"Custom log: {msg}", end="")}, + ], + }, +) + +# Adding documents +wrapper.add( + "This is a sample document about artificial intelligence.", + {"category": "AI"}, +) +wrapper.add( + "Python is a popular programming language for data science.", + {"category": "Programming"}, +) + +# Querying +results = wrapper.query("What is AI?") +for result in results: + print( + f"Score: {result['score']}, Custom Score: {result['custom_score']}, Text: {result['metadata']['text']}" + ) +``` + + # License MIT diff --git a/chroma_db.py b/examples/chroma_db.py similarity index 100% rename from chroma_db.py rename to examples/chroma_db.py diff --git a/examples/faiss_example.py b/examples/faiss_example.py new file mode 100644 index 0000000..7eb4d3d --- /dev/null +++ b/examples/faiss_example.py @@ -0,0 +1,78 @@ +from typing import List, Dict, Any +from swarms_memory.faiss_wrapper import FAISSDB + + +from transformers import AutoTokenizer, AutoModel +import torch + + +# Custom embedding function using a HuggingFace model +def custom_embedding_function(text: str) -> List[float]: + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + model = AutoModel.from_pretrained("bert-base-uncased") + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + with torch.no_grad(): + outputs = model(**inputs) + embeddings = ( + outputs.last_hidden_state.mean(dim=1).squeeze().tolist() + ) + return embeddings + + +# Custom preprocessing function +def custom_preprocess(text: str) -> str: + return text.lower().strip() + + +# Custom postprocessing function +def custom_postprocess( + results: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + for result in results: + result["custom_score"] = ( + result["score"] * 2 + ) # Example modification + return results + + +# Initialize the wrapper with custom functions +wrapper = FAISSDB( + dimension=768, + index_type="Flat", + embedding_function=custom_embedding_function, + preprocess_function=custom_preprocess, + postprocess_function=custom_postprocess, + metric="cosine", + logger_config={ + "handlers": [ + { + "sink": "custom_faiss_rag_wrapper.log", + "rotation": "1 GB", + }, + {"sink": lambda msg: print(f"Custom log: {msg}", end="")}, + ], + }, +) + +# Adding documents +wrapper.add( + "This is a sample document about artificial intelligence.", + {"category": "AI"}, +) +wrapper.add( + "Python is a popular programming language for data science.", + {"category": "Programming"}, +) + +# Querying +results = wrapper.query("What is AI?") +for result in results: + print( + f"Score: {result['score']}, Custom Score: {result['custom_score']}, Text: {result['metadata']['text']}" + ) diff --git a/pinecome_wrapper_example.py b/examples/pinecome_wrapper_example.py similarity index 100% rename from pinecome_wrapper_example.py rename to examples/pinecome_wrapper_example.py diff --git a/swarms_memory/__init__.py b/swarms_memory/__init__.py index a018703..6e917b2 100644 --- a/swarms_memory/__init__.py +++ b/swarms_memory/__init__.py @@ -1,4 +1,5 @@ from swarms_memory.chroma_db_wrapper import ChromaDB from swarms_memory.pinecone_wrapper import PineconeMemory +from swarms_memory.faiss_wrapper import FAISSDB -__all__ = ["ChromaDB", "PineconeMemory"] +__all__ = ["ChromaDB", "PineconeMemory", "FAISSDB"] diff --git a/swarms_memory/chroma_db_wrapper.py b/swarms_memory/chroma_db_wrapper.py index bfd9d68..c5b85bf 100644 --- a/swarms_memory/chroma_db_wrapper.py +++ b/swarms_memory/chroma_db_wrapper.py @@ -51,6 +51,7 @@ def __init__( *args, **kwargs, ): + super().__init__(*args, **kwargs) self.metric = metric self.output_dir = output_dir self.limit_tokens = limit_tokens diff --git a/swarms_memory/faiss_wrapper.py b/swarms_memory/faiss_wrapper.py new file mode 100644 index 0000000..8a63bb0 --- /dev/null +++ b/swarms_memory/faiss_wrapper.py @@ -0,0 +1,262 @@ +from typing import List, Dict, Any, Callable, Optional +import faiss +import numpy as np +from loguru import logger +from sentence_transformers import SentenceTransformer +from swarms.memory.base_vectordb import BaseVectorDatabase + + +class FAISSDB(BaseVectorDatabase): + """ + A highly customizable wrapper class for FAISS-based Retrieval-Augmented Generation (RAG) system. + + This class provides methods to add documents to the FAISS index and query the index + for similar documents. It allows for custom embedding models, preprocessing functions, + and other customizations. + """ + + def __init__( + self, + dimension: int = 768, + index_type: str = "Flat", + embedding_model: Optional[Any] = None, + embedding_function: Optional[ + Callable[[str], List[float]] + ] = None, + preprocess_function: Optional[Callable[[str], str]] = None, + postprocess_function: Optional[ + Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]] + ] = None, + metric: str = "cosine", + logger_config: Optional[Dict[str, Any]] = None, + *args, + **kwargs, + ): + """ + Initialize the FAISSDB. + + Args: + dimension (int): Dimension of the document embeddings. Defaults to 768. + index_type (str): Type of FAISS index to use. Defaults to 'Flat'. + embedding_model (Optional[Any]): Custom embedding model. Defaults to None. + embedding_function (Optional[Callable]): Custom embedding function. Defaults to None. + preprocess_function (Optional[Callable]): Custom preprocessing function. Defaults to None. + postprocess_function (Optional[Callable]): Custom postprocessing function. Defaults to None. + metric (str): Distance metric for FAISS index. Defaults to 'cosine'. + logger_config (Optional[Dict]): Configuration for the logger. Defaults to None. + """ + super().__init__(*args, **kwargs) + self._setup_logger(logger_config) + logger.info("Initializing FAISSDB") + + self.dimension = dimension + self.index = self._create_index(index_type, metric) + self.documents = [] + + self.embedding_model = embedding_model or SentenceTransformer( + "all-MiniLM-L6-v2" + ) + self.embedding_function = ( + embedding_function or self._default_embedding_function + ) + self.preprocess_function = ( + preprocess_function or self._default_preprocess_function + ) + self.postprocess_function = ( + postprocess_function or self._default_postprocess_function + ) + + logger.info("FAISSDB initialized successfully") + + def _setup_logger(self, config: Optional[Dict[str, Any]] = None): + """Set up the logger with the given configuration.""" + default_config = { + "handlers": [ + { + "sink": "faiss_rag_wrapper.log", + "rotation": "500 MB", + }, + {"sink": lambda msg: print(msg, end="")}, + ], + } + logger.configure(**(config or default_config)) + + def _create_index(self, index_type: str, metric: str): + """Create and return a FAISS index based on the specified type and metric.""" + if metric == "cosine": + index = faiss.IndexFlatIP(self.dimension) + elif metric == "l2": + index = faiss.IndexFlatL2(self.dimension) + else: + raise ValueError(f"Unsupported metric: {metric}") + + if index_type == "Flat": + return index + elif index_type == "IVF": + nlist = 100 # number of clusters + quantizer = faiss.IndexFlatL2(self.dimension) + index = faiss.IndexIVFFlat( + quantizer, self.dimension, nlist + ) + else: + raise ValueError(f"Unsupported index type: {index_type}") + + return index + + def _default_embedding_function(self, text: str) -> List[float]: + """Default embedding function using the SentenceTransformer model.""" + return self.embedding_model.encode(text).tolist() + + def _default_preprocess_function(self, text: str) -> str: + """Default preprocessing function.""" + return text.strip() + + def _default_postprocess_function( + self, results: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Default postprocessing function.""" + return results + + def add( + self, doc: str, metadata: Optional[Dict[str, Any]] = None + ) -> None: + """ + Add a document to the FAISS index. + + Args: + doc (str): The document to be added. + metadata (Optional[Dict[str, Any]]): Additional metadata for the document. + + Returns: + None + """ + logger.info(f"Adding document: {doc[:50]}...") + processed_doc = self.preprocess_function(doc) + embedding = self.embedding_function(processed_doc) + + self.index.add(np.array([embedding], dtype=np.float32)) + metadata = metadata or {} + metadata["text"] = processed_doc + self.documents.append(metadata) + + logger.success( + f"Document added successfully. Total documents: {len(self.documents)}" + ) + + def query( + self, query: str, top_k: int = 5 + ) -> List[Dict[str, Any]]: + """ + Query the FAISS index for similar documents. + + Args: + query (str): The query string. + top_k (int): The number of top results to return. Defaults to 5. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the top_k most similar documents. + """ + logger.info(f"Querying with: {query}") + processed_query = self.preprocess_function(query) + query_embedding = self.embedding_function(processed_query) + + distances, indices = self.index.search( + np.array([query_embedding], dtype=np.float32), top_k + ) + + results = [] + for i, (distance, idx) in enumerate( + zip(distances[0], indices[0]) + ): + if idx != -1: # FAISS uses -1 for empty slots + result = { + "id": idx, + "score": 1 + - distance, # Convert distance to similarity score + "metadata": self.documents[idx], + } + results.append(result) + + processed_results = self.postprocess_function(results) + logger.success( + f"Query completed. Found {len(processed_results)} results." + ) + return processed_results + + +# # Example usage +# if __name__ == "__main__": +# from transformers import AutoTokenizer, AutoModel +# import torch + +# # Custom embedding function using a HuggingFace model +# def custom_embedding_function(text: str) -> List[float]: +# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +# model = AutoModel.from_pretrained("bert-base-uncased") +# inputs = tokenizer( +# text, +# return_tensors="pt", +# padding=True, +# truncation=True, +# max_length=512, +# ) +# with torch.no_grad(): +# outputs = model(**inputs) +# embeddings = ( +# outputs.last_hidden_state.mean(dim=1).squeeze().tolist() +# ) +# return embeddings + +# # Custom preprocessing function +# def custom_preprocess(text: str) -> str: +# return text.lower().strip() + +# # Custom postprocessing function +# def custom_postprocess( +# results: List[Dict[str, Any]], +# ) -> List[Dict[str, Any]]: +# for result in results: +# result["custom_score"] = ( +# result["score"] * 2 +# ) # Example modification +# return results + +# # Initialize the wrapper with custom functions +# wrapper = FAISSDB( +# dimension=768, +# index_type="Flat", +# embedding_function=custom_embedding_function, +# preprocess_function=custom_preprocess, +# postprocess_function=custom_postprocess, +# metric="cosine", +# logger_config={ +# "handlers": [ +# { +# "sink": "custom_faiss_rag_wrapper.log", +# "rotation": "1 GB", +# }, +# { +# "sink": lambda msg: print( +# f"Custom log: {msg}", end="" +# ) +# }, +# ], +# }, +# ) + +# # Adding documents +# wrapper.add( +# "This is a sample document about artificial intelligence.", +# {"category": "AI"}, +# ) +# wrapper.add( +# "Python is a popular programming language for data science.", +# {"category": "Programming"}, +# ) + +# # Querying +# results = wrapper.query("What is AI?") +# for result in results: +# print( +# f"Score: {result['score']}, Custom Score: {result['custom_score']}, Text: {result['metadata']['text']}" +# ) diff --git a/swarms_memory/pinecone_wrapper.py b/swarms_memory/pinecone_wrapper.py index a0f2dbb..e2f49c3 100644 --- a/swarms_memory/pinecone_wrapper.py +++ b/swarms_memory/pinecone_wrapper.py @@ -51,6 +51,7 @@ def __init__( namespace (str): Pinecone namespace. Defaults to ''. logger_config (Optional[Dict]): Configuration for the logger. Defaults to None. """ + super().__init__() self._setup_logger(logger_config) logger.info("Initializing PineconeMemory") diff --git a/tests/test_chromadb.py b/tests/test_chromadb.py index 4e43559..7504305 100644 --- a/tests/test_chromadb.py +++ b/tests/test_chromadb.py @@ -17,8 +17,8 @@ def test_init(mock_client, mock_persistent_client): assert chroma_db.output_dir == "swarms" assert chroma_db.limit_tokens == 1000 assert chroma_db.n_results == 1 - assert chroma_db.docs_folder == None - assert chroma_db.verbose == False + assert chroma_db.docs_folder is None + assert chroma_db.verbose is False mock_persistent_client.assert_called_once() mock_client.assert_called_once()