Skip to content

Commit

Permalink
Add Swarmauri Persistent Qdrant Vector Store community package with t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
MichaelDecent committed Jan 13, 2025
1 parent 1edd4b1 commit 324e615
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 69 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "swarmauri_vectorstore_communitycloudqdrant"
name = "swarmauri_vectorstore_communityqdrant"
version = "0.6.0.dev1"
description = "example community package"
description = "Swarmauri Persistent Qdrant Vector Store"
authors = ["Jacob Stewart <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
Expand All @@ -22,7 +22,7 @@ swarmauri_base = { path = "../../base" }
swarmauri_vectorstore_doc2vec = { path = "../../standards" }

# Dependencies
cloudqdrant = "^1.12.0"
qdrant-client = "^1.12.0"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -55,4 +55,5 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S"
asyncio_default_fixture_loop_scope = "function"

[tool.poetry.plugins."swarmauri.vector_stores"]
CloudQdrantVectorStore = "swarmauri_vectorstore_communitycloudqdrant.CloudQdrantVectorStore:CloudQdrantVectorStore"
CloudQdrantVectorStore = "swarmauri_vectorstore_communityqdrant.CloudQdrantVectorStore:CloudQdrantVectorStore"
PersistentQdrantVectorStore = "swarmauri_vectorstore_communityqdrant.PersistentQdrantVectorStore:PersistentQdrantVectorStore"
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from typing import List, Union, Literal
from pydantic import Field, PrivateAttr, ConfigDict

from qdrant_client import QdrantClient
from qdrant_client.models import (
PointStruct,
VectorParams,
Distance,
)

from swarmauri_standard.documents.Document import Document
from swarmauri_vectorstore_doc2vec.Doc2VecEmbedding import Doc2VecEmbedding
from swarmauri_standard.distances.CosineDistance import CosineDistance

from swarmauri_base.vector_stores.VectorStoreBase import VectorStoreBase
from swarmauri_base.vector_stores.VectorStoreRetrieveMixin import (
VectorStoreRetrieveMixin,
)
from swarmauri_base.vector_stores.VectorStoreSaveLoadMixin import (
VectorStoreSaveLoadMixin,
)
from swarmauri_base.vector_stores.VectorStorePersistentMixin import (
VectorStorePersistentMixin,
)


class PersistentQdrantVectorStore(
VectorStoreSaveLoadMixin,
VectorStoreRetrieveMixin,
VectorStorePersistentMixin,
VectorStoreBase,
):
"""
PersistentQdrantVectorStore is a concrete implementation that integrates functionality
for saving, loading, storing, and retrieving vector documents, leveraging a locally
hosted Qdrant instance as the backend.
"""

type: Literal["PersistentQdrantVectorStore"] = "PersistentQdrantVectorStore"

# allow arbitary types in the model config
model_config = ConfigDict(arbitrary_types_allowed=True)

# Use PrivateAttr to make _embedder and _distance private
_embedder: Doc2VecEmbedding = PrivateAttr()
_distance: CosineDistance = PrivateAttr()
client: Union[QdrantClient, None] = Field(default=None, init=False)

def __init__(self, **kwargs):
super().__init__(**kwargs)

self._embedder = Doc2VecEmbedding(vector_size=self.vector_size)
self._distance = CosineDistance()

def connect(self) -> None:
"""
Connects to the Qdrant vector store using the provided URL.
"""
if self.client is None:
self.client = QdrantClient(path=self.path)

# Check if the collection exists
existing_collections = self.client.get_collections().collections
collection_names = [collection.name for collection in existing_collections]

if self.collection_name not in collection_names:
# Ensure the collection exists with the desired configuration
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.vector_size, distance=Distance.COSINE
),
)

def disconnect(self) -> None:
"""
Disconnects from the Qdrant vector store.
"""
if self.client is not None:
self.client = None

def add_document(self, document: Document) -> None:
"""
Add a single document to the document store.
Parameters:
document (Document): The document to be added to the store.
"""
embedding = None
if not document.embedding:
self._embedder.fit([document.content]) # Fit only once
embedding = (
self._embedder.transform([document.content])[0].to_numpy().tolist()
)
else:
embedding = document.embedding

