Skip to content

Commit

Permalink
Applying pre-commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Sep 20, 2024
1 parent 6c9c8bb commit 717052d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import os

import chromadb

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
from copy import deepcopy
from hashlib import sha256
import json
from typing import Literal, Optional, Union, List

from ragbits.document_search.documents.element import TextElement
from typing import List, Literal, Optional, Union

try:
import chromadb

HAS_CHROMADB = True
except ImportError:
HAS_CHROMADB = False
Expand Down Expand Up @@ -37,7 +36,7 @@ def __init__(
distance_method (Literal["l2", "ip", "cosine"], default="l2"): The distance method to use.
"""
if not HAS_CHROMADB:
raise ImportError("You need to install the 'ragbits-document-search[chromadb]' extra requirement of to use LiteLLM embeddings models")
raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models")

super().__init__()
self.index_name = index_name
Expand Down Expand Up @@ -78,9 +77,9 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]:
return retrieved["documents"][0][0]

return None
def _process_db_entry(self, entry: VectorDBEntry):
id = sha256(entry.key.encode("utf-8")).hexdigest()

def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]:
doc_id = sha256(entry.key.encode("utf-8")).hexdigest()
embedding = entry.vector
text = entry.metadata["content"]

Expand All @@ -89,23 +88,39 @@ def _process_db_entry(self, entry: VectorDBEntry):
metadata["key"] = entry.key
metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()}

return doc_id, embedding, text, metadata

return id, embedding, text, metadata
def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]:
"""
Processes the metadata dictionary by parsing JSON strings if applicable.
def _process_metadata(self, metadata):
return {key: json.loads(val) if self.is_json(val) else val
for key, val in metadata.items()}
Args:
metadata (dict): A dictionary containing metadata where values may be JSON strings.
def is_json(self, myjson) -> bool:
Returns:
dict[str, Union[str, int, float, bool]]: A dictionary with the same keys as the input,
where JSON strings are parsed into their respective Python data types.
"""
return {key: json.loads(val) if self.is_json(val) else val for key, val in metadata.items()}

def is_json(self, myjson: str) -> bool:
"""
Check if the provided string is a valid JSON.
Args:
myjson (str): The string to be checked.
Returns:
bool: True if the string is a valid JSON, False otherwise.
"""
try:
if isinstance(myjson, str):
json.loads(myjson)
return True
return False
except ValueError:
return False



async def store(self, entries: List[VectorDBEntry]) -> None:
"""
Stores entries in the ChromaDB collection.
Expand All @@ -114,9 +129,7 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
entries (List[VectorDBEntry]): The entries to store.
"""
collection = self._get_chroma_collection()



entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))

Expand All @@ -137,12 +150,12 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
query_result = collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for doc, meta in zip(query_result.get("documents"), query_result.get("metadatas")):
for meta in query_result.get("metadatas"):
db_entry = VectorDBEntry(
key=meta[0].get("key"),
vector=vector,
metadata=self._process_metadata(meta[0]),
)
)

db_entries.append(db_entry)

Expand Down Expand Up @@ -177,4 +190,4 @@ def __repr__(self) -> str:
Returns:
str: The string representation of the object.
"""
return f"{self.__class__.__name__}(index_name={self.index_name})"
return f"{self.__class__.__name__}(index_name={self.index_name})"
56 changes: 34 additions & 22 deletions packages/ragbits-document-search/tests/unit/test_chromadb_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from hashlib import sha256
import json
from unittest.mock import AsyncMock, MagicMock, patch
import uuid

import chromadb
import pytest

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.vector_store import chromadb_store
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore, VectorDBEntry


Expand All @@ -33,6 +29,7 @@ def mock_chromadb_store(mock_chroma_client, mock_embedding_function):
class MockEmbeddings(Embeddings):
async def embed_text(self, text):
return [[0.4, 0.5, 0.6]]

def __call__(self, input):
return self.embed_text(input)

Expand All @@ -56,26 +53,27 @@ def mock_vector_db_entry():
return VectorDBEntry(
key="test_key",
vector=[0.1, 0.2, 0.3],
metadata={"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}
metadata={
"content": "test content",
"document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"},
},
)


def test_chromadbstore_init_import_error():
with patch('ragbits.document_search.vector_store.chromadb_store.HAS_CHROMADB', False):
with patch("ragbits.document_search.vector_store.chromadb_store.HAS_CHROMADB", False):
with pytest.raises(ImportError):
ChromaDBStore(
index_name="test_index",
chroma_client=MagicMock(),
embedding_function=MagicMock()
)
ChromaDBStore(index_name="test_index", chroma_client=MagicMock(), embedding_function=MagicMock())


def test_get_chroma_collection(mock_chromadb_store):
_ = mock_chromadb_store._get_chroma_collection()
assert mock_chromadb_store.chroma_client.get_or_create_collection.called


def test_get_chroma_collection_with_custom_embedding_function(custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client):
def test_get_chroma_collection_with_custom_embedding_function(
custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client
):
_ = mock_chromadb_store_with_custom_embedding_function._get_chroma_collection()

mock_chroma_client.get_or_create_collection.assert_called_once_with(
Expand All @@ -90,7 +88,10 @@ async def test_stores_entries_correctly(mock_chromadb_store):
VectorDBEntry(
key="test_key",
vector=[0.1, 0.2, 0.3],
metadata={"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}},
metadata={
"content": "test content",
"document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"},
},
)
]
await mock_chromadb_store.store(data)
Expand All @@ -101,10 +102,13 @@ def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):
id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)
print(f"metadata: {metadata}, type: {type(metadata)}")

assert id == sha256("test_key".encode("utf-8")).hexdigest()
assert id == sha256(b"test_key").hexdigest()
assert embedding == [0.1, 0.2, 0.3]
assert text == "test content"
assert metadata["document"] == '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}'
assert (
metadata["document"]
== '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}'
)
assert metadata["key"] == "test_key"


Expand All @@ -120,7 +124,15 @@ async def test_retrieves_entries_correctly(mock_chromadb_store):
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.query.return_value = {
"documents": [["test content"]],
"metadatas": [[{"key": "test_key", "content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}]],
"metadatas": [
[
{
"key": "test_key",
"content": "test content",
"document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"},
}
]
],
}
entries = await mock_chromadb_store.retrieve(vector)
assert len(entries) == 1
Expand All @@ -143,7 +155,7 @@ async def test_find_similar(mock_chromadb_store, mock_embedding_function):
mock_chromadb_store.embedding_function = mock_embedding_function
mock_chromadb_store.chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]]
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"
Expand All @@ -154,7 +166,7 @@ async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_e
mock_chromadb_store.embedding_function = custom_embedding_function
mock_chromadb_store.chroma_client.get_or_create_collection().query.return_value = {
"documents": [["test content"]],
"distances": [[0.1]]
"distances": [[0.1]],
}
result = await mock_chromadb_store.find_similar("test text")
assert result == "test content"
Expand All @@ -163,6 +175,7 @@ async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_e
def test_repr(mock_chromadb_store):
assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)"


@pytest.mark.parametrize(
"retrieved, max_distance, expected",
[
Expand All @@ -180,10 +193,9 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte
def test_is_json_valid_string(mock_chromadb_store):
# Arrange
valid_json_string = '{"key": "value"}'

# Act
result = mock_chromadb_store.is_json(valid_json_string)

# Assert
assert result is True

0 comments on commit 717052d

Please sign in to comment.