Skip to content

Commit

Permalink
feat: support for different MetadataStores in VectorStore (#144)
Browse files Browse the repository at this point in the history
Co-authored-by: Michał Pstrąg <[email protected]>
  • Loading branch information
konrad-czarnota-ds and micpst authored Oct 30, 2024
1 parent a60cdfe commit c1c019f
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 40 deletions.
3 changes: 3 additions & 0 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
},
},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"},
Expand Down
30 changes: 30 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import MetadataStore
from .in_memory import InMemoryMetadataStore

__all__ = ["InMemoryMetadataStore", "MetadataStore"]

module = sys.modules[__name__]


def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None:
"""
Initializes and returns a MetadataStore object based on the provided configuration.
Args:
metadata_store_config: A dictionary containing configuration details for the MetadataStore.
Returns:
An instance of the specified MetadataStore class, initialized with the provided config
(if any) or default arguments.
"""
if metadata_store_config is None:
return None

metadata_store_class = get_cls_from_config(metadata_store_config["type"], module)
config = metadata_store_config.get("config", {})

return metadata_store_class(**config)
32 changes: 32 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod


class MetadataStore(ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs.
"""

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Store metadatas under ids in metadata store.
Args:
ids: list of unique ids of the entries
metadatas: list of dicts with metadata.
"""

@abstractmethod
async def get(self, ids: list[str]) -> list[dict]:
"""
Returns metadatas associated with a given ids.
Args:
ids: list of ids to use.
Returns:
List of metadata dicts associated with a given ids.
Raises:
MetadataNotFoundError: If the metadata is not found.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class MetadataNotFoundError(Exception):
"""
Raised when metadata is not found in the metadata store
"""

def __init__(self, id: str) -> None:
super().__init__(f"Metadata not found for {id} id.")
self.id = id
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.metadata_stores.exceptions import MetadataNotFoundError


class InMemoryMetadataStore(MetadataStore):
"""
Metadata Store implemented in memory
"""

def __init__(self) -> None:
"""
Constructs a new InMemoryMetadataStore instance.
"""
self._storage: dict[str, dict] = {}

async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Store metadatas under ids in metadata store.
Args:
ids: list of unique ids of the entries
metadatas: list of dicts with metadata.
"""
for _id, metadata in zip(ids, metadatas, strict=False):
self._storage[_id] = metadata

async def get(self, ids: list[str]) -> list[dict]:
"""
Returns metadatas associated with a given ids.
Args:
ids: list of ids to use.
Returns:
List of metadata dicts associated with a given ids.
Raises:
MetadataNotFoundError: If the metadata is not found.
"""
try:
return [self._storage[_id] for _id in ids]
except KeyError as exc:
raise MetadataNotFoundError(*exc.args) from exc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

from ..metadata_stores import get_metadata_store
from ..utils.config_handling import get_cls_from_config
from .base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from .in_memory import InMemoryVectorStore
Expand All @@ -26,4 +27,8 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
if vector_store_config["type"].endswith("ChromaVectorStore"):
return vector_store_cls.from_config(config)

return vector_store_cls(default_options=VectorStoreOptions(**config.get("default_options", {})))
metadata_store_config = vector_store_config.get("metadata_store_config")
return vector_store_cls(
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(metadata_store_config),
)
16 changes: 15 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel

from ragbits.core.metadata_stores.base import MetadataStore

WhereQuery = dict[str, str | int | float | bool]


Expand Down Expand Up @@ -29,9 +31,21 @@ class VectorStore(ABC):
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
def __init__(
self,
default_options: VectorStoreOptions | None = None,
metadata_store: MetadataStore | None = None,
) -> None:
"""
Constructs a new VectorStore instance.
Args:
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use.
"""
super().__init__()
self._default_options = default_options or VectorStoreOptions()
self._metadata_store = metadata_store

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
74 changes: 51 additions & 23 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from chromadb import Collection
from chromadb.api import ClientAPI

from ragbits.core.metadata_stores import get_metadata_store
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

Expand All @@ -23,17 +25,19 @@ def __init__(
index_name: str,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
default_options: VectorStoreOptions | None = None,
):
metadata_store: MetadataStore | None = None,
) -> None:
"""
Initializes the ChromaVectorStore with the given parameters.
Constructs a new ChromaVectorStore instance.
Args:
client: The ChromaDB client.
index_name: The name of the index.
distance_method: The distance method to use.
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use. If None, the metadata will be stored in ChromaDB.
"""
super().__init__(default_options)
super().__init__(default_options=default_options, metadata_store=metadata_store)
self._client = client
self._index_name = index_name
self._distance_method = distance_method
Expand Down Expand Up @@ -62,12 +66,13 @@ def from_config(cls, config: dict) -> ChromaVectorStore:
Returns:
An initialized instance of the ChromaVectorStore class.
"""
client = get_cls_from_config(config["client"]["type"], chromadb) # type: ignore
client_cls = get_cls_from_config(config["client"]["type"], chromadb)
return cls(
client=client(**config["client"].get("config", {})),
client=client_cls(**config["client"].get("config", {})),
index_name=config["index_name"],
distance_method=config.get("distance_method", "l2"),
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(config.get("metadata_store")),
)

async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand All @@ -77,17 +82,17 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Args:
entries: The entries to store.
"""
# TODO: Think about better id components for hashing
# TODO: Think about better id components for hashing and move hash computing to VectorStoreEntry
ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries]
documents = [entry.key for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [
{
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}
for entry in entries
]
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) # type: ignore
metadatas = [entry.metadata for entry in entries]
metadatas = (
[{"__metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
if self._metadata_store is None
else await self._metadata_store.store(ids, metadatas) # type: ignore
)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) # type: ignore

async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Expand All @@ -99,25 +104,37 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
Returns:
The retrieved entries.
Raises:
MetadataNotFoundError: If the metadata is not found.
"""
options = self._default_options if options is None else options

results = self._collection.query(
query_embeddings=vector,
n_results=options.k,
include=["metadatas", "embeddings", "distances"],
include=["metadatas", "embeddings", "distances", "documents"],
)
ids = results.get("ids") or []
metadatas = results.get("metadatas") or []
embeddings = results.get("embeddings") or []
distances = results.get("distances") or []
documents = results.get("documents") or []

metadatas = [
[json.loads(metadata["__metadata"]) for batch in metadatas for metadata in batch] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(*ids)
]

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embeddings),
metadata=json.loads(str(metadata["__metadata"])),
metadata=metadata, # type: ignore
)
for batch in zip(metadatas, embeddings, distances, strict=False)
for metadata, embeddings, distance in zip(*batch, strict=False)
for batch in zip(metadatas, embeddings, distances, documents, strict=False)
for metadata, embeddings, distance, document in zip(*batch, strict=False)
if options.max_distance is None or distance <= options.max_distance
]

Expand All @@ -135,6 +152,9 @@ async def list(
Returns:
The entries.
Raises:
MetadataNotFoundError: If the metadata is not found.
"""
# Cast `where` to chromadb's Where type
where_chroma: chromadb.Where | None = dict(where) if where else None
Expand All @@ -143,16 +163,24 @@ async def list(
where=where_chroma,
limit=limit,
offset=offset,
include=["metadatas", "embeddings"],
include=["metadatas", "embeddings", "documents"],
)
ids = get_results.get("ids") or []
metadatas = get_results.get("metadatas") or []
embeddings = get_results.get("embeddings") or []
documents = get_results.get("documents") or []

metadatas = (
[json.loads(metadata["__metadata"]) for metadata in metadatas] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(ids)
)

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embedding),
metadata=json.loads(str(metadata["__metadata"])),
metadata=metadata, # type: ignore
)
for metadata, embedding in zip(metadatas, embeddings, strict=False)
for metadata, embedding, document in zip(metadatas, embeddings, documents, strict=False)
]
16 changes: 14 additions & 2 deletions packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery


