Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Oct 28, 2024
1 parent 8487460 commit fc6d636
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 58 deletions.
11 changes: 6 additions & 5 deletions examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,15 @@ def __init__(

def _prepare_document_search(self, database_path: str, index_name: str) -> None:
chroma_client = chromadb.PersistentClient(path=database_path)
embedding_client = LiteLLMEmbeddings()

vector_store = ChromaDBStore(
client=chroma_client,
index_name=index_name,
chroma_client=chroma_client,
embedding_function=embedding_client,
)
self.document_search = DocumentSearch(embedder=embedding_client, vector_store=vector_store)
embedder = LiteLLMEmbeddings()
self.document_search = DocumentSearch(
embedder=embedder,
vector_store=vector_store,
)

async def _create_database(self, document_paths: list[str]) -> str:
for path in document_paths:
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"distance_method": "l2",
"default_options": {
"k": 3,
"max_distance": 1.15,
"max_distance": 1.2,
},
},
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys

from ..utils.config_handling import get_cls_from_config
from .base import VectorDBEntry, VectorStore, WhereQuery
from .base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

Expand All @@ -27,4 +27,4 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
if vector_store_config["type"] == "ChromaDBStore":
return vector_store_cls.from_config(config)

return vector_store_cls(**config)
return vector_store_cls(default_options=VectorStoreOptions(**config.get("default_options", {})))
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def from_config(cls, config: dict) -> ChromaDBStore:
client=client(**config["client"].get("config", {})),
index_name=config["index_name"],
distance_method=config.get("distance_method", "l2"),
default_options=VectorStoreOptions(**config.get("options", {})),
default_options=VectorStoreOptions(**config.get("default_options", {})),
)

async def store(self, entries: list[VectorDBEntry]) -> None:
Expand Down
35 changes: 18 additions & 17 deletions packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

import numpy as np

from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, WhereQuery
from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery


class InMemoryVectorStore(VectorStore):
"""
A simple in-memory implementation of Vector Store, storing vectors in memory.
"""

def __init__(self) -> None:
def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
super().__init__(default_options)
self._storage: dict[str, VectorDBEntry] = {}

async def store(self, entries: list[VectorDBEntry]) -> None:
Expand All @@ -23,30 +24,30 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
for entry in entries:
self._storage[entry.key] = entry

async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]:
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]:
"""
Retrieve entries from the vector store.
Args:
vector: The vector to search for.
k: The number of entries to retrieve.
options: The options for querying the vector store.
Returns:
The entries.
"""
knn = []

for entry in self._storage.values():
entry_distance = self._calculate_squared_euclidean(entry.vector, vector)
knn.append((entry, entry_distance))

knn.sort(key=lambda x: x[1])

return [entry for entry, _ in knn[:k]]

@staticmethod
def _calculate_squared_euclidean(vector_x: list[float], vector_b: list[float]) -> float:
return float(np.linalg.norm(np.array(vector_x) - np.array(vector_b)))
options = self._default_options if options is None else options
entries = sorted(
(
(entry, float(np.linalg.norm(np.array(entry.vector) - np.array(vector))))
for entry in self._storage.values()
),
key=lambda x: x[1],
)
return [
entry
for entry, distance in entries[: options.k]
if options.max_distance is None or distance <= options.max_distance
]

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@

import pytest

from ragbits.core.vector_store.base import VectorDBEntry
from ragbits.core.vector_store.base import VectorDBEntry, VectorStoreOptions
from ragbits.core.vector_store.chromadb_store import ChromaDBStore


@pytest.fixture
def mock_chromadb_store():
def mock_chromadb_store() -> ChromaDBStore:
return ChromaDBStore(
client=MagicMock(),
index_name="test_index",
)


def test_chromadbstore_init_import_error():
def test_init_import_error() -> None:
with patch("ragbits.core.vector_store.chromadb_store.HAS_CHROMADB", False), pytest.raises(ImportError):
ChromaDBStore(
client=MagicMock(),
index_name="test_index",
)


def test_get_chroma_collection(mock_chromadb_store: ChromaDBStore):
def test_get_chroma_collection(mock_chromadb_store: ChromaDBStore) -> None:
_ = mock_chromadb_store._get_chroma_collection()
assert mock_chromadb_store._client.get_or_create_collection.call_count == 2 # type: ignore


async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore):
async def test_store(mock_chromadb_store: ChromaDBStore) -> None:
data = [
VectorDBEntry(
key="test_key",
Expand Down Expand Up @@ -59,32 +59,52 @@ async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore):
)


async def test_retrieves_entries_correctly(mock_chromadb_store: ChromaDBStore):
@pytest.mark.parametrize(
("max_distance", "results"),
[
(
None,
[
{"content": "test content 1", "title": "test title 1", "vector": [0.12, 0.25, 0.29]},
{"content": "test content 2", "title": "test title 2", "vector": [0.13, 0.26, 0.30]},
],
),
(0.1, [{"content": "test content 1", "title": "test title 1", "vector": [0.12, 0.25, 0.29]}]),
(0.09, []),
],
)
async def test_retrieve(mock_chromadb_store: ChromaDBStore, max_distance: float | None, results: list[dict]) -> None:
vector = [0.1, 0.2, 0.3]
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.query.return_value = { # type: ignore
"metadatas": [
[
{
"__key": "test_key",
"__metadata": '{"content": "test content", "document": {"title": "test title", "source":'
"__key": "test_key_1",
"__metadata": '{"content": "test content 1", "document": {"title": "test title 1", "source":'
' {"path": "/test/path-1"}, "document_type": "txt"}}',
},
{
"__key": "test_key_2",
"__metadata": '{"content": "test content 2", "document": {"title": "test title 2", "source":'
' {"path": "/test/path-2"}, "document_type": "txt"}}',
},
]
],
"embeddings": [[[0.12, 0.25, 0.29]]],
"distances": [[0.1]],
"embeddings": [[[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]]],
"distances": [[0.1, 0.2]],
}

