Skip to content

Commit

Permalink
unified interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Nov 14, 2024
1 parent 87879c6 commit 6b83467
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/document-search/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_key() for element in results])
print([element.get_text_representation() for element in results])


if __name__ == "__main__":
Expand Down
22 changes: 10 additions & 12 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import json
from typing import Literal

Expand Down Expand Up @@ -56,7 +54,7 @@ def _get_chroma_collection(self) -> Collection:
)

@classmethod
def from_config(cls, config: dict) -> ChromaVectorStore:
def from_config(cls, config: dict) -> "ChromaVectorStore":
"""
Creates and returns an instance of the ChromaVectorStore class from the given configuration.
Expand Down Expand Up @@ -87,11 +85,13 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
documents = [entry.content for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [entry.metadata for entry in entries]

metadatas = (
[{"__metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
if self._metadata_store is None
else await self._metadata_store.store(ids, metadatas) # type: ignore
)

self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) # type: ignore

@traceable
Expand All @@ -116,14 +116,13 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
n_results=options.k,
include=["metadatas", "embeddings", "distances", "documents"],
)

ids = results.get("ids") or []
metadatas = results.get("metadatas") or []
embeddings = results.get("embeddings") or []
distances = results.get("distances") or []
documents = results.get("documents") or []

metadatas = [
[json.loads(metadata["__metadata"]) for batch in metadatas for metadata in batch] # type: ignore
[json.loads(metadata["__metadata"]) for batch in results.get("metadatas", []) for metadata in batch] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(*ids)
]
Expand Down Expand Up @@ -162,19 +161,18 @@ async def list(
# Cast `where` to chromadb's Where type
where_chroma: chromadb.Where | None = dict(where) if where else None

get_results = self._collection.get(
results = self._collection.get(
where=where_chroma,
limit=limit,
offset=offset,
include=["metadatas", "embeddings", "documents"],
)
ids = get_results.get("ids") or []
metadatas = get_results.get("metadatas") or []
embeddings = get_results.get("embeddings") or []
documents = get_results.get("documents") or []

ids = results.get("ids") or []
embeddings = results.get("embeddings") or []
documents = results.get("documents") or []
metadatas = (
[json.loads(metadata["__metadata"]) for metadata in metadatas] # type: ignore
[json.loads(metadata["__metadata"]) for metadata in results.get("metadatas", [])] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(ids)
)
Expand Down
21 changes: 10 additions & 11 deletions packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from __future__ import annotations

import json
import uuid

import qdrant_client
from qdrant_client import AsyncQdrantClient
Expand Down Expand Up @@ -43,7 +40,7 @@ def __init__(
self._distance_method = distance_method

@classmethod
def from_config(cls, config: dict) -> QdrantVectorStore:
def from_config(cls, config: dict) -> "QdrantVectorStore":
"""
Creates and returns an instance of the QdrantVectorStore class from the given configuration.
Expand Down Expand Up @@ -79,11 +76,11 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
vectors_config=VectorParams(size=len(entries[0].vector), distance=self._distance_method),
)

ids = [str(uuid.uuid5(uuid.NAMESPACE_DNS, str(entry))) for entry in entries]
ids = [entry.id for entry in entries]
embeddings = [entry.vector for entry in entries]
payloads = [{"__document": entry.key} for entry in entries]

payloads = [{"__document": entry.content} for entry in entries]
metadatas = [entry.metadata for entry in entries]

metadatas = (
[{"__metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
if self._metadata_store is None
Expand Down Expand Up @@ -138,11 +135,12 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

return [
VectorStoreEntry(
key=document,
id=str(id),
content=document,
vector=vector, # type: ignore
metadata=metadata,
)
for document, vector, metadata in zip(documents, vectors, metadatas, strict=True)
for id, document, vector, metadata in zip(ids, documents, vectors, metadatas, strict=True)
]

@traceable
Expand Down Expand Up @@ -187,9 +185,10 @@ async def list( # type: ignore

return [
VectorStoreEntry(
key=document,
id=str(id),
content=document,
vector=vector, # type: ignore
metadata=metadata,
)
for document, vector, metadata in zip(documents, vectors, metadatas, strict=True)
for id, document, vector, metadata in zip(ids, documents, vectors, metadatas, strict=True)
]
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def mock_qdrant_store() -> QdrantVectorStore:
async def test_store(mock_qdrant_store: QdrantVectorStore) -> None:
data = [
VectorStoreEntry(
key="test_key",
id="1c7d6b27-4ef1-537c-ad7c-676edb8bc8a8",
content="test_key",
vector=[0.1, 0.2, 0.3],
metadata={
"content": "test content",
Expand Down

0 comments on commit 6b83467

Please sign in to comment.