Expand All @@ -10,8 +11,19 @@ class InMemoryVectorStore(VectorStore):
A simple in-memory implementation of Vector Store, storing vectors in memory.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
super().__init__(default_options)
def __init__(
self,
default_options: VectorStoreOptions | None = None,
metadata_store: MetadataStore | None = None,
) -> None:
"""
Constructs a new InMemoryVectorStore instance.
Args:
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use.
"""
super().__init__(default_options=default_options, metadata_store=metadata_store)
self._storage: dict[str, VectorStoreEntry] = {}

async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions packages/ragbits-core/tests/unit/metadata_stores/test_in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from ragbits.core.metadata_stores.exceptions import MetadataNotFoundError
from ragbits.core.metadata_stores.in_memory import InMemoryMetadataStore


@pytest.fixture
def metadata_store() -> InMemoryMetadataStore:
return InMemoryMetadataStore()


async def test_store(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1", "id2"]
metadatas = [{"key1": "value1"}, {"key2": "value2"}]
await metadata_store.store(ids, metadatas)
assert metadata_store._storage["id1"] == {"key1": "value1"}
assert metadata_store._storage["id2"] == {"key2": "value2"}


async def test_get(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1", "id2"]
metadatas = [{"key1": "value1"}, {"key2": "value2"}]
await metadata_store.store(ids, metadatas)
result = await metadata_store.get(ids)
assert result == [{"key1": "value1"}, {"key2": "value2"}]


async def test_get_metadata_not_found(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1"]
with pytest.raises(MetadataNotFoundError):
await metadata_store.get(ids)
Empty file.
Loading

0 comments on commit c1c019f

Please sign in to comment.