entries = await mock_chromadb_store.retrieve(vector)
entries = await mock_chromadb_store.retrieve(vector, options=VectorStoreOptions(max_distance=max_distance))

assert len(entries) == 1
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]
assert len(entries) == len(results)
for entry, result in zip(entries, results, strict=False):
assert entry.metadata["content"] == result["content"]
assert entry.metadata["document"]["title"] == result["title"]
assert entry.vector == result["vector"]


async def test_lists_entries_correctly(mock_chromadb_store: ChromaDBStore):
async def test_list(mock_chromadb_store: ChromaDBStore) -> None:
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.get.return_value = { # type: ignore
"metadatas": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from ragbits.core.vector_store.base import VectorStoreOptions
from ragbits.core.vector_store.in_memory import InMemoryVectorStore
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
Expand Down Expand Up @@ -30,7 +31,7 @@ def get_key(self) -> str:


@pytest.fixture(name="store")
async def store_fixture():
async def store_fixture() -> InMemoryVectorStore:
document_meta = DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt")))
elements = [
(AnimalElement(name="spikey", species="dog", type="mammal", age=5, document_meta=document_meta), [0.5, 0.5]),
Expand All @@ -51,84 +52,92 @@ async def store_fixture():
return store


async def test_simple_vector_store(store: InMemoryVectorStore):
@pytest.mark.parametrize(
("k", "max_distance", "results"),
[
(5, None, ["spikey", "fluffy", "slimy", "spotty", "scaly"]),
(2, None, ["spikey", "fluffy"]),
(5, 0.3, ["spikey", "fluffy"]),
],
)
async def test_retrieve(store: InMemoryVectorStore, k: int, max_distance: float | None, results: list[str]) -> None:
search_vector = [0.4, 0.4]

results = await store.retrieve(search_vector, 2)
entries = await store.retrieve(search_vector, options=VectorStoreOptions(k=k, max_distance=max_distance))

assert len(results) == 2
assert results[0].metadata["name"] == "spikey"
assert results[1].metadata["name"] == "fluffy"
assert len(entries) == len(results)
for entry, result in zip(entries, results, strict=False):
assert entry.metadata["name"] == result


async def test_list_all(store: InMemoryVectorStore):
async def test_list_all(store: InMemoryVectorStore) -> None:
results = await store.list()

assert len(results) == 6
names = [result.metadata["name"] for result in results]
assert names == ["spikey", "fluffy", "slimy", "scaly", "hairy", "spotty"]


async def test_list_limit(store: InMemoryVectorStore):
async def test_list_limit(store: InMemoryVectorStore) -> None:
results = await store.list(limit=3)

assert len(results) == 3
names = {result.metadata["name"] for result in results}
assert names == {"spikey", "fluffy", "slimy"}


async def test_list_offset(store: InMemoryVectorStore):
async def test_list_offset(store: InMemoryVectorStore) -> None:
results = await store.list(offset=3)

assert len(results) == 3
names = {result.metadata["name"] for result in results}
assert names == {"scaly", "hairy", "spotty"}


async def test_limit_with_offset(store: InMemoryVectorStore):
async def test_limit_with_offset(store: InMemoryVectorStore) -> None:
results = await store.list(limit=2, offset=3)

assert len(results) == 2
names = {result.metadata["name"] for result in results}
assert names == {"scaly", "hairy"}


async def test_where(store: InMemoryVectorStore):
async def test_where(store: InMemoryVectorStore) -> None:
results = await store.list(where={"type": "insect"})

assert len(results) == 2
names = {result.metadata["name"] for result in results}
assert names == {"hairy", "spotty"}


async def test_multiple_where(store: InMemoryVectorStore):
async def test_multiple_where(store: InMemoryVectorStore) -> None:
results = await store.list(where={"type": "insect", "age": 1})

assert len(results) == 1
assert results[0].metadata["name"] == "spotty"


async def test_empty_where(store: InMemoryVectorStore):
async def test_empty_where(store: InMemoryVectorStore) -> None:
results = await store.list(where={})

assert len(results) == 6
names = {result.metadata["name"] for result in results}
assert names == {"spikey", "fluffy", "slimy", "scaly", "hairy", "spotty"}


async def test_empty_results(store: InMemoryVectorStore):
async def test_empty_results(store: InMemoryVectorStore) -> None:
results = await store.list(where={"type": "bird"})

assert len(results) == 0


async def test_empty_results_with_limit(store: InMemoryVectorStore):
async def test_empty_results_with_limit(store: InMemoryVectorStore) -> None:
results = await store.list(where={"type": "bird"}, limit=2)

assert len(results) == 0


async def test_where_limit(store: InMemoryVectorStore):
async def test_where_limit(store: InMemoryVectorStore) -> None:
results = await store.list(where={"type": "insect"}, limit=1)

assert len(results) == 1
Expand Down

0 comments on commit fc6d636

Please sign in to comment.