Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move vector search capabilities to core package #39

Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import chromadb

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
Expand Down
3 changes: 3 additions & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ dependencies = [
]

[project.optional-dependencies]
chromadb = [
"chromadb~=0.4.24",
]
litellm = [
"litellm~=1.46.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
HAS_CHROMADB = False

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.vector_store.base import VectorStore
from ragbits.document_search.vector_store.in_memory import VectorDBEntry
from ragbits.core.vector_store.base import VectorStore
from ragbits.core.vector_store.in_memory import VectorDBEntry


class ChromaDBStore(VectorStore):
Expand Down Expand Up @@ -102,6 +102,7 @@ def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float,
A dictionary with the same keys as the input, where JSON strings are parsed
into their respective Python data types.
"""
metadata["document_meta"] = metadata.pop("document")
return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()}

def _is_json(self, myjson: str) -> bool:
Expand Down Expand Up @@ -139,12 +140,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
Args:
entries: The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
self._collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Expand All @@ -157,8 +156,7 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
Expand All @@ -183,14 +181,11 @@ async def find_similar(self, text: str) -> Optional[str]:
Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self._embedding_function, Embeddings):
embedding = await self._embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
retrieved = self._collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)
retrieved = self._collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from ragbits.document_search.vector_store.base import VectorDBEntry, VectorStore
from ragbits.core.vector_store.base import VectorDBEntry, VectorStore


class InMemoryVectorStore(VectorStore):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore, VectorDBEntry
from ragbits.core.vector_store.chromadb_store import ChromaDBStore, VectorDBEntry


@pytest.fixture
Expand Down Expand Up @@ -61,13 +61,14 @@ def mock_vector_db_entry():


def test_chromadbstore_init_import_error():
with patch("ragbits.document_search.vector_store.chromadb_store.HAS_CHROMADB", False):
with patch("ragbits.core.vector_store.chromadb_store.HAS_CHROMADB", False):
with pytest.raises(ImportError):
ChromaDBStore(index_name="test_index", chroma_client=MagicMock(), embedding_function=MagicMock())


def test_get_chroma_collection(mock_chromadb_store):
_ = mock_chromadb_store._get_chroma_collection()

assert mock_chromadb_store._chroma_client.get_or_create_collection.called


Expand All @@ -91,13 +92,14 @@ async def test_stores_entries_correctly(mock_chromadb_store):
},
)
]

await mock_chromadb_store.store(data)

mock_chromadb_store._chroma_client.get_or_create_collection().add.assert_called_once()


def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):
id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)
print(f"metadata: {metadata}, type: {type(metadata)}")

assert id == sha256(b"test_key").hexdigest()
assert embedding == [0.1, 0.2, 0.3]
Expand All @@ -111,6 +113,7 @@ def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):

async def test_store(mock_chromadb_store, mock_vector_db_entry):
await mock_chromadb_store.store([mock_vector_db_entry])

assert mock_chromadb_store._chroma_client.get_or_create_collection().add.called


Expand All @@ -129,17 +132,21 @@ async def test_retrieves_entries_correctly(mock_chromadb_store):
]
],
}

entries = await mock_chromadb_store.retrieve(vector)

assert len(entries) == 1
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].metadata["document_meta"]["title"] == "test title"


async def test_handles_empty_retrieve(mock_chromadb_store):
vector = [0.1, 0.2, 0.3]
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.query.return_value = {"documents": [], "metadatas": []}

entries = await mock_chromadb_store.retrieve(vector)

assert len(entries) == 0


Expand All @@ -150,7 +157,9 @@ async def test_find_similar(mock_chromadb_store, mock_embedding_function):
"documents": [["test content"]],
"distances": [[0.1]],
}

result = await mock_chromadb_store.find_similar("test text")

assert result == "test content"


Expand All @@ -160,7 +169,9 @@ async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_e
"documents": [["test content"]],
"distances": [[0.1]],
}

result = await mock_chromadb_store.find_similar("test text")

assert result == "test content"


Expand All @@ -178,16 +189,15 @@ def test_repr(mock_chromadb_store):
)
def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expected):
mock_chromadb_store._max_distance = max_distance

result = mock_chromadb_store._return_best_match(retrieved)

assert result == expected


def test_is_json_valid_string(mock_chromadb_store):
# Arrange
valid_json_string = '{"key": "value"}'

# Act
result = mock_chromadb_store._is_json(valid_json_string)

# Assert
assert result is True
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path

from ragbits.core.vector_store.in_memory import InMemoryVectorStore
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore


async def test_simple_vector_store():
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-document-search/examples/simple_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import asyncio

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
Expand Down
3 changes: 0 additions & 3 deletions packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ dependencies = [
]

[project.optional-dependencies]
chromadb = [
"chromadb~=0.4.24",
]
gcs = [
"gcloud-aio-storage~=9.3.0"
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.vector_store.base import VectorStore
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
Expand All @@ -7,7 +8,6 @@
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker
from ragbits.document_search.vector_store.base import VectorStore


class DocumentSearch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pydantic import BaseModel

from ragbits.core.vector_store.base import VectorDBEntry
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.base import VectorDBEntry


class Element(BaseModel, ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest.mock import AsyncMock

from ragbits.core.vector_store.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore


async def test_document_search():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ragbits.core.vector_store.base import VectorDBEntry
from ragbits.document_search.documents.document import DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.vector_store.base import VectorDBEntry


def test_resolving_element_type():
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ description = "Ragbits development workspace"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ragbits[litellm,local]",
"ragbits[litellm,local,chromadb]",
"ragbits-dev-kit",
"ragbits-document-search[chromadb,gcs]",
"ragbits-document-search[gcs]",
"ragbits-cli"
]

Expand Down
16 changes: 8 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading