Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for different MetadataStores in VectorStore #144

Merged
merged 16 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
71 changes: 71 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_store/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import abc
from typing import Any
from uuid import UUID


class MetadataStore(abc.ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs
"""

@abc.abstractmethod
async def store(self, key: str | UUID, metadata: dict) -> None:
"""
Store metadata under key in metadata store

Args:
key: unique key of the entry
metadata: dict with metadata
"""

@abc.abstractmethod
async def query(self, metadata_field_name: str, value: Any) -> dict: # noqa
"""
Queries metastore and returns dicts with key: metadata format that match

Args:
metadata_field_name: name of metadata field
value: value to match against

Returns:
dict with key: metadata entries that match query
"""

@abc.abstractmethod
async def get(self, key: str | UUID) -> dict:
"""
Returns metadata associated with a given key

Args:
key: key to use

Returns:
metadata dict associated with a given key
"""

@abc.abstractmethod
async def get_all(self) -> dict:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns all keys with associated metadata

Returns:
metadata dict for all entries
"""

@abc.abstractmethod
async def store_global(self, metadata: dict) -> None:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Store key value pairs of metadata that is shared across entries

Args:
metadata: common key value pairs for the whole collection
"""

@abc.abstractmethod
async def get_global(self) -> dict:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Get key value pairs of metadata that is shared across entries

Returns:
metadata for the whole collection
"""
78 changes: 78 additions & 0 deletions packages/ragbits-core/src/ragbits/core/metadata_store/in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any
from uuid import UUID

from ragbits.core.metadata_store.base import MetadataStore


class InMemoryMetadataStore(MetadataStore):
"""
Metadata Store implemented in memory
"""

def __init__(self) -> None:
self._storage: dict[str | UUID, Any] = {}
self._global_storage: dict[str | UUID, Any] = {}

async def store(self, key: str | UUID, metadata: dict) -> None:
"""
Store metadata under key in metadata store

Args:
key: unique key of the entry
metadata: dict with metadata
"""
self._storage[key] = metadata

async def query(self, metadata_field_name: str, value: Any) -> dict: # noqa
"""
Queries metastore and returns dicts with key: metadata format that match

Args:
metadata_field_name: name of metadata field
value: value to match against

Returns:
dict with key: metadata entries that match query
"""
return {
key: metadata for key, metadata in self._storage.items() if metadata.get(metadata_field_name, None) == value
}

async def get(self, key: str | UUID) -> dict:
"""
Returns metadata associated with a given key

Args:
key: key to use

Returns:
metadata dict associated with a given key
"""
return self._storage.get(key, {})

async def get_all(self) -> dict:
"""
Returns all keys with associated metadata

Returns:
metadata dict for all entries
"""
return self._storage

async def store_global(self, metadata: dict) -> None:
"""
Store key value pairs of metadata that is shared across entries

Args:
metadata: common key value pairs for the whole collection
"""
self._global_storage.update(metadata)

async def get_global(self) -> dict:
"""
Get key value pairs of metadata that is shared across entries

Returns:
metadata for the whole collection
"""
return self._global_storage
7 changes: 6 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel

from ragbits.core.metadata_store.base import MetadataStore

WhereQuery = dict[str, str | int | float | bool]


Expand Down Expand Up @@ -29,9 +31,12 @@ class VectorStore(ABC):
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
metadata_store: MetadataStore | None

def __init__(self, default_options: VectorStoreOptions | None = None, metadata_store: MetadataStore | None = None):
super().__init__()
self._default_options = default_options or VectorStoreOptions()
self.metadata_store = metadata_store

@abstractmethod
async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
108 changes: 72 additions & 36 deletions packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,33 @@
from chromadb import Collection
from chromadb.api import ClientAPI

from ragbits.core.metadata_store.base import MetadataStore
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

CHROMA_IDS_KEY = "ids"
CHROMA_DOCUMENTS_KEY = "documents"
CHROMA_DISTANCES_KEY = "distances"
CHROMA_METADATA_KEY = "metadatas"
CHROMA_EMBEDDINGS_KEY = "embeddings"
CHROMA_LIST_INCLUDE_KEYS = [CHROMA_DOCUMENTS_KEY, CHROMA_METADATA_KEY, CHROMA_EMBEDDINGS_KEY]
CHROMA_QUERY_INCLUDE_KEYS = CHROMA_LIST_INCLUDE_KEYS + [CHROMA_DISTANCES_KEY]


class ChromaVectorStore(VectorStore):
"""
Class that stores text embeddings using [Chroma](https://docs.trychroma.com/).
"""

METADATA_INNER_KEY = "__metadata"

def __init__(
self,
client: ClientAPI,
index_name: str,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
default_options: VectorStoreOptions | None = None,
metadata_store: MetadataStore | None = None,
):
"""
Initializes the ChromaVectorStore with the given parameters.
Expand All @@ -32,24 +44,36 @@ def __init__(
index_name: The name of the index.
distance_method: The distance method to use.
default_options: The default options for querying the vector store.
metadata_store: The metadata store to use.
"""
super().__init__(default_options)
super().__init__(default_options, metadata_store)
self._client = client
self._index_name = index_name
self._distance_method = distance_method
self._collection = self._get_chroma_collection()
self._collection: Collection | None = None
micpst marked this conversation as resolved.
Show resolved Hide resolved

def _get_chroma_collection(self) -> Collection:
async def _get_chroma_collection(self) -> Collection:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Gets or creates a collection with the given name and metadata.

Returns:
The collection.
"""
return self._client.get_or_create_collection(
if self._collection is not None:
return self._collection

global_metadata = {"hnsw:space": self._distance_method}
if self.metadata_store is not None:
await self.metadata_store.store_global(global_metadata)
metadata_to_store = None
micpst marked this conversation as resolved.
Show resolved Hide resolved
else:
metadata_to_store = global_metadata

self._collection = self._client.get_or_create_collection(
name=self._index_name,
metadata={"hnsw:space": self._distance_method},
metadata=metadata_to_store,
)
return self._collection

@classmethod
def from_config(cls, config: dict) -> ChromaVectorStore:
Expand Down Expand Up @@ -80,14 +104,20 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
# TODO: Think about better id components for hashing
ids = [sha256(entry.key.encode("utf-8")).hexdigest() for entry in entries]
embeddings = [entry.vector for entry in entries]
metadatas = [
{
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}
for entry in entries
]
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) # type: ignore

if self.metadata_store is not None:
for key, meta in zip(ids, [entry.metadata for entry in entries], strict=False):
await self.metadata_store.store(key, meta)
metadata_to_store = None
else:
metadata_to_store = [
{self.METADATA_INNER_KEY: json.dumps(entry.metadata, default=str)} for entry in entries
micpst marked this conversation as resolved.
Show resolved Hide resolved
]

contents = [entry.key for entry in entries]

collection = await self._get_chroma_collection()
collection.add(ids=ids, embeddings=embeddings, metadatas=metadata_to_store, documents=contents) # type: ignore

async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Expand All @@ -101,26 +131,33 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
The retrieved entries.
"""
options = self._default_options if options is None else options
results = self._collection.query(
query_embeddings=vector,
n_results=options.k,
include=["metadatas", "embeddings", "distances"],
)
metadatas = results.get("metadatas") or []
embeddings = results.get("embeddings") or []
distances = results.get("distances") or []
collection = await self._get_chroma_collection()
results = collection.query(query_embeddings=vector, n_results=options.k, include=CHROMA_QUERY_INCLUDE_KEYS) # type: ignore
metadatas = results.get(CHROMA_METADATA_KEY) or []
embeddings = results.get(CHROMA_EMBEDDINGS_KEY) or []
distances = results.get(CHROMA_DISTANCES_KEY) or []
ids = results.get(CHROMA_IDS_KEY) or []
documents = results.get(CHROMA_DOCUMENTS_KEY) or []

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embeddings),
metadata=json.loads(str(metadata["__metadata"])),
metadata=await self._load_sample_metadata(metadata, sample_id),
)
for batch in zip(metadatas, embeddings, distances, strict=False)
for metadata, embeddings, distance in zip(*batch, strict=False)
for batch in zip(metadatas, embeddings, distances, ids, documents, strict=False) # type: ignore
for metadata, embeddings, distance, sample_id, document in zip(*batch, strict=False)
if options.max_distance is None or distance <= options.max_distance
]

