Skip to content

Commit

Permalink
renames + small fixes from main
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Oct 28, 2024
1 parent 9f767fc commit 1c8407d
Show file tree
Hide file tree
Showing 23 changed files with 90 additions and 98 deletions.
11 changes: 11 additions & 0 deletions docs/api_reference/core/vector-stores.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Vector Stores

::: ragbits.core.vector_stores.VectorStoreEntry

::: ragbits.core.vector_stores.VectorStoreOptions

::: ragbits.core.vector_stores.VectorStore

::: ragbits.core.vector_stores.InMemoryVectorStore

::: ragbits.core.vector_stores.chroma.ChromaVectorStore
9 changes: 0 additions & 9 deletions docs/api_reference/core/vector_store.md

This file was deleted.

6 changes: 3 additions & 3 deletions examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# dependencies = [
# "gradio",
# "ragbits-document-search",
# "ragbits-core[chromadb, litellm]",
# "ragbits-core[chroma,litellm]",
# ]
# ///
from collections.abc import AsyncIterator
Expand All @@ -16,7 +16,7 @@
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.prompt import Prompt
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.core.vector_stores.chroma import ChromaVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(

def _prepare_document_search(self, database_path: str, index_name: str) -> None:
chroma_client = chromadb.PersistentClient(path=database_path)
vector_store = ChromaDBStore(
vector_store = ChromaVectorStore(
client=chroma_client,
index_name=index_name,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import asyncio

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store import InMemoryVectorStore
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

Expand Down
6 changes: 3 additions & 3 deletions examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits-core[litellm]",
# "ragbits-core[chroma,litellm]",
# ]
# ///
import asyncio

from chromadb import PersistentClient

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.core.vector_stores.chroma import ChromaVectorStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta

Expand All @@ -27,7 +27,7 @@ async def main() -> None:
"""
Run the example.
"""
vector_store = ChromaDBStore(
vector_store = ChromaVectorStore(
client=PersistentClient("./chroma"),
index_name="jokes",
)
Expand Down
4 changes: 2 additions & 2 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits[litellm]",
# "ragbits-core[chroma,litellm]",
# ]
# ///
import asyncio
Expand All @@ -23,7 +23,7 @@
config = {
"embedder": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
"vector_store": {
"type": "ragbits.core.vector_store.chromadb_store:ChromaDBStore",
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
"client": {
"type": "PersistentClient",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
type: LiteLLMEmbeddings
type: ragbits.core.embeddings.litellm:LiteLLMEmbeddings
config:
model: "text-embedding-3-small"
options:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
txt:
type: UnstructuredDefaultProvider
type: ragbits.document_search.ingestion.providers.unstructured:UnstructuredDefaultProvider
config:
use_api: false
partition_kwargs:
Expand All @@ -12,7 +12,7 @@ txt:
overlap_all: 0

md:
type: UnstructuredDefaultProvider
type: ragbits.document_search.ingestion.providers.unstructured:UnstructuredDefaultProvider
config:
use_api: false
partition_kwargs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
type: ChromaDBStore
type: ragbits.core.vector_stores.chroma:ChromaVectorStore
config:
chroma_client:
client:
type: PersistentClient
config:
path: chroma
embedding_function:
type: ragbits.core.embeddings.litellm:LiteLLMEmbeddings
index_name: default
distance_method: l2
default_options:
k: 3
max_distance: 1.2
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ nav:
- api_reference/core/prompt.md
- api_reference/core/llms.md
- api_reference/core/embeddings.md
- api_reference/core/vector_store.md
- api_reference/core/vector-stores.md
- Document Search:
- api_reference/document_search/index.md
- api_reference/document_search/documents.md
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
]

[project.optional-dependencies]
chromadb = [
chroma = [
"chromadb~=0.4.24",
]
litellm = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys

from ..utils.config_handling import get_cls_from_config
from .base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery
from .base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery
from .in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "WhereQuery"]
__all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "WhereQuery"]

module = sys.modules[__name__]

Expand All @@ -23,7 +23,7 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

if vector_store_config["type"].endswith("ChromaDBStore"):
if vector_store_config["type"].endswith("ChromaVectorStore"):
return vector_store_cls.from_config(config)

return vector_store_cls(default_options=VectorStoreOptions(**config.get("default_options", {})))
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
WhereQuery = dict[str, str | int | float | bool]


class VectorDBEntry(BaseModel):
class VectorStoreEntry(BaseModel):
"""
An object representing a vector database entry.
"""
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
self._default_options = default_options or VectorStoreOptions()

@abstractmethod
async def store(self, entries: list[VectorDBEntry]) -> None:
async def store(self, entries: list[VectorStoreEntry]) -> None:
"""
Store entries in the vector store.
Expand All @@ -43,7 +43,7 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
"""

@abstractmethod
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]:
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Retrieve entries from the vector store.
Expand All @@ -58,7 +58,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
@abstractmethod
async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
) -> list[VectorStoreEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@
from hashlib import sha256
from typing import Literal

try:
import chromadb
from chromadb import Collection
from chromadb.api import ClientAPI
except ImportError:
HAS_CHROMADB = False
else:
HAS_CHROMADB = True
import chromadb
from chromadb import Collection
from chromadb.api import ClientAPI

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery


class ChromaDBStore(VectorStore):
class ChromaVectorStore(VectorStore):
"""
Class that stores text embeddings using [Chroma](https://docs.trychroma.com/).
"""
Expand All @@ -30,17 +25,14 @@ def __init__(
default_options: VectorStoreOptions | None = None,
):
"""
Initializes the ChromaDBStore with the given parameters.
Initializes the ChromaVectorStore with the given parameters.
Args:
client: The ChromaDB client.
index_name: The name of the index.
distance_method: The distance method to use.
default_options: The default options for querying the vector store.
"""
if not HAS_CHROMADB:
raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models")

super().__init__(default_options)
self._client = client
self._index_name = index_name
Expand All @@ -60,15 +52,15 @@ def _get_chroma_collection(self) -> Collection:
)

@classmethod
def from_config(cls, config: dict) -> ChromaDBStore:
def from_config(cls, config: dict) -> ChromaVectorStore:
"""
Creates and returns an instance of the ChromaDBStore class from the given configuration.
Creates and returns an instance of the ChromaVectorStore class from the given configuration.
Args:
config: A dictionary containing the configuration for initializing the ChromaDBStore instance.
config: A dictionary containing the configuration for initializing the ChromaVectorStore instance.
Returns:
An initialized instance of the ChromaDBStore class.
An initialized instance of the ChromaVectorStore class.
"""
client = get_cls_from_config(config["client"]["type"], chromadb) # type: ignore
return cls(
Expand All @@ -78,7 +70,7 @@ def from_config(cls, config: dict) -> ChromaDBStore:
default_options=VectorStoreOptions(**config.get("default_options", {})),
)

async def store(self, entries: list[VectorDBEntry]) -> None:
async def store(self, entries: list[VectorStoreEntry]) -> None:
"""
Stores entries in the ChromaDB collection.
Expand All @@ -97,7 +89,7 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
]
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) # type: ignore

async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]:
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Retrieves entries from the ChromaDB collection.
Expand All @@ -119,8 +111,10 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None
distances = results.get("distances") or []

return [
VectorDBEntry(
key=str(metadata["__key"]), vector=list(embeddings), metadata=json.loads(str(metadata["__metadata"]))
VectorStoreEntry(
key=str(metadata["__key"]),
vector=list(embeddings),
metadata=json.loads(str(metadata["__metadata"])),
)
for batch in zip(metadatas, embeddings, distances, strict=False)
for metadata, embeddings, distance in zip(*batch, strict=False)
Expand All @@ -129,7 +123,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
) -> list[VectorStoreEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Expand All @@ -155,7 +149,7 @@ async def list(
embeddings = get_results.get("embeddings") or []

return [
VectorDBEntry(
VectorStoreEntry(
key=str(metadata["__key"]),
vector=list(embedding),
metadata=json.loads(str(metadata["__metadata"])),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, VectorStoreOptions, WhereQuery
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery


class InMemoryVectorStore(VectorStore):
Expand All @@ -12,9 +12,9 @@ class InMemoryVectorStore(VectorStore):

def __init__(self, default_options: VectorStoreOptions | None = None) -> None:
super().__init__(default_options)
self._storage: dict[str, VectorDBEntry] = {}
self._storage: dict[str, VectorStoreEntry] = {}

async def store(self, entries: list[VectorDBEntry]) -> None:
async def store(self, entries: list[VectorStoreEntry]) -> None:
"""
Store entries in the vector store.
Expand All @@ -24,7 +24,7 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
for entry in entries:
self._storage[entry.key] = entry

async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorDBEntry]:
async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]:
"""
Retrieve entries from the vector store.
Expand All @@ -51,7 +51,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
) -> list[VectorStoreEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Expand Down
Loading

0 comments on commit 1c8407d

Please sign in to comment.