Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for different MetadataStores in VectorStore #144

Merged
merged 16 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
},
},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"},
Expand Down
30 changes: 30 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import MetadataStore
from .in_memory import InMemoryMetadataStore

__all__ = ["InMemoryMetadataStore", "MetadataStore"]

module = sys.modules[__name__]


def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None:
"""
Initializes and returns a MetadataStore object based on the provided configuration.

Args:
metadata_store_config: A dictionary containing configuration details for the MetadataStore.

Returns:
An instance of the specified MetadataStore class, initialized with the provided config
(if any) or default arguments.
"""
if metadata_store_config is None:
return None

metadata_store_class = get_cls_from_config(metadata_store_config["type"], module)
config = metadata_store_config.get("config", {})

return metadata_store_class(**config)
32 changes: 32 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod


class MetadataStore(ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs.
"""

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Store metadatas under ids in metadata store.

Args:
ids: list of unique ids of the entries
metadatas: list of dicts with metadata.
"""

@abstractmethod
async def get(self, ids: list[str]) -> list[dict]:
"""
Returns metadatas associated with a given ids.

Args:
ids: list of ids to use.

Returns:
List of metadata dicts associated with a given ids.

Raises:
MetadataNotFoundError: If the metadata is not found.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class MetadataNotFoundError(Exception):
"""
Raised when metadata is not found in the metadata store
"""

def __init__(self, id: str) -> None:
super().__init__(f"Metadata not found for {id} id.")
self.id = id
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.metadata_stores.exceptions import MetadataNotFoundError


class InMemoryMetadataStore(MetadataStore):
"""
Metadata Store implemented in memory
"""

def __init__(self) -> None:
"""
Constructs a new InMemoryMetadataStore instance.
"""
self._storage: dict[str, dict] = {}

async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Store metadatas under ids in metadata store.

Args:
ids: list of unique ids of the entries
metadatas: list of dicts with metadata.
"""
for _id, metadata in zip(ids, metadatas, strict=False):
self._storage[_id] = metadata

async def get(self, ids: list[str]) -> list[dict]:
"""
Returns metadatas associated with a given ids.

Args:
ids: list of ids to use.

Returns:
List of metadata dicts associated with a given ids.

Raises:
MetadataNotFoundError: If the metadata is not found.
"""
try:
return [self._storage[_id] for _id in ids]
except KeyError as exc:
raise MetadataNotFoundError(*exc.args) from exc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

from ..metadata_stores import get_metadata_store
from ..utils.config_handling import get_cls_from_config
from .base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from .in_memory import InMemoryVectorStore
Expand All @@ -26,4 +27,8 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
if vector_store_config["type"].endswith("ChromaVectorStore"):
return vector_store_cls.from_config(config)

return vector_store_cls(default_options=VectorStoreOptions(**config.get("default_options", {})))
metadata_store_config = vector_store_config.get("metadata_store_config")
return vector_store_cls(
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(metadata_store_config),
)
16 changes: 15 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel

from ragbits.core.metadata_stores.base import MetadataStore

WhereQuery = dict[str, str | int | float | bool]


Expand Down Expand Up @@ -29,9 +31,21 @@ class VectorStore(ABC):
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
def __init__(
self,
default_options: VectorStoreOptions | None = None,
metadata_store: MetadataStore | None = None,
) -> None:
"""
Constructs a new VectorStore instance.

Args:
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use.
"""
super().__init__()
self._default_options = default_options or VectorStoreOptions()
self._metadata_store = metadata_store

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
74 changes: 51 additions & 23 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from chromadb import Collection
from chromadb.api import ClientAPI

from ragbits.core.metadata_stores import get_metadata_store
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

Expand All @@ -23,17 +25,19 @@ def __init__(
index_name: str,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
default_options: VectorStoreOptions | None = None,
):
metadata_store: MetadataStore | None = None,
) -> None:
"""
Initializes the ChromaVectorStore with the given parameters.
Constructs a new ChromaVectorStore instance.

Args:
client: The ChromaDB client.
index_name: The name of the index.
distance_method: The distance method to use.
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use. If None, the metadata will be stored in ChromaDB.
"""
super().__init__(default_options)
super().__init__(default_options=default_options, metadata_store=metadata_store)
self._client = client
self._index_name = index_name
self._distance_method = distance_method
Expand Down Expand Up @@ -62,12 +66,13 @@ def from_config(cls, config: dict) -> ChromaVectorStore:
Returns:
An initialized instance of the ChromaVectorStore class.
"""
client = get_cls_from_config(config["client"]["type"], chromadb) # type: ignore
client_cls = get_cls_from_config(config["client"]["type"], chromadb)
return cls(
client=client(**config["client"].get("config", {})),
client=client_cls(**config["client"].get("config", {})),
index_name=config["index_name"],
distance_method=config.get("distance_method", "l2"),
default_options=VectorStoreOptions(**config.get("default_options", {})),
metadata_store=get_metadata_store(config.get("metadata_store")),
)

async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand All @@ -77,17 +82,17 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
Args:
entries: The entries to store.
"""
# TODO: Think about better id components for hashing
# TODO: Think about better id components for hashing and move hash computing to VectorStoreEntry
ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries]
documents = [entry.key for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [
{
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}
for entry in entries
]
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) # type: ignore
metadatas = [entry.metadata for entry in entries]
metadatas = (
[{"__metadata": json.dumps(metadata, default=str)} for metadata in metadatas]
if self._metadata_store is None
else await self._metadata_store.store(ids, metadatas) # type: ignore
)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) # type: ignore