async def _load_sample_metadata(self, metadata: dict, sample_id: str) -> dict:
if self.metadata_store is not None:
metadata = await self.metadata_store.get(sample_id)
else:
metadata = json.loads(metadata[self.METADATA_INNER_KEY])

return metadata

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorStoreEntry]:
Expand All @@ -139,20 +176,19 @@ async def list(
# Cast `where` to chromadb's Where type
where_chroma: chromadb.Where | None = dict(where) if where else None

get_results = self._collection.get(
where=where_chroma,
limit=limit,
offset=offset,
include=["metadatas", "embeddings"],
)
metadatas = get_results.get("metadatas") or []
embeddings = get_results.get("embeddings") or []
collection = await self._get_chroma_collection()
get_results = collection.get(where=where_chroma, limit=limit, offset=offset, include=CHROMA_LIST_INCLUDE_KEYS) # type: ignore

metadatas = get_results.get(CHROMA_METADATA_KEY) or []
embeddings = get_results.get(CHROMA_EMBEDDINGS_KEY) or []
documents = get_results.get(CHROMA_DOCUMENTS_KEY) or []
ids = get_results.get(CHROMA_IDS_KEY) or []

return [
VectorStoreEntry(
key=str(metadata["__key"]),
key=document,
vector=list(embedding),
metadata=json.loads(str(metadata["__metadata"])),
metadata=await self._load_sample_metadata(metadata, sample_id),
)
for metadata, embedding in zip(metadatas, embeddings, strict=False)
for metadata, embedding, sample_id, document in zip(metadatas, embeddings, ids, documents, strict=False) # type: ignore
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from ragbits.core.metadata_store.base import MetadataStore
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery


Expand All @@ -10,8 +11,10 @@ class InMemoryVectorStore(VectorStore):
A simple in-memory implementation of Vector Store, storing vectors in memory.
"""

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
super().__init__(default_options)
def __init__(
self, default_options: VectorStoreOptions | None = None, metadata_store: MetadataStore | None = None
) -> None:
super().__init__(default_options, metadata_store)
self._storage: dict[str, VectorStoreEntry] = {}

async def store(self, entries: list[VectorStoreEntry]) -> None:
Expand Down
Loading
Loading