Skip to content

Commit

Permalink
Linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds committed Oct 25, 2024
1 parent ad3a5cc commit af9dc06
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 51 deletions.
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/metadata_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def store(self, key: str | UUID, metadata: dict) -> None:
"""

@abc.abstractmethod
async def query(self, metadata_field_name: str, value: Any) -> dict:
async def query(self, metadata_field_name: str, value: Any) -> dict: # noqa
"""
Queries metastore and returns dicts with key: metadata format that match
Expand Down Expand Up @@ -66,6 +66,6 @@ async def get_global(self) -> dict:
"""
Get key value pairs of metadata that is shared across entries
Returns
Returns:
metadata for the whole collection
"""
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def store(self, key: str | UUID, metadata: dict) -> None:
"""
self._storage[key] = metadata

async def query(self, metadata_field_name: str, value: Any) -> dict:
async def query(self, metadata_field_name: str, value: Any) -> dict: # noqa
"""
Queries metastore and returns dicts with key: metadata format that match
Expand Down
5 changes: 2 additions & 3 deletions packages/ragbits-core/src/ragbits/core/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
from typing import Optional

from pydantic import BaseModel

Expand All @@ -24,9 +23,9 @@ class VectorStore(abc.ABC):
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

metadata_store: Optional[MetadataStore]
metadata_store: MetadataStore | None

def __init__(self, metadata_store: Optional[MetadataStore] = None):
def __init__(self, metadata_store: MetadataStore | None = None):
self.metadata_store = metadata_store

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from hashlib import sha256
from typing import List, Literal, Optional
from typing import Literal

try:
import chromadb
Expand All @@ -16,18 +16,19 @@
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_store import VectorDBEntry, VectorStore, WhereQuery

CHROMA_IDS_KEY = "ids"
CHROMA_DOCUMENTS_KEY = "documents"
CHROMA_DISTANCES_KEY = "distances"
CHROMA_METADATA_KEY = "metadatas"
CHROMA_EMBEDDINGS_KEY = "embeddings"
CHROMA_INCLUDE_KEYS = [CHROMA_DOCUMENTS_KEY, CHROMA_DISTANCES_KEY, CHROMA_METADATA_KEY, CHROMA_EMBEDDINGS_KEY]


class ChromaDBStore(VectorStore):
"""
Class that stores text embeddings using [Chroma](https://docs.trychroma.com/).
"""

CHROMA_IDS_KEY = "ids"
CHROMA_DOCUMENTS_KEY = "documents"
CHROMA_DISTANCES_KEY = "distances"
CHROMA_METADATA_KEY = "metadatas"
CHROMA_EMBEDDINGS_KEY = "embeddings"
CHROMA_INCLUDE_KEYS = [CHROMA_DOCUMENTS_KEY, CHROMA_DISTANCES_KEY, CHROMA_METADATA_KEY, CHROMA_EMBEDDINGS_KEY]
DEFAULT_DISTANCE_METHOD = "l2"
METADATA_INNER_KEY = "__metadata"

