Skip to content

Commit

Permalink
element fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Nov 15, 2024
1 parent 6b83467 commit 1c1ecd8
Show file tree
Hide file tree
Showing 17 changed files with 130 additions and 105 deletions.
2 changes: 1 addition & 1 deletion examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _handle_message(
if not self._documents_ingested:
yield self.NO_DOCUMENTS_INGESTED_MESSAGE
results = await self.document_search.search(message[-1])
prompt = RAGPrompt(QueryWithContext(query=message, context=[i.get_text_representation() for i in results]))
prompt = RAGPrompt(QueryWithContext(query=message, context=[i.text_representation for i in results]))
response = await self._llm.generate(prompt)
yield response.answer

Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_text_representation() for element in results])
print([element.text_representation for element in results])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/chroma_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_text_representation() for element in results])
print([element.text_representation for element in results])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def main() -> None:
print("Results for 'Fluffy teady bear toy':")
for result in results:
document = await result.document_meta.fetch()
print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.get_text_representation()}")
print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.text_representation}")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def main() -> None:

print()
print(f"Documents similar to: {query}")
print([element.get_text_representation() for element in results])
print([element.text_representation for element in results])


if __name__ == "__main__":
Expand Down
33 changes: 13 additions & 20 deletions packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,27 @@
import sys

from ..metadata_stores import get_metadata_store
from ..utils.config_handling import get_cls_from_config
from .base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from .in_memory import InMemoryVectorStore
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "WhereQuery"]
__all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "VectorStoreOptions", "WhereQuery"]

module = sys.modules[__name__]


