Skip to content

Commit

Permalink
refactor: generalize creating objects from config
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Dec 4, 2024
1 parent 9ea3ace commit 9d831b3
Show file tree
Hide file tree
Showing 29 changed files with 257 additions and 413 deletions.
4 changes: 0 additions & 4 deletions docs/how-to/document_search/use_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,4 @@ class CustomReranker(Reranker):
options: RerankerOptions | None = None,
) -> Sequence[Element]:
pass

@classmethod
def from_config(cls, config: dict) -> "CustomReranker":
pass
```
83 changes: 43 additions & 40 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class to rephrase the query.
import asyncio

from ragbits.document_search import DocumentSearch
from ragbits.document_search._main import DocumentSearchConfig
from ragbits.document_search.documents.document import DocumentMeta

documents = [
Expand All @@ -56,54 +57,56 @@ class to rephrase the query.
),
]

config = {
"embedder": {
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings",
},
"vector_store": {
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
"client": {
"type": "PersistentClient",
"config": {
"path": "chroma",
config = DocumentSearchConfig.model_validate(
{
"embedder": {
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings",
},
"vector_store": {
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
"client": {
"type": "PersistentClient",
"config": {
"path": "chroma",
},
},
"index_name": "jokes",
"distance_method": "l2",
"default_options": {
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
},
},
"index_name": "jokes",
"distance_method": "l2",
"default_options": {
"k": 3,
"max_distance": 1.2,
},
"metadata_store": {
"type": "InMemoryMetadataStore",
},
},
},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {
"model": "cohere/rerank-english-v3.0",
"default_options": {
"top_n": 3,
"max_chunks_per_doc": None,
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {
"model": "cohere/rerank-english-v3.0",
"default_options": {
"top_n": 3,
"max_chunks_per_doc": None,
},
},
},
},
"providers": {"txt": {"type": "DummyProvider"}},
"rephraser": {
"type": "LLMQueryRephraser",
"config": {
"llm": {
"type": "ragbits.core.llms.litellm:LiteLLM",
"config": {
"model_name": "gpt-4-turbo",
"providers": {"txt": {"type": "DummyProvider"}},
"rephraser": {
"type": "LLMQueryRephraser",
"config": {
"llm": {
"type": "ragbits.core.llms.litellm:LiteLLM",
"config": {
"model_name": "gpt-4-turbo",
},
},
"prompt": "QueryRephraserPrompt",
},
"prompt": "QueryRephraserPrompt",
},
},
}
}
)


async def main() -> None:
Expand Down
24 changes: 0 additions & 24 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings, EmbeddingType
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "NoopEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.
Args:
embedder_config : A dictionary containing configuration details for the embedder.
Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

embbedings = get_cls_from_config(embeddings_type, module)
return embbedings(**config)
8 changes: 7 additions & 1 deletion packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar

from ragbits.core import embeddings
from ragbits.core.utils.config_handling import WithConstructionConfig


class EmbeddingType(Enum):
Expand All @@ -17,11 +21,13 @@ class EmbeddingType(Enum):
IMAGE: str = "image"


class Embeddings(ABC):
class Embeddings(WithConstructionConfig, ABC):
"""
Abstract client for communication with embedding models.
"""

default_module: ClassVar = embeddings

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Expand Down
36 changes: 0 additions & 36 deletions packages/ragbits-core/src/ragbits/core/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,3 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import LLM

__all__ = ["LLM"]

module = sys.modules[__name__]


def get_llm(config: dict) -> LLM:
"""
Initializes and returns an LLM object based on the provided configuration.
Args:
config : A dictionary containing configuration details for the LLM.
Returns:
An instance of the specified LLM class, initialized with the provided config
(if any) or default arguments.
Raises:
KeyError: If the configuration dictionary does not contain a "type" key.
ValueError: If the LLM class is not a subclass of LLM.
"""
llm_type = config["type"]
llm_config = config.get("config", {})
default_options = llm_config.pop("default_options", None)
llm_cls = get_cls_from_config(llm_type, module)

if not issubclass(llm_cls, LLM):
raise ValueError(f"Invalid LLM class: {llm_cls}")

# We need to infer the options class from the LLM class.
# pylint: disable=protected-access
options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore

return llm_cls(**llm_config, default_options=options)
26 changes: 24 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Generic, cast, overload
from typing import ClassVar, Generic, cast, overload

from typing_extensions import Self

from ragbits.core import llms
from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT
from ragbits.core.utils.config_handling import WithConstructionConfig

from .clients.base import LLMClient, LLMClientOptions, LLMOptions

Expand All @@ -20,12 +24,13 @@ class LLMType(enum.Enum):
STRUCTURED_OUTPUT = "structured_output"


class LLM(Generic[LLMClientOptions], ABC):
class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC):
"""
Abstract class for interaction with Large Language Model.
"""

_options_cls: type[LLMClientOptions]
default_module: ClassVar = llms

def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None:
"""
Expand Down Expand Up @@ -160,3 +165,20 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
if prompt.list_images():
wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}")
return prompt.chat

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.
Args:
config: A dictionary containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
"""
default_options = config.pop("default_options", None)

options = cls._options_cls(**default_options) if default_options else None

return cls(**config, default_options=options)
26 changes: 0 additions & 26 deletions packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,4 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import MetadataStore
from .in_memory import InMemoryMetadataStore

__all__ = ["InMemoryMetadataStore", "MetadataStore"]

module = sys.modules[__name__]


def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None:
"""
Initializes and returns a MetadataStore object based on the provided configuration.
Args:
metadata_store_config: A dictionary containing configuration details for the MetadataStore.
Returns:
An instance of the specified MetadataStore class, initialized with the provided config
(if any) or default arguments.
"""
if metadata_store_config is None:
return None

metadata_store_class = get_cls_from_config(metadata_store_config["type"], module)
config = metadata_store_config.get("config", {})

return metadata_store_class(**config)
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from ragbits.core import metadata_stores
from ragbits.core.utils.config_handling import WithConstructionConfig

class MetadataStore(ABC):

class MetadataStore(WithConstructionConfig, ABC):
"""
An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs.
"""

default_module: ClassVar = metadata_stores

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
"""
Expand Down
58 changes: 57 additions & 1 deletion packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import abc
from importlib import import_module
from types import ModuleType
from typing import Any
from typing import Any, ClassVar

from pydantic import BaseModel
from typing_extensions import Self


class InvalidConfigError(Exception):
Expand Down Expand Up @@ -36,3 +40,55 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no
return getattr(default_module, cls_path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err


class ObjectContructionConfig(BaseModel):
"""
A model for object construction configuration.
"""

# Path to the class to be constructed
type: str

# Configuration details for the class
config: dict[str, Any] = {}


class WithConstructionConfig(abc.ABC):
"""
A mixin class that provides methods for initializing classes from configuration.
"""

# The default module to search for the subclass if no specific module is provided in the type string.
default_module: ClassVar[ModuleType]

@classmethod
def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
"""
Initializes the class with the provided configuration. May return a subclass of the class,
if requested by the configuration.
Args:
config: A model containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
"""
subclass = get_cls_from_config(config.type, cls.default_module)
if not issubclass(subclass, cls):
raise InvalidConfigError(f"{subclass} is not a subclass of {cls}")

return subclass.from_config(config.config)

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Initializes the class with the provided configuration.
Args:
config: A dictionary containing configuration details for the class.
Returns:
An instance of the class initialized with the provided configuration.
"""
return cls(**config)
Loading

0 comments on commit 9d831b3

Please sign in to comment.