Expand All @@ -38,7 +39,7 @@ def __init__(
embedding_function: Embeddings | chromadb.EmbeddingFunction,
max_distance: float | None = None,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
metadata_store: Optional[MetadataStore] = None,
metadata_store: MetadataStore | None = None,
):
"""
Initializes the ChromaDBStore with the given parameters.
Expand All @@ -49,6 +50,7 @@ def __init__(
embedding_function: The embedding function.
max_distance: The maximum distance for similarity.
distance_method: The distance method to use.
metadata_store: The metadata store to use.
"""
if not HAS_CHROMADB:
raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models")
Expand All @@ -59,7 +61,7 @@ def __init__(
self._embedding_function = embedding_function
self._max_distance = max_distance
self._metadata = {"hnsw:space": distance_method}
self._collection = None
self._collection: chromadb.Collection | None = None

@classmethod
def from_config(cls, config: dict) -> ChromaDBStore:
Expand Down Expand Up @@ -123,8 +125,8 @@ def _return_best_match(self, retrieved: dict) -> str | None:
Returns:
The best match or None if no match is found.
"""
if self._max_distance is None or retrieved[self.CHROMA_DISTANCES_KEY][0][0] <= self._max_distance:
return retrieved[self.CHROMA_DOCUMENTS_KEY][0][0]
if self._max_distance is None or retrieved[CHROMA_DISTANCES_KEY][0][0] <= self._max_distance:
return retrieved[CHROMA_DOCUMENTS_KEY][0][0]

return None

Expand Down Expand Up @@ -156,16 +158,39 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
ids, embeddings, contents, metadatas = map(list, zip(*entries_processed, strict=False))

if self.metadata_store is not None:
for key, meta in zip(ids, metadatas):
for key, meta in zip(ids, metadatas, strict=False):
await self.metadata_store.store(key, meta)
metadata_to_store = None
else:
metadata_to_store = [{self.METADATA_INNER_KEY: json.dumps(m, default=str)} for m in metadatas]

collection = await self._get_chroma_collection()
collection.add(ids=ids, embeddings=embeddings, metadatas=metadata_to_store, documents=contents)
collection.add(ids=ids, embeddings=embeddings, metadatas=metadata_to_store, documents=contents) # type: ignore

async def _extract_entries_from_query(
self, query_results: chromadb.api.types.QueryResult | chromadb.api.types.GetResult
) -> list[VectorDBEntry]:
db_entries: list[VectorDBEntry] = []

if len(query_results[CHROMA_DOCUMENTS_KEY]) < 1: # type: ignore
return db_entries
for i in range(len(query_results[CHROMA_DOCUMENTS_KEY][0])): # type: ignore
key = query_results[CHROMA_DOCUMENTS_KEY][0][i] # type: ignore
if self.metadata_store is not None:
metadata = await self.metadata_store.get(query_results[CHROMA_IDS_KEY][0][i]) # type: ignore
else:
metadata = json.loads(query_results[CHROMA_METADATA_KEY][0][i][self.METADATA_INNER_KEY]) # type: ignore

db_entry = VectorDBEntry(
key=key,
vector=query_results[CHROMA_EMBEDDINGS_KEY][0][i], # type: ignore
metadata=metadata,
)
db_entries.append(db_entry)

return db_entries

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]:
"""
Retrieves entries from the ChromaDB collection.
Expand All @@ -177,7 +202,7 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
The retrieved entries.
"""
collection = await self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k, include=self.CHROMA_INCLUDE_KEYS)
query_result = collection.query(query_embeddings=[vector], n_results=k, include=CHROMA_INCLUDE_KEYS) # type: ignore
return await self._extract_entries_from_query(query_result)

async def list(
Expand All @@ -199,30 +224,9 @@ async def list(
where_chroma: chromadb.Where | None = dict(where) if where else None

collection = await self._get_chroma_collection()
get_results = collection.get(where=where_chroma, limit=limit, offset=offset, include=self.CHROMA_INCLUDE_KEYS)
get_results = collection.get(where=where_chroma, limit=limit, offset=offset, include=CHROMA_INCLUDE_KEYS) # type: ignore
return await self._extract_entries_from_query(get_results)

async def _extract_entries_from_query(self, query_results: chromadb.api.types.QueryResult | chromadb.api.types.GetResult) -> List[VectorDBEntry]:
db_entries: list[VectorDBEntry] = []

if len(query_results[self.CHROMA_DOCUMENTS_KEY]) < 1:
return db_entries
for i in range(len(query_results[self.CHROMA_DOCUMENTS_KEY][0])):
key = query_results[self.CHROMA_DOCUMENTS_KEY][0][i]
if self.metadata_store is not None:
metadata = await self.metadata_store.get(query_results[self.CHROMA_IDS_KEY][0][i])
else:
metadata = json.loads(query_results[self.CHROMA_METADATA_KEY][0][i][self.METADATA_INNER_KEY])

db_entry = VectorDBEntry(
key=key,
vector=query_results[self.CHROMA_EMBEDDINGS_KEY][0][i],
metadata=metadata,
)
db_entries.append(db_entry)

return db_entries

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from itertools import islice
from typing import Optional

import numpy as np

Expand All @@ -12,7 +11,7 @@ class InMemoryVectorStore(VectorStore):
A simple in-memory implementation of Vector Store, storing vectors in memory.
"""

def __init__(self, metadata_store: Optional[MetadataStore] = None) -> None:
def __init__(self, metadata_store: MetadataStore | None = None) -> None:
super().__init__(metadata_store)
self._storage: dict[str, VectorDBEntry] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_chromadbstore_init_import_error():
embedding_function=MagicMock(),
)


async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore):
data = [
VectorDBEntry(
Expand All @@ -96,7 +97,7 @@ async def test_stores_entries_correctly(mock_chromadb_store: ChromaDBStore):
mock_chromadb_store._chroma_client.get_or_create_collection().add.assert_called_once() # type: ignore


def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry):
def test_process_db_entry(mock_chromadb_store: ChromaDBStore, mock_vector_db_entry: VectorDBEntry):
id, embedding, key, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry)

assert id == sha256(b"test_key").hexdigest()
Expand All @@ -122,7 +123,8 @@ async def test_retrieves_entries_correctly(mock_chromadb_store: ChromaDBStore):
"metadatas": [
[
{
"__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}'
"__metadata": '{"content": "test content", "document": {"title": "test title", '
'"source": {"path": "/test/path"}, "document_type": "test_type"}}'
}
]
],
Expand All @@ -137,17 +139,19 @@ async def test_retrieves_entries_correctly(mock_chromadb_store: ChromaDBStore):
assert entries[0].vector == [0.12, 0.25, 0.29]


async def test_lists_entries_correctly(mock_chromadb_store):
async def test_lists_entries_correctly(mock_chromadb_store: ChromaDBStore):
mock_collection = await mock_chromadb_store._get_chroma_collection()
mock_collection.get.return_value = {
mock_collection.get.return_value = { # type: ignore
"documents": [["test content", "test content 2"]],
"metadatas": [
[
{
"__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
"__metadata": '{"content": "test content", "document": {"title": "test title", '
'"source": {"path": "/test/path"}, "document_type": "test_type"}}',
},
{
"__metadata": '{"content": "test content 2", "document": {"title": "test title 2", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
"__metadata": '{"content": "test content 2", "document": {"title": "test title 2", '
'"source": {"path": "/test/path"}, "document_type": "test_type"}}',
},
]
],
Expand Down

0 comments on commit af9dc06

Please sign in to comment.