From 589bda702168c3fd40f0152391044c52689c7a91 Mon Sep 17 00:00:00 2001 From: Radek Dymacz Date: Mon, 9 Dec 2024 23:50:04 +0100 Subject: [PATCH] singlestore support --- .env.examples | 19 +- README.md | 76 ++++- examples/singlestore_wrapper_example.py | 60 ++++ pyproject.toml | 1 + requirements.txt | 3 +- swarms_memory/memory/singlestore.md | 208 ++++++++++++ swarms_memory/vector_dbs/__init__.py | 2 + .../vector_dbs/singlestore_wrapper.py | 271 ++++++++++++++++ tests/vector_dbs/test_singlestore.py | 299 ++++++++++++++++++ 9 files changed, 933 insertions(+), 6 deletions(-) create mode 100644 examples/singlestore_wrapper_example.py create mode 100644 swarms_memory/memory/singlestore.md create mode 100644 swarms_memory/vector_dbs/singlestore_wrapper.py create mode 100644 tests/vector_dbs/test_singlestore.py diff --git a/.env.examples b/.env.examples index 9d8ad7c..967d9b5 100644 --- a/.env.examples +++ b/.env.examples @@ -1,2 +1,19 @@ +# Pinecone Configuration PINECONE_API_KEYS="your_pinecone_api_key" -BASE_SWARMS_MEMORY_URL="http:" \ No newline at end of file + +# Base URL Configuration +BASE_SWARMS_MEMORY_URL="http:" + +# SingleStore Configuration +# Host can be localhost for local development or your SingleStore deployment URL +SINGLESTORE_HOST="your_singlestore_host" # e.g., "localhost" or "svc-123-xyz.aws.singlestore.com" + +# Default port is 3306, but might be different for your deployment +SINGLESTORE_PORT="3306" + +# Your SingleStore user credentials +SINGLESTORE_USER="your_singlestore_username" # e.g., "admin" +SINGLESTORE_PASSWORD="your_singlestore_password" + +# Database name where vector tables will be created +SINGLESTORE_DATABASE="your_database_name" # e.g., "vector_store" \ No newline at end of file diff --git a/README.md b/README.md index 9ea76d1..6d2d548 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ -

Swarms Memory