payload = {
"content": document.content,
"metadata": document.metadata,
}

doc = PointStruct(id=document.id, vector=embedding, payload=payload)

self.client.upsert(
collection_name=self.collection_name,
points=[doc],
)

def add_documents(self, documents: List[Document]) -> None:
"""
Add multiple documents to the document store in a batch operation.
Parameters:
documents (List[Document]): A list of documents to be added to the store.
"""
points = [
PointStruct(
id=doc.id,
vector=doc.embedding
or self._embedder.fit_transform([doc.content])[0].to_numpy().tolist(),
payload={"content": doc.content, "metadata": doc.metadata},
)
for doc in documents
]
self.client.upsert(self.collection_name, points=points)

def get_document(self, id: str) -> Union[Document, None]:
"""
Retrieve a single document by its identifier.
Parameters:
id (str): The unique identifier of the document to retrieve.
Returns:
Union[Document, None]: The requested document if found; otherwise, None.
"""
response = self.client.retrieve(
collection_name=self.collection_name,
ids=[id],
)
if response:
payload = response[0].payload
return Document(
id=id, content=payload["content"], metadata=payload["metadata"]
)
return None

def get_all_documents(self) -> List[Document]:
"""
Retrieve all documents stored in the document store.
Returns:
List[Document]: A list of all documents in the store.
"""
response = self.client.scroll(
collection_name=self.collection_name,
)

return [
Document(
id=doc.id,
content=doc.payload["content"],
metadata=doc.payload["metadata"],
)
for doc in response[0]
]

def delete_document(self, id: str) -> None:
"""
Delete a document from the document store by its identifier.
Parameters:
id (str): The unique identifier of the document to delete.
"""
self.client.delete(self.collection_name, points_selector=[id])

def update_document(self, id: str, updated_document: Document) -> None:
"""
Update a document in the document store.
Parameters:
id (str): The unique identifier of the document to update.
updated_document (Document): The updated document instance.
"""
# Precompute the embedding outside the update process
if not updated_document.embedding:
# Transform without refitting to avoid vocabulary issues
document_vector = self._embedder.transform([updated_document.content])[0]
else:
document_vector = updated_document.embedding

document_vector = document_vector.to_numpy().tolist()

self.client.upsert(
self.collection_name,
points=[
PointStruct(
id=id,
vector=document_vector,
payload={
"content": updated_document.content,
"metadata": updated_document.metadata,
},
)
],
)

def clear_documents(self) -> None:
"""
Deletes all documents from the vector store.
"""
self.client.delete_collection(self.collection_name)

def document_count(self) -> int:
"""
Returns the number of documents in the store.
"""
response = self.client.scroll(
collection_name=self.collection_name,
)
return len(response)

def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
"""
Retrieve the top_k most relevant documents based on the given query.
For the purpose of this example, this method performs a basic search.
Args:
query (str): The query string used for document retrieval.
top_k (int): The number of top relevant documents to retrieve.
Returns:
List[Document]: A list of the top_k most relevant documents.
"""
query_vector = self._embedder.infer_vector(query).value
results = self.client.search(
collection_name=self.collection_name, query_vector=query_vector, limit=top_k
)

return [
Document(
id=res.id,
content=res.payload["content"],
metadata=res.payload["metadata"],
)
for res in results
]

# Override the model_dump_json method
def model_dump_json(self, *args, **kwargs) -> str:
# Call the disconnect method before serialization
self.disconnect()

# Now proceed with the usual JSON serialization
return super().model_dump_json(*args, **kwargs)
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from .PersistentQdrantVectorStore import PersistentQdrantVectorStore
from .CloudQdrantVectorStore import CloudQdrantVectorStore

__version__ = "0.6.0.dev26"
__long_desc__ = """
# Swarmauri CloudQdrant VectorStore Plugin
# Swarmauri Qdrant Based Components
Components Included:
- PersistentQdrantVectorStore
- CloudQdrantVectorStore
Visit us at: https://swarmauri.com
Follow us at: https://github.com/swarmauri
Expand Down
Loading

0 comments on commit 324e615

Please sign in to comment.