def get_vector_store(vector_store_config: dict) -> VectorStore:
def get_vector_store(config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.
Args:
vector_store_config: A dictionary containing configuration details for the VectorStore.
config: A dictionary containing configuration details for the VectorStore.
Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.
"""
vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})
if vector_store_config["type"].endswith(("ChromaVectorStore", "QdrantVectorStore")):
return vector_store_cls.from_config(config)

metadata_store_config = vector_store_config.get("metadata_store_config")
return vector_store_cls(
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(metadata_store_config),
)
Raises:
KeyError: If the provided configuration does not contain a valid "type" key.
InvalidConfigurationError: If the provided configuration is invalid.
NotImplementedError: If the specified VectorStore class cannot be created from the provided configuration.
"""
vector_store_cls = get_cls_from_config(config["type"], sys.modules[__name__])
return vector_store_cls.from_config(config.get("config", {}))
18 changes: 17 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class VectorStoreEntry(BaseModel):
"""

id: str
key: str
vector: list[float]
content: str
metadata: dict


Expand Down Expand Up @@ -48,6 +48,22 @@ def __init__(
self._default_options = default_options or VectorStoreOptions()
self._metadata_store = metadata_store

@classmethod
def from_config(cls, config: dict) -> "VectorStore":
"""
Creates and returns an instance of the Reranker class from the given configuration.
Args:
config: A dictionary containing the configuration for initializing the Reranker instance.
Returns:
An initialized instance of the Reranker class.
Raises:
NotImplementedError: If the class cannot be created from the provided configuration.
"""
raise NotImplementedError(f"Cannot create class {cls.__name__} from config.")

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
"""
Expand Down
29 changes: 11 additions & 18 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Literal

import chromadb
from chromadb import Collection
from chromadb.api import ClientAPI

from ragbits.core.audit import traceable
Expand Down Expand Up @@ -39,16 +38,7 @@ def __init__(
self._client = client
self._index_name = index_name
self._distance_method = distance_method
self._collection = self._get_chroma_collection()

def _get_chroma_collection(self) -> Collection:
"""
Gets or creates a collection with the given name and metadata.
Returns:
The collection.
"""
return self._client.get_or_create_collection(
self._collection = self._client.get_or_create_collection(
name=self._index_name,
metadata={"hnsw:space": self._distance_method},
)
Expand All @@ -68,7 +58,7 @@ def from_config(cls, config: dict) -> "ChromaVectorStore":
return cls(
client=client_cls(**config["client"].get("config", {})),
index_name=config["index_name"],
distance_method=config.get("distance_method", "l2"),
distance_method=config.get("distance_method", "cosine"),
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(config.get("metadata_store")),
)
Expand All @@ -81,8 +71,11 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Args:
entries: The entries to store.
"""
if not entries:
return

ids = [entry.id for entry in entries]
documents = [entry.content for entry in entries]
documents = [entry.key for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [entry.metadata for entry in entries]

Expand Down Expand Up @@ -130,12 +123,12 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
return [
VectorStoreEntry(
id=id,
content=document,
key=document,
vector=list(embeddings),
metadata=metadata, # type: ignore
)
for batch in zip(ids, metadatas, embeddings, distances, documents, strict=False)
for id, metadata, embeddings, distance, document in zip(*batch, strict=False)
for batch in zip(ids, metadatas, embeddings, distances, documents, strict=True)
for id, metadata, embeddings, distance, document in zip(*batch, strict=True)
if options.max_distance is None or distance <= options.max_distance
]

Expand Down Expand Up @@ -180,9 +173,9 @@ async def list(
return [
VectorStoreEntry(
id=id,
content=document,
key=document,
vector=list(embedding),
metadata=metadata, # type: ignore
)
for id, metadata, embedding, document in zip(ids, metadatas, embeddings, documents, strict=False)
for id, metadata, embedding, document in zip(ids, metadatas, embeddings, documents, strict=True)
]
17 changes: 17 additions & 0 deletions packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ragbits.core.audit import traceable
from ragbits.core.metadata_stores import get_metadata_store
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

Expand All @@ -27,6 +28,22 @@ def __init__(
super().__init__(default_options=default_options, metadata_store=metadata_store)
self._storage: dict[str, VectorStoreEntry] = {}

@classmethod
def from_config(cls, config: dict) -> "InMemoryVectorStore":
"""
Creates and returns an instance of the InMemoryVectorStore class from the given configuration.
Args:
config: A dictionary containing the configuration for initializing the InMemoryVectorStore instance.
Returns:
An initialized instance of the InMemoryVectorStore class.
"""
return cls(
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(config.get("metadata_store")),
)

@traceable
async def store(self, entries: list[VectorStoreEntry]) -> None:
"""
Expand Down
21 changes: 12 additions & 9 deletions packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Raises:
QdrantException: If upload to collection fails.
"""
if not entries:
return

if not await self._client.collection_exists(self._index_name):
await self._client.create_collection(
collection_name=self._index_name,
Expand All @@ -78,16 +81,16 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:

ids = [entry.id for entry in entries]
embeddings = [entry.vector for entry in entries]
payloads = [{"__document": entry.content} for entry in entries]
payloads = [{"document": entry.key} for entry in entries]
metadatas = [entry.metadata for entry in entries]

metadatas = (
[{"__metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
[{"metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
if self._metadata_store is None
else await self._metadata_store.store(ids, metadatas) # type: ignore
)
if metadatas is not None:
payloads = [{**payload, **metadata} for (payload, metadata) in zip(payloads, metadatas, strict=False)]
payloads = [{**payload, **metadata} for (payload, metadata) in zip(payloads, metadatas, strict=True)]

self._client.upload_collection(
collection_name=self._index_name,
Expand Down Expand Up @@ -126,17 +129,17 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

ids = [point.id for point in results.points]
vectors = [point.vector for point in results.points]
documents = [point.payload["__document"] for point in results.points] # type: ignore
documents = [point.payload["document"] for point in results.points] # type: ignore
metadatas = (
[json.loads(point.payload["__metadata"]) for point in results.points] # type: ignore
[json.loads(point.payload["metadata"]) for point in results.points] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(ids) # type: ignore
)

return [
VectorStoreEntry(
id=str(id),
content=document,
key=document,
vector=vector, # type: ignore
metadata=metadata,
)
Expand Down Expand Up @@ -176,17 +179,17 @@ async def list( # type: ignore

ids = [point.id for point in results.points]
vectors = [point.vector for point in results.points]
documents = [point.payload["__document"] for point in results.points] # type: ignore
documents = [point.payload["document"] for point in results.points] # type: ignore
metadatas = (
[json.loads(point.payload["__metadata"]) for point in results.points] # type: ignore
[json.loads(point.payload["metadata"]) for point in results.points] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(ids) # type: ignore
)

return [
VectorStoreEntry(
id=str(id),
content=document,
key=document,
vector=vector, # type: ignore
metadata=metadata,
)
Expand Down
19 changes: 6 additions & 13 deletions packages/ragbits-core/tests/unit/vector_stores/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@ def mock_chromadb_store() -> ChromaVectorStore:
)


async def test_get_chroma_collection(mock_chromadb_store: ChromaVectorStore) -> None:
_ = mock_chromadb_store._get_chroma_collection()
assert mock_chromadb_store._client.get_or_create_collection.call_count == 2 # type: ignore


async def test_store(mock_chromadb_store: ChromaVectorStore) -> None:
data = [
VectorStoreEntry(
id="test_key",
content="test content",
key="test content",
vector=[0.1, 0.2, 0.3],
metadata={
"content": "test content",
Expand Down Expand Up @@ -70,8 +65,7 @@ async def test_retrieve(
mock_chromadb_store: ChromaVectorStore, max_distance: float | None, results: list[dict]
) -> None:
vector = [0.1, 0.2, 0.3]
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.query.return_value = { # type: ignore
mock_chromadb_store._collection.query.return_value = { # type: ignore
"metadatas": [
[
{
Expand All @@ -98,12 +92,11 @@ async def test_retrieve(
assert entry.metadata["document"]["title"] == result["title"]
assert entry.vector == result["vector"]
assert entry.id == f"test_id_{results.index(result) + 1}"
assert entry.content == result["content"]
assert entry.key == result["content"]


async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.get.return_value = { # type: ignore
mock_chromadb_store._collection.get.return_value = { # type: ignore
"metadatas": [
{
"__metadata": '{"content": "test content", "document": {"title": "test title", "source":'
Expand All @@ -125,10 +118,10 @@ async def test_list(mock_chromadb_store: ChromaVectorStore) -> None:
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]
assert entries[0].content == "test content 1"
assert entries[0].key == "test content 1"
assert entries[0].id == "test_id_1"
assert entries[1].metadata["content"] == "test content 2"
assert entries[1].metadata["document"]["title"] == "test title 2"
assert entries[1].vector == [0.13, 0.26, 0.30]
assert entries[1].content == "test content2"
assert entries[1].key == "test content2"
assert entries[1].id == "test_id_2"
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import pytest
from pydantic import computed_field

from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
Expand All @@ -20,12 +21,14 @@ class AnimalElement(Element):
type: str
age: int

def get_text_representation(self) -> str:
@computed_field # type: ignore[prop-decorator]
@property
def text_representation(self) -> str:
"""
Get the text representation of the element.
Returns:
The key.
The text representation.
"""
return self.name

Expand Down
Loading

0 comments on commit 1c1ecd8

Please sign in to comment.