@@ -39,8 +38,9 @@ Here's a more detailed and larger table with descriptions and website links for | **ChromaDB** | Available | A high-performance, distributed database optimized for handling large-scale AI tasks. | [ChromaDB Documentation](swarms_memory/memory/chromadb.md) | [ChromaDB](https://chromadb.com) | | **Pinecone** | Available | A fully managed vector database that makes it easy to add vector search to your applications. | [Pinecone Documentation](swarms_memory/memory/pinecone.md) | [Pinecone](https://pinecone.io) | | **Redis** | Coming Soon | An open-source, in-memory data structure store, used as a database, cache, and message broker. | [Redis Documentation](swarms_memory/memory/redis.md) | [Redis](https://redis.io) | -| **Faiss** | Coming Soon | A library for efficient similarity search and clustering of dense vectors, developed by Facebook AI. | [Faiss Documentation](swarms_memory/memory/faiss.md) | [Faiss](https://faiss.ai) | -| **HNSW** | Coming Soon | A graph-based algorithm for approximate nearest neighbor search, known for its speed and accuracy. | [HNSW Documentation](swarms_memory/memory/hnsw.md) | [HNSW](https://github.com/nmslib/hnswlib) | +| **Faiss** | Available | A library for efficient similarity search and clustering of dense vectors, developed by Facebook AI. | [Faiss Documentation](swarms_memory/memory/faiss.md) | [Faiss](https://faiss.ai) | +| **SingleStore**| Available | A distributed SQL database that provides high-performance vector similarity search. | [SingleStore Documentation](swarms_memory/memory/singlestore.md) | [SingleStore](https://www.singlestore.com) | +| **HNSW** | Coming Soon | A graph-based algorithm for approximate nearest neighbor search. | [HNSW Documentation](swarms_memory/memory/hnsw.md) | [HNSW](https://github.com/nmslib/hnswlib) | This table includes a brief description of each system, their current status, links to their documentation, and their respective websites for further information. @@ -259,6 +259,75 @@ for result in results: ``` +### SingleStore +```python +from swarms_memory.vector_dbs.singlestore_wrapper import SingleStoreDB + +# Initialize SingleStore with environment variables +db = SingleStoreDB( + host="your_host", + port=3306, + user="your_user", + password="your_password", + database="your_database", + table_name="example_vectors", + dimension=768, # Default dimension for all-MiniLM-L6-v2 + namespace="example" +) + +# Custom embedding function example (optional) +def custom_embedding_function(text: str) -> List[float]: + # Your custom embedding logic here + return embeddings + +# Initialize with custom functions +db = SingleStoreDB( + host="your_host", + port=3306, + user="your_user", + password="your_password", + database="your_database", + table_name="example_vectors", + dimension=768, + namespace="example", + embedding_function=custom_embedding_function, + preprocess_function=lambda x: x.lower(), # Simple preprocessing + postprocess_function=lambda x: sorted(x, key=lambda k: k['similarity'], reverse=True) # Sort by similarity +) + +# Add documents with metadata +doc_id = db.add( + document="SingleStore is a distributed SQL database that combines horizontal scalability with ACID guarantees.", + metadata={"source": "docs", "category": "database"} +) + +# Query similar documents +results = db.query( + query="How does SingleStore scale?", + top_k=3, + metadata_filter={"source": "docs"} +) + +# Process results +for result in results: + print(f"Document: {result['document']}") + print(f"Similarity: {result['similarity']:.4f}") + print(f"Metadata: {result['metadata']}\n") + +# Delete a document +db.delete(doc_id) + +# Key features: +# - Built on SingleStore's native vector similarity search +# - Supports custom embedding models and functions +# - Automatic table creation with optimized vector indexing +# - Metadata filtering for refined searches +# - Document preprocessing and postprocessing +# - Namespace support for document organization +# - SSL support for secure connections + +# For more examples, see the [SingleStore example](examples/singlestore_wrapper_example.py). +``` # License MIT @@ -275,4 +344,3 @@ Please cite Swarms in your paper or your project if you found it beneficial in a note = {Accessed: Date} } ``` - diff --git a/examples/singlestore_wrapper_example.py b/examples/singlestore_wrapper_example.py new file mode 100644 index 0000000..8f59e54 --- /dev/null +++ b/examples/singlestore_wrapper_example.py @@ -0,0 +1,60 @@ +import os +from dotenv import load_dotenv +from swarms_memory.vector_dbs.singlestore_wrapper import SingleStoreDB + +# Load environment variables +load_dotenv() + +def main(): + # Initialize SingleStore with environment variables + db = SingleStoreDB( + host=os.getenv("SINGLESTORE_HOST"), + port=int(os.getenv("SINGLESTORE_PORT", "3306")), + user=os.getenv("SINGLESTORE_USER"), + password=os.getenv("SINGLESTORE_PASSWORD"), + database=os.getenv("SINGLESTORE_DATABASE"), + table_name="example_vectors", + dimension=768, # Default dimension for all-MiniLM-L6-v2 + namespace="example" + ) + + # Example documents + documents = [ + "SingleStore is a distributed SQL database that combines the horizontal scalability of NoSQL systems with the ACID guarantees of traditional RDBMSs.", + "Vector similarity search in SingleStore uses DOT_PRODUCT distance type for efficient nearest neighbor queries.", + "SingleStore supports both row and column store formats, making it suitable for both transactional and analytical workloads." + ] + + # Add documents to the database + doc_ids = [] + for doc in documents: + doc_id = db.add( + document=doc, + metadata={"source": "example", "type": "documentation"} + ) + doc_ids.append(doc_id) + print(f"Added document with ID: {doc_id}") + + # Query similar documents + query = "How does SingleStore handle vector similarity search?" + results = db.query( + query=query, + top_k=2, + metadata_filter={"source": "example"} + ) + + print("\nQuery:", query) + print("\nResults:") + for result in results: + print(f"\nDocument: {result['document']}") + print(f"Similarity: {result['similarity']:.4f}") + print(f"Metadata: {result['metadata']}") + + # Clean up - delete documents + print("\nCleaning up...") + for doc_id in doc_ids: + db.delete(doc_id) + print(f"Deleted document with ID: {doc_id}") + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index bba2587..26dd83c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ pinecone = "*" faiss-cpu = "*" pydantic = "*" sqlalchemy = "*" +singlestoredb = "*" diff --git a/requirements.txt b/requirements.txt index d17103d..4a1572f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pinecone faiss-cpu torch pydantic -sqlalchemy \ No newline at end of file +sqlalchemy +singlestoredb \ No newline at end of file diff --git a/swarms_memory/memory/singlestore.md b/swarms_memory/memory/singlestore.md new file mode 100644 index 0000000..7ace32c --- /dev/null +++ b/swarms_memory/memory/singlestore.md @@ -0,0 +1,208 @@ +# SingleStore Vector Database + +SingleStore is a distributed SQL database that provides high-performance vector similarity search capabilities. This implementation uses the official SingleStore Python library to provide efficient vector storage and similarity search for your RAG (Retrieval-Augmented Generation) system. + +## Features + +- High-performance vector similarity search using SingleStore's native vector operations +- Automatic vector indexing for fast similarity search +- Support for custom embedding models and functions +- Document preprocessing and postprocessing capabilities +- Namespace support for document organization +- Comprehensive error handling and logging +- Built-in connection management using the official SingleStore Python client + +## Installation + +```bash +pip install singlestoredb sentence-transformers numpy +``` + +## Configuration + +Set up your SingleStore credentials in your environment variables: + +```bash +SINGLESTORE_HOST=your_host +SINGLESTORE_USER=your_user +SINGLESTORE_PASSWORD=your_password +``` + +## Usage + +### Basic Usage + +```python +from swarms_memory.vector_dbs import SingleStoreDB + +# Initialize the database +db = SingleStoreDB( + database="vectordb", + table_name="embeddings", + dimension=768 # matches the embedding model's dimension +) + +# Add a document +doc_id = db.add("This is a sample document") + +# Query similar documents +results = db.query("Find similar documents", top_k=3) + +# Get a specific document +doc = db.get(doc_id) + +# Delete a document +db.delete(doc_id) +``` + +### Advanced Usage + +#### Custom Embedding Model + +```python +from sentence_transformers import SentenceTransformer + +# Use a different embedding model +db = SingleStoreDB( + database="vectordb", + embedding_model=SentenceTransformer("all-mpnet-base-v2"), + dimension=768 +) +``` + +#### Custom Embedding Function + +```python +def custom_embedding_function(text: str) -> List[float]: + # Your custom embedding logic here + return [0.1, 0.2, ...] # Must match dimension + +db = SingleStoreDB( + database="vectordb", + embedding_function=custom_embedding_function, + dimension=768 +) +``` + +#### Document Preprocessing + +```python +def preprocess_text(text: str) -> str: + # Custom preprocessing logic + return text.lower().strip() + +db = SingleStoreDB( + database="vectordb", + preprocess_function=preprocess_text +) +``` + +#### Result Postprocessing + +```python +def postprocess_results(results: List[Dict]) -> List[Dict]: + # Custom postprocessing logic + return sorted(results, key=lambda x: x["similarity"], reverse=True) + +db = SingleStoreDB( + database="vectordb", + postprocess_function=postprocess_results +) +``` + +#### Using Namespaces + +```python +# Initialize with a default namespace +db = SingleStoreDB( + database="vectordb", + namespace="project1" +) + +# Add document to the namespace +db.add("Document in project1") + +# Query within the namespace +results = db.query("Query in project1") + +# Query in a different namespace +results = db.query("Query in project2", namespace="project2") +``` + +## API Reference + +### SingleStoreDB Class + +```python +class SingleStoreDB: + def __init__( + self, + host: str = None, + port: int = 3306, + user: str = None, + password: str = None, + database: str = "vectordb", + table_name: str = "embeddings", + dimension: int = 768, + embedding_model: Optional[Any] = None, + embedding_function: Optional[Callable] = None, + preprocess_function: Optional[Callable] = None, + postprocess_function: Optional[Callable] = None, + namespace: str = "", + verbose: bool = False + ) +``` + +#### Methods + +- `add(document: str, metadata: Dict = None, embedding: List[float] = None, doc_id: str = None) -> str` + Add a document to the database. + +- `query(query: str, top_k: int = 5, embedding: List[float] = None, namespace: str = None) -> List[Dict]` + Query similar documents. + +- `delete(doc_id: str, namespace: str = None) -> bool` + Delete a document from the database. + +- `get(doc_id: str, namespace: str = None) -> Optional[Dict]` + Retrieve a specific document. + +## Performance Optimization + +### Vector Indexing + +The implementation automatically creates a vector index on the embedding column: + +```sql +VECTOR INDEX vec_idx (embedding) DIMENSION = {dimension} +``` + +This index significantly improves the performance of similarity search queries. + +### Connection Management + +The implementation uses the official SingleStore Python client with proper connection management: +- Automatic connection pooling +- Context managers for cursor operations +- Proper cleanup of resources + +### Query Optimization + +- Uses native SingleStore vector operations for similarity search +- Efficient handling of vector data using SingleStore's array type +- Proper indexing for fast lookups and filtering + +## Error Handling + +The implementation includes comprehensive error handling for: +- Connection issues +- Query execution errors +- Invalid embeddings +- Missing documents +- Authentication failures + +All errors are logged using the `loguru` logger for easy debugging. + +## Contributing + +We welcome contributions! Please feel free to submit a Pull Request. diff --git a/swarms_memory/vector_dbs/__init__.py b/swarms_memory/vector_dbs/__init__.py index ea24111..6eaee28 100644 --- a/swarms_memory/vector_dbs/__init__.py +++ b/swarms_memory/vector_dbs/__init__.py @@ -2,10 +2,12 @@ from swarms_memory.vector_dbs.pinecone_wrapper import PineconeMemory from swarms_memory.vector_dbs.faiss_wrapper import FAISSDB from swarms_memory.vector_dbs.base_vectordb import BaseVectorDatabase +from swarms_memory.vector_dbs.singlestore_wrapper import SingleStoreDB __all__ = [ "ChromaDB", "PineconeMemory", "FAISSDB", "BaseVectorDatabase", + "SingleStoreDB", ] diff --git a/swarms_memory/vector_dbs/singlestore_wrapper.py b/swarms_memory/vector_dbs/singlestore_wrapper.py new file mode 100644 index 0000000..3b45a6b --- /dev/null +++ b/swarms_memory/vector_dbs/singlestore_wrapper.py @@ -0,0 +1,271 @@ +from typing import Any, Callable, Dict, List, Optional +import os +import uuid +import json +import numpy as np +from loguru import logger +import singlestoredb as s2 +from sentence_transformers import SentenceTransformer +from swarms_memory.vector_dbs.base_vectordb import BaseVectorDatabase + + +class SingleStoreDB(BaseVectorDatabase): + """ + A highly customizable wrapper class for SingleStore-based Retrieval-Augmented Generation (RAG) system. + + This class provides methods to add documents to SingleStore and query for similar documents + using vector similarity search. It supports custom embedding models, preprocessing functions, + and other customizations. + """ + + def __init__( + self, + host: str, + user: str, + password: str, + database: str, + table_name: str, + dimension: int = 768, + port: int = 3306, + ssl: bool = True, + ssl_verify: bool = True, + embedding_model: Optional[Any] = None, + embedding_function: Optional[Callable[[str], List[float]]] = None, + preprocess_function: Optional[Callable[[str], str]] = None, + postprocess_function: Optional[Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]] = None, + namespace: str = "", + logger_config: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the SingleStoreDB wrapper. + + Args: + host (str): SingleStore host address + user (str): SingleStore username + password (str): SingleStore password + database (str): Database name + table_name (str): Table name for vector storage + dimension (int): Dimension of vectors to store. Defaults to 768 + port (int): SingleStore port number. Defaults to 3306 + ssl (bool): Whether to use SSL for connection. Defaults to True + ssl_verify (bool): Whether to verify SSL certificate. Defaults to True + embedding_model (Optional[Any]): Model for generating embeddings. Defaults to None + embedding_function (Optional[Callable]): Custom function for generating embeddings. Defaults to None + preprocess_function (Optional[Callable]): Custom function for preprocessing documents. Defaults to None + postprocess_function (Optional[Callable]): Custom function for postprocessing query results. Defaults to None + namespace (str): Namespace for document organization. Defaults to "" + logger_config (Optional[Dict]): Configuration for the logger. Defaults to None + """ + super().__init__() + self._setup_logger(logger_config) + logger.info("Initializing SingleStoreDB") + + # Store connection parameters + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.ssl = ssl + self.ssl_verify = ssl_verify + + self.table_name = table_name + self.dimension = dimension + self.namespace = namespace + + # Set up embedding model and functions + self.embedding_model = embedding_model or SentenceTransformer("all-MiniLM-L6-v2") + self.embedding_function = embedding_function or self._default_embedding_function + self.preprocess_function = preprocess_function or self._default_preprocess_function + self.postprocess_function = postprocess_function or self._default_postprocess_function + + # Initialize database and create table if needed + self._initialize_database() + logger.info("SingleStoreDB initialized successfully") + + def _setup_logger(self, config: Optional[Dict[str, Any]] = None): + """Set up the logger with the given configuration.""" + default_config = { + "handlers": [ + {"sink": "singlestore_wrapper.log", "rotation": "500 MB"}, + {"sink": lambda msg: print(msg, end="")}, + ], + } + logger.configure(**(config or default_config)) + + def _default_embedding_function(self, text: str) -> np.ndarray: + """Default embedding function using the SentenceTransformer model.""" + return self.embedding_model.encode(text) + + def _default_preprocess_function(self, text: str) -> str: + """Default preprocessing function.""" + return text.strip() + + def _default_postprocess_function( + self, results: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Default postprocessing function.""" + return results + + def _initialize_database(self): + """Initialize the database and create the vector table if it doesn't exist.""" + # Build connection string with SSL options + ssl_params = [] + if self.ssl: + ssl_params.append("ssl=true") + if not self.ssl_verify: + ssl_params.append("ssl_verify=false") + + ssl_string = "&".join(ssl_params) + if ssl_string: + ssl_string = "?" + ssl_string + + # Use standard connection URL format as per documentation + self.connection_string = f"{self.user}:{self.password}@{self.host}:{self.port}/{self.database}{ssl_string}" + + with s2.connect(self.connection_string) as conn: + with conn.cursor() as cursor: + # Create table with optimized settings for vector operations + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id VARCHAR(255) PRIMARY KEY, + document TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci, + embedding BLOB, + metadata JSON, + namespace VARCHAR(255), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + KEY idx_namespace (namespace), + VECTOR INDEX vec_idx (embedding) DIMENSION = {self.dimension} DISTANCE_TYPE = DOT_PRODUCT + ) ENGINE = columnstore; + """) + logger.info(f"Table {self.table_name} initialized") + + def add( + self, + document: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Add a document to the vector database. + + Args: + document (str): The document text to add + metadata (Optional[Dict[str, Any]]): Additional metadata for the document + + Returns: + str: Document ID of the added document + """ + logger.info(f"Adding document: {document[:50]}...") + + # Process document and generate embedding + processed_doc = self.preprocess_function(document) + embedding = self.embedding_function(processed_doc) + + # Prepare metadata + doc_id = str(uuid.uuid4()) + metadata = metadata or {} + metadata_json = json.dumps(metadata) + + # Insert into database + with s2.connect(self.connection_string) as conn: + with conn.cursor() as cursor: + cursor.execute( + f""" + INSERT INTO {self.table_name} + (id, document, embedding, metadata, namespace) + VALUES (%s, %s, %s, %s, %s) + """, + (doc_id, processed_doc, embedding, metadata_json, self.namespace) + ) + + logger.success(f"Document added successfully with ID: {doc_id}") + return doc_id + + def query( + self, + query: str, + top_k: int = 5, + metadata_filter: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + """ + Query the vector database for similar documents. + + Args: + query (str): Query text + top_k (int): Number of results to return. Defaults to 5 + metadata_filter (Optional[Dict[str, Any]]): Filter results by metadata + + Returns: + List[Dict[str, Any]]: List of similar documents with their metadata and similarity scores + """ + logger.info(f"Querying with: {query}") + + # Process query and generate embedding + processed_query = self.preprocess_function(query) + query_embedding = self.embedding_function(processed_query) + + # Construct metadata filter if provided + filter_clause = "" + if metadata_filter: + filter_conditions = [] + for key, value in metadata_filter.items(): + filter_conditions.append(f"JSON_EXTRACT(metadata, '$.{key}') = '{value}'") + if filter_conditions: + filter_clause = "AND " + " AND ".join(filter_conditions) + + # Query database + with s2.connect(self.connection_string) as conn: + with conn.cursor() as cursor: + cursor.execute( + f""" + SELECT id, document, metadata, DOT_PRODUCT(embedding, %s) as similarity + FROM {self.table_name} + WHERE namespace = %s {filter_clause} + ORDER BY similarity DESC + LIMIT %s + """, + (query_embedding, self.namespace, top_k) + ) + results = cursor.fetchall() + + # Format and process results + formatted_results = [] + for doc_id, document, metadata_json, similarity in results: + metadata = json.loads(metadata_json) if metadata_json else {} + formatted_results.append({ + "id": doc_id, + "document": self.postprocess_function(document) if document else None, + "metadata": metadata, + "similarity": float(similarity) + }) + + logger.success(f"Query completed. Found {len(formatted_results)} results.") + return formatted_results + + def delete(self, doc_id: str) -> bool: + """ + Delete a document from the database. + + Args: + doc_id (str): ID of the document to delete + + Returns: + bool: True if deletion was successful + """ + logger.info(f"Deleting document with ID: {doc_id}") + + with s2.connect(self.connection_string) as conn: + with conn.cursor() as cursor: + cursor.execute( + f"DELETE FROM {self.table_name} WHERE id = %s AND namespace = %s", + (doc_id, self.namespace) + ) + deleted = cursor.rowcount > 0 + + if deleted: + logger.success(f"Document {doc_id} deleted successfully") + else: + logger.warning(f"Document {doc_id} not found") + + return deleted diff --git a/tests/vector_dbs/test_singlestore.py b/tests/vector_dbs/test_singlestore.py new file mode 100644 index 0000000..c69d858 --- /dev/null +++ b/tests/vector_dbs/test_singlestore.py @@ -0,0 +1,299 @@ +import os +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +import singlestoredb as s2 +from swarms_memory.vector_dbs.singlestore_wrapper import SingleStoreDB + + +@pytest.fixture +def mock_singlestore(): + with patch('singlestoredb.connect') as mock_connect: + # Create mock connection and cursor + mock_cursor = MagicMock() + mock_connection = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + mock_connect.return_value.__enter__.return_value = mock_connection + + # Initialize DB with test configuration + db = SingleStoreDB( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + table_name="test_embeddings", + dimension=4, + namespace="test", + ssl=True, + ssl_verify=True + ) + yield db, mock_cursor, mock_connect + + +def test_initialization(mock_singlestore): + db, mock_cursor, mock_connect = mock_singlestore + + # Verify connection string format with SSL + expected_conn_string = "test_user:test_password@localhost:3306/test_db?ssl=true" + assert mock_connect.call_args[0][0] == expected_conn_string + + # Verify table creation with new schema + create_table_call = mock_cursor.execute.call_args_list[0] + create_table_sql = create_table_call[0][0] + + # Check for new schema elements + assert "CREATE TABLE IF NOT EXISTS" in create_table_sql + assert "test_embeddings" in create_table_sql + assert "document TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" in create_table_sql + assert "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP" in create_table_sql + assert "updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" in create_table_sql + assert "KEY idx_namespace (namespace)" in create_table_sql + assert "VECTOR INDEX vec_idx (embedding) DIMENSION = 4 DISTANCE_TYPE = DOT_PRODUCT" in create_table_sql + assert "ENGINE = columnstore" in create_table_sql + + +@patch('singlestoredb.connect') +def test_ssl_configuration(mock_connect): + # Setup mock + mock_cursor = MagicMock() + mock_connection = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + mock_connect.return_value.__enter__.return_value = mock_connection + + # Test with SSL disabled + db_no_ssl = SingleStoreDB( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + table_name="test_embeddings", + dimension=4, + ssl=False + ) + + # Verify connection string without SSL + expected_conn_string = "test_user:test_password@localhost:3306/test_db" + assert mock_connect.call_args[0][0] == expected_conn_string + + # Reset mock + mock_connect.reset_mock() + + # Test with SSL enabled but no verification + db_ssl_no_verify = SingleStoreDB( + host="localhost", + port=3306, + user="test_user", + password="test_password", + database="test_db", + table_name="test_embeddings", + dimension=4, + ssl=True, + ssl_verify=False + ) + + # Verify connection string with SSL no verify + expected_conn_string = "test_user:test_password@localhost:3306/test_db?ssl=true&ssl_verify=false" + assert mock_connect.call_args[0][0] == expected_conn_string + + +def test_add_document(mock_singlestore): + db, mock_cursor, _ = mock_singlestore + + # Mock embedding function + test_embedding = np.array([0.1, 0.2, 0.3, 0.4]) + db.embedding_function = MagicMock(return_value=test_embedding) + + # Add document with metadata + metadata = {"key": "value"} + doc_id = db.add("test document", metadata=metadata) + + # Verify the insert query + insert_call = mock_cursor.execute.call_args_list[-1] + insert_sql = insert_call[0][0] + + # Verify SQL includes all columns + assert "INSERT INTO test_embeddings" in insert_sql + assert "id" in insert_sql + assert "document" in insert_sql + assert "embedding" in insert_sql + assert "metadata" in insert_sql + assert "namespace" in insert_sql + + # Verify parameters + params = insert_call[0][1] + assert len(params) == 5 # id, document, embedding, metadata, namespace + assert params[1] == "test document" # document + np.testing.assert_array_equal(params[2], test_embedding) # embedding + assert '"key": "value"' in params[3] # metadata + assert params[4] == "test" # namespace + + +def test_query(mock_singlestore): + db, mock_cursor, _ = mock_singlestore + + # Mock embedding function and query results + test_embedding = np.array([0.1, 0.2, 0.3, 0.4]) + db.embedding_function = MagicMock(return_value=test_embedding) + + mock_cursor.fetchall.return_value = [ + ("doc1", "content1", '{"key": "value1"}', 0.9), + ("doc2", "content2", '{"key": "value2"}', 0.8) + ] + + # Query with metadata filter + results = db.query( + "test query", + top_k=2, + metadata_filter={"category": "test"} + ) + + # Verify query execution + query_call = mock_cursor.execute.call_args_list[-1] + assert "SELECT" in query_call[0][0] + assert "DOT_PRODUCT" in query_call[0][0] + assert "JSON_EXTRACT" in query_call[0][0] + + # Verify results + assert len(results) == 2 + assert results[0]["id"] == "doc1" + assert results[0]["similarity"] == 0.9 + assert results[0]["metadata"]["key"] == "value1" + + +def test_delete(mock_singlestore): + db, mock_cursor, _ = mock_singlestore + + # Mock successful deletion + mock_cursor.rowcount = 1 + + # Delete document + success = db.delete("test_id") + + # Verify deletion query + delete_call = mock_cursor.execute.call_args_list[-1] + assert "DELETE FROM test_embeddings" in delete_call[0][0] + assert delete_call[0][1][0] == "test_id" + assert success is True + + # Test unsuccessful deletion + mock_cursor.rowcount = 0 + success = db.delete("nonexistent_id") + assert success is False + + +def test_preprocessing(mock_singlestore): + db, mock_cursor, _ = mock_singlestore + + # Define custom preprocessing function + def custom_preprocess(text: str) -> str: + return text.strip().lower() + + db.preprocess_function = custom_preprocess + + # Mock embedding function + test_embedding = np.array([0.1, 0.2, 0.3, 0.4]) + db.embedding_function = MagicMock(return_value=test_embedding) + + # Add document with preprocessing + db.add(" TEST DOCUMENT ") + + # Verify preprocessed document + insert_call = mock_cursor.execute.call_args_list[-1] + assert insert_call[0][1][1] == "test document" + + +def test_postprocessing(mock_singlestore): + db, mock_cursor, _ = mock_singlestore + + # Define custom postprocessing function + def custom_postprocess(document: str) -> str: + return document.upper() + + db.postprocess_function = custom_postprocess + + # Mock embedding function and query results + test_embedding = np.array([0.1, 0.2, 0.3, 0.4]) + db.embedding_function = MagicMock(return_value=test_embedding) + + mock_cursor.fetchall.return_value = [ + ("doc1", "test document", "{}", 0.9) + ] + + # Query with postprocessing + results = db.query("test") + + # Verify postprocessed results + assert results[0]["document"] == "TEST DOCUMENT" + + +def test_logger_setup(mock_singlestore): + db, _, _ = mock_singlestore + + # Test with custom logger config + custom_config = { + "handlers": [ + {"sink": "custom.log", "rotation": "1 MB"}, + ], + } + + with patch('loguru.logger.configure') as mock_configure: + db._setup_logger(custom_config) + mock_configure.assert_called_once_with(**custom_config) + + # Test with default config + with patch('loguru.logger.configure') as mock_configure: + db._setup_logger(None) + called_config = mock_configure.call_args[1] + assert "handlers" in called_config + assert len(called_config["handlers"]) == 2 + assert called_config["handlers"][0]["sink"] == "singlestore_wrapper.log" + + +def test_default_embedding_function(mock_singlestore): + db, _, _ = mock_singlestore + + # Mock the embedding model + test_embedding = np.array([0.1, 0.2, 0.3, 0.4]) + db.embedding_model = MagicMock() + db.embedding_model.encode.return_value = test_embedding + + # Test the default embedding function + result = db._default_embedding_function("test text") + + # Verify the embedding model was called correctly + db.embedding_model.encode.assert_called_once_with("test text") + np.testing.assert_array_equal(result, test_embedding) + assert isinstance(result, np.ndarray) + + +def test_default_preprocess_function(mock_singlestore): + db, _, _ = mock_singlestore + + # Test with spaces to trim + result = db._default_preprocess_function(" test text ") + assert result == "test text" + + # Test with no spaces to trim + result = db._default_preprocess_function("test text") + assert result == "test text" + + # Test with empty string + result = db._default_preprocess_function("") + assert result == "" + + +def test_default_postprocess_function(mock_singlestore): + db, _, _ = mock_singlestore + + # Test with empty list + test_results = [] + result = db._default_postprocess_function(test_results) + assert result == [] + + # Test with list of results + test_results = [{"id": "1", "text": "test"}] + result = db._default_postprocess_function(test_results) + assert result == test_results + assert result[0]["id"] == "1"