async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Expand All @@ -99,25 +104,37 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

Returns:
The retrieved entries.

Raises:
MetadataNotFoundError: If the metadata is not found.
"""
options = self._default_options if options is None else options

results = self._collection.query(
query_embeddings=vector,
n_results=options.k,
include=["metadatas", "embeddings", "distances"],
include=["metadatas", "embeddings", "distances", "documents"],
)
ids = results.get("ids") or []
metadatas = results.get("metadatas") or []
embeddings = results.get("embeddings") or []
distances = results.get("distances") or []
documents = results.get("documents") or []

metadatas = [
[json.loads(metadata["__metadata"]) for batch in metadatas for metadata in batch] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(*ids)
]

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embeddings),
metadata=json.loads(str(metadata["__metadata"])),
metadata=metadata, # type: ignore
)
for batch in zip(metadatas, embeddings, distances, strict=False)
for metadata, embeddings, distance in zip(*batch, strict=False)
for batch in zip(metadatas, embeddings, distances, documents, strict=False)
for metadata, embeddings, distance, document in zip(*batch, strict=False)
if options.max_distance is None or distance <= options.max_distance
]

Expand All @@ -135,6 +152,9 @@ async def list(

Returns:
The entries.

Raises:
MetadataNotFoundError: If the metadata is not found.
"""
# Cast `where` to chromadb's Where type
where_chroma: chromadb.Where | None = dict(where) if where else None
Expand All @@ -143,16 +163,24 @@ async def list(
where=where_chroma,
limit=limit,
offset=offset,
include=["metadatas", "embeddings"],
include=["metadatas", "embeddings", "documents"],
)
ids = get_results.get("ids") or []
metadatas = get_results.get("metadatas") or []
embeddings = get_results.get("embeddings") or []
documents = get_results.get("documents") or []

metadatas = (
[json.loads(metadata["__metadata"]) for metadata in metadatas] # type: ignore
if self._metadata_store is None
else await self._metadata_store.get(ids)
)

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embedding),
metadata=json.loads(str(metadata["__metadata"])),
metadata=metadata, # type: ignore
)
for metadata, embedding in zip(metadatas, embeddings, strict=False)
for metadata, embedding, document in zip(metadatas, embeddings, documents, strict=False)
]
16 changes: 14 additions & 2 deletions packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery


Expand All @@ -10,8 +11,19 @@ class InMemoryVectorStore(VectorStore):
A simple in-memory implementation of Vector Store, storing vectors in memory.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
super().__init__(default_options)
def __init__(
self,
default_options: VectorStoreOptions | None = None,
metadata_store: MetadataStore | None = None,
) -> None:
"""
Constructs a new InMemoryVectorStore instance.

Args:
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use.
"""
super().__init__(default_options=default_options, metadata_store=metadata_store)
self._storage: dict[str, VectorStoreEntry] = {}

async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions packages/ragbits-core/tests/unit/metadata_stores/test_in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from ragbits.core.metadata_stores.exceptions import MetadataNotFoundError
from ragbits.core.metadata_stores.in_memory import InMemoryMetadataStore


@pytest.fixture
def metadata_store() -> InMemoryMetadataStore:
return InMemoryMetadataStore()


async def test_store(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1", "id2"]
metadatas = [{"key1": "value1"}, {"key2": "value2"}]
await metadata_store.store(ids, metadatas)
assert metadata_store._storage["id1"] == {"key1": "value1"}
assert metadata_store._storage["id2"] == {"key2": "value2"}


async def test_get(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1", "id2"]
metadatas = [{"key1": "value1"}, {"key2": "value2"}]
await metadata_store.store(ids, metadatas)
result = await metadata_store.get(ids)
assert result == [{"key1": "value1"}, {"key2": "value2"}]


async def test_get_metadata_not_found(metadata_store: InMemoryMetadataStore) -> None:
ids = ["id1"]
with pytest.raises(MetadataNotFoundError):
await metadata_store.get(ids)
Empty file.
Loading
Loading