diff --git a/docs/how-to/document_search/use_reranker.md b/docs/how-to/document_search/use_reranker.md index b494ed66..bc65c059 100644 --- a/docs/how-to/document_search/use_reranker.md +++ b/docs/how-to/document_search/use_reranker.md @@ -83,8 +83,4 @@ class CustomReranker(Reranker): options: RerankerOptions | None = None, ) -> Sequence[Element]: pass - - @classmethod - def from_config(cls, config: dict) -> "CustomReranker": - pass ``` \ No newline at end of file diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py index 0af979f8..825b0062 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py @@ -1,29 +1,5 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config - from .base import Embeddings, EmbeddingType from .litellm import LiteLLMEmbeddings from .noop import NoopEmbeddings __all__ = ["EmbeddingType", "Embeddings", "LiteLLMEmbeddings", "NoopEmbeddings"] - -module = sys.modules[__name__] - - -def get_embeddings(embedder_config: dict) -> Embeddings: - """ - Initializes and returns an Embeddings object based on the provided embedder configuration. - - Args: - embedder_config : A dictionary containing configuration details for the embedder. - - Returns: - An instance of the specified Embeddings class, initialized with the provided config - (if any) or default arguments. - """ - embeddings_type = embedder_config["type"] - config = embedder_config.get("config", {}) - - embbedings = get_cls_from_config(embeddings_type, module) - return embbedings(**config) diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/base.py b/packages/ragbits-core/src/ragbits/core/embeddings/base.py index e03087b6..460476c9 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/base.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/base.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod from enum import Enum +from typing import ClassVar + +from ragbits.core import embeddings +from ragbits.core.utils.config_handling import WithConstructionConfig class EmbeddingType(Enum): @@ -17,11 +21,13 @@ class EmbeddingType(Enum): IMAGE: str = "image" -class Embeddings(ABC): +class Embeddings(WithConstructionConfig, ABC): """ Abstract client for communication with embedding models. """ + default_module: ClassVar = embeddings + @abstractmethod async def embed_text(self, data: list[str]) -> list[list[float]]: """ diff --git a/packages/ragbits-core/src/ragbits/core/llms/__init__.py b/packages/ragbits-core/src/ragbits/core/llms/__init__.py index bf2ed301..111892eb 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/llms/__init__.py @@ -1,40 +1,4 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config - from .base import LLM from .litellm import LiteLLM __all__ = ["LLM", "LiteLLM"] - -module = sys.modules[__name__] - - -def get_llm(config: dict) -> LLM: - """ - Initializes and returns an LLM object based on the provided configuration. - - Args: - config : A dictionary containing configuration details for the LLM. - - Returns: - An instance of the specified LLM class, initialized with the provided config - (if any) or default arguments. - - Raises: - KeyError: If the configuration dictionary does not contain a "type" key. - ValueError: If the LLM class is not a subclass of LLM. - """ - llm_type = config["type"] - llm_config = config.get("config", {}) - default_options = llm_config.pop("default_options", None) - llm_cls = get_cls_from_config(llm_type, module) - - if not issubclass(llm_cls, LLM): - raise ValueError(f"Invalid LLM class: {llm_cls}") - - # We need to infer the options class from the LLM class. - # pylint: disable=protected-access - options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore - - return llm_cls(**llm_config, default_options=options) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 3e68c2de..8ff01821 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -3,9 +3,13 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from functools import cached_property -from typing import Generic, cast, overload +from typing import ClassVar, Generic, cast, overload +from typing_extensions import Self + +from ragbits.core import llms from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT +from ragbits.core.utils.config_handling import WithConstructionConfig from .clients.base import LLMClient, LLMClientOptions, LLMOptions @@ -20,12 +24,13 @@ class LLMType(enum.Enum): STRUCTURED_OUTPUT = "structured_output" -class LLM(Generic[LLMClientOptions], ABC): +class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC): """ Abstract class for interaction with Large Language Model. """ _options_cls: type[LLMClientOptions] + default_module: ClassVar = llms def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None: """ @@ -160,3 +165,20 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat: if prompt.list_images(): wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}") return prompt.chat + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Initializes the class with the provided configuration. + + Args: + config: A dictionary containing configuration details for the class. + + Returns: + An instance of the class initialized with the provided configuration. + """ + default_options = config.pop("default_options", None) + + options = cls._options_cls(**default_options) if default_options else None + + return cls(**config, default_options=options) diff --git a/packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py b/packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py index 722a4570..ea330009 100644 --- a/packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/metadata_stores/__init__.py @@ -1,30 +1,4 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config - from .base import MetadataStore from .in_memory import InMemoryMetadataStore __all__ = ["InMemoryMetadataStore", "MetadataStore"] - -module = sys.modules[__name__] - - -def get_metadata_store(metadata_store_config: dict | None) -> MetadataStore | None: - """ - Initializes and returns a MetadataStore object based on the provided configuration. - - Args: - metadata_store_config: A dictionary containing configuration details for the MetadataStore. - - Returns: - An instance of the specified MetadataStore class, initialized with the provided config - (if any) or default arguments. - """ - if metadata_store_config is None: - return None - - metadata_store_class = get_cls_from_config(metadata_store_config["type"], module) - config = metadata_store_config.get("config", {}) - - return metadata_store_class(**config) diff --git a/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py b/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py index 5a6ad81d..f5a5767b 100644 --- a/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/metadata_stores/base.py @@ -1,11 +1,17 @@ from abc import ABC, abstractmethod +from typing import ClassVar +from ragbits.core import metadata_stores +from ragbits.core.utils.config_handling import WithConstructionConfig -class MetadataStore(ABC): + +class MetadataStore(WithConstructionConfig, ABC): """ An abstract class for metadata storage. Allows to store, query and retrieve metadata in form of key value pairs. """ + default_module: ClassVar = metadata_stores + @abstractmethod async def store(self, ids: list[str], metadatas: list[dict]) -> None: """ diff --git a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py index c8b69df2..398b1f6b 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py +++ b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py @@ -1,6 +1,10 @@ +import abc from importlib import import_module from types import ModuleType -from typing import Any +from typing import Any, ClassVar + +from pydantic import BaseModel +from typing_extensions import Self class InvalidConfigError(Exception): @@ -9,7 +13,7 @@ class InvalidConfigError(Exception): """ -def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # noqa: ANN401 +def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401 """ Retrieves and returns a class based on the given type string. The class can be either in the default module or a specified module if provided in the type string. @@ -23,6 +27,9 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no Returns: Any: The object retrieved from the specified or default module. + + Raises: + InvalidConfigError: The requested class is not found under the specified module """ if ":" in cls_path: try: @@ -32,7 +39,65 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no except AttributeError as err: raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err + if default_module is None: + raise InvalidConfigError("Given type string does not contain a module and no default module provided") + try: return getattr(default_module, cls_path) except AttributeError as err: raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err + + +class ObjectContructionConfig(BaseModel): + """ + A model for object construction configuration. + """ + + # Path to the class to be constructed + type: str + + # Configuration details for the class + config: dict[str, Any] = {} + + +class WithConstructionConfig(abc.ABC): + """ + A mixin class that provides methods for initializing classes from configuration. + """ + + # The default module to search for the subclass if no specific module is provided in the type string. + default_module: ClassVar[ModuleType | None] = None + + @classmethod + def subclass_from_config(cls, config: ObjectContructionConfig) -> Self: + """ + Initializes the class with the provided configuration. May return a subclass of the class, + if requested by the configuration. + + Args: + config: A model containing configuration details for the class. + + Returns: + An instance of the class initialized with the provided configuration. + + Raises: + InvalidConfigError: The class can't be found or is not a subclass of the current class. + """ + subclass = get_cls_from_config(config.type, cls.default_module) + if not issubclass(subclass, cls): + raise InvalidConfigError(f"{subclass} is not a subclass of {cls}") + + return subclass.from_config(config.config) + + @classmethod + def from_config(cls, config: dict) -> Self: + """ + Initializes the class with the provided configuration. + + Args: + config: A dictionary containing configuration details for the class. + + Returns: + An instance of the class initialized with the provided configuration. + """ + return cls(**config) diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py b/packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py index 7fb02751..7a85def1 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/__init__.py @@ -1,27 +1,4 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery from ragbits.core.vector_stores.in_memory import InMemoryVectorStore __all__ = ["InMemoryVectorStore", "VectorStore", "VectorStoreEntry", "VectorStoreOptions", "WhereQuery"] - - -def get_vector_store(config: dict) -> VectorStore: - """ - Initializes and returns a VectorStore object based on the provided configuration. - - Args: - config: A dictionary containing configuration details for the VectorStore. - - Returns: - An instance of the specified VectorStore class, initialized with the provided config - (if any) or default arguments. - - Raises: - KeyError: If the provided configuration does not contain a valid "type" key. - InvalidConfigurationError: If the provided configuration is invalid. - NotImplementedError: If the specified VectorStore class cannot be created from the provided configuration. - """ - vector_store_cls = get_cls_from_config(config["type"], sys.modules[__name__]) - return vector_store_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py index 4512c659..1b74b52b 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod +from typing import ClassVar from pydantic import BaseModel +from typing_extensions import Self +from ragbits.core import vector_stores from ragbits.core.metadata_stores.base import MetadataStore +from ragbits.core.utils.config_handling import ObjectContructionConfig, WithConstructionConfig WhereQuery = dict[str, str | int | float | bool] @@ -27,11 +31,13 @@ class VectorStoreOptions(BaseModel, ABC): max_distance: float | None = None -class VectorStore(ABC): +class VectorStore(WithConstructionConfig, ABC): """ A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function. """ + default_module: ClassVar = vector_stores + def __init__( self, default_options: VectorStoreOptions | None = None, @@ -49,20 +55,31 @@ def __init__( self._metadata_store = metadata_store @classmethod - def from_config(cls, config: dict) -> "VectorStore": + def from_config(cls, config: dict) -> Self: """ - Creates and returns an instance of the Reranker class from the given configuration. + Initializes the class with the provided configuration. Args: - config: A dictionary containing the configuration for initializing the Reranker instance. + config: A dictionary containing configuration details for the class. Returns: - An initialized instance of the Reranker class. + An instance of the class initialized with the provided configuration. Raises: - NotImplementedError: If the class cannot be created from the provided configuration. + ValidationError: The metadata_store configuration doesn't follow the expected format. + InvalidConfigError: The metadata_store class can't be found or is not the correct type. """ - raise NotImplementedError(f"Cannot create class {cls.__name__} from config.") + default_options = config.pop("default_options", None) + options = VectorStoreOptions(**default_options) if default_options else None + + store_config = config.pop("metadata_store", None) + store = ( + MetadataStore.subclass_from_config(ObjectContructionConfig.model_validate(store_config)) + if store_config + else None + ) + + return cls(**config, default_options=options, metadata_store=store) @abstractmethod async def store(self, entries: list[VectorStoreEntry]) -> None: diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py index fb31d9a9..5a1a71b5 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py @@ -2,11 +2,11 @@ import chromadb from chromadb.api import ClientAPI +from typing_extensions import Self from ragbits.core.audit import traceable -from ragbits.core.metadata_stores import get_metadata_store from ragbits.core.metadata_stores.base import MetadataStore -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config from ragbits.core.utils.dict_transformations import flatten_dict, unflatten_dict from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery @@ -44,24 +44,24 @@ def __init__( ) @classmethod - def from_config(cls, config: dict) -> "ChromaVectorStore": + def from_config(cls, config: dict) -> Self: """ - Creates and returns an instance of the ChromaVectorStore class from the given configuration. + Initializes the class with the provided configuration. Args: - config: A dictionary containing the configuration for initializing the ChromaVectorStore instance. + config: A dictionary containing configuration details for the class. Returns: - An initialized instance of the ChromaVectorStore class. + An instance of the class initialized with the provided configuration. + + Raises: + ValidationError: The client or metadata_store configuration doesn't follow the expected format. + InvalidConfigError: The client or metadata_store class can't be found or is not the correct type. """ - client_cls = get_cls_from_config(config["client"]["type"], chromadb) - return cls( - client=client_cls(**config["client"].get("config", {})), - index_name=config["index_name"], - distance_method=config.get("distance_method", "cosine"), - default_options=VectorStoreOptions(**config.get("default_options", {})), - metadata_store=get_metadata_store(config.get("metadata_store")), - ) + client_options = ObjectContructionConfig.model_validate(config["client"]) + client_cls = get_cls_from_config(client_options.type, chromadb) + config["client"] = client_cls(**client_options.config) + return super().from_config(config) @traceable async def store(self, entries: list[VectorStoreEntry]) -> None: diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py index 3ae96c2a..63227442 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py @@ -3,8 +3,8 @@ import numpy as np from ragbits.core.audit import traceable -from ragbits.core.metadata_stores import get_metadata_store from ragbits.core.metadata_stores.base import MetadataStore +from ragbits.core.utils.config_handling import ObjectContructionConfig from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery @@ -38,10 +38,19 @@ def from_config(cls, config: dict) -> "InMemoryVectorStore": Returns: An initialized instance of the InMemoryVectorStore class. + + Raises: + ValidationError: The metadata_store configuration doesn't follow the expected format. + InvalidConfigError: The metadata_store class can't be found or is not the correct type. """ + store = ( + MetadataStore.subclass_from_config(ObjectContructionConfig.model_validate(config["metadata_store"])) + if "metadata_store" in config + else None + ) return cls( default_options=VectorStoreOptions(**config.get("default_options", {})), - metadata_store=get_metadata_store(config.get("metadata_store")), + metadata_store=store, ) @traceable diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py index 506dbcf6..72e1e584 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py @@ -4,11 +4,11 @@ import qdrant_client from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import Distance, Filter, VectorParams +from typing_extensions import Self from ragbits.core.audit import traceable -from ragbits.core.metadata_stores import get_metadata_store from ragbits.core.metadata_stores.base import MetadataStore -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions @@ -41,24 +41,24 @@ def __init__( self._distance_method = distance_method @classmethod - def from_config(cls, config: dict) -> "QdrantVectorStore": + def from_config(cls, config: dict) -> Self: """ - Creates and returns an instance of the QdrantVectorStore class from the given configuration. + Initializes the class with the provided configuration. Args: - config: A dictionary containing the configuration for initializing the QdrantVectorStore instance. + config: A dictionary containing configuration details for the class. Returns: - An initialized instance of the QdrantVectorStore class. + An instance of the class initialized with the provided configuration. + + Raises: + ValidationError: The client or metadata_store configuration doesn't follow the expected format. + InvalidConfigError: The client or metadata_store class can't be found or is not the correct type. """ - client_cls = get_cls_from_config(config["client"]["type"], qdrant_client) - return cls( - client=client_cls(**config["client"].get("config", {})), - index_name=config["index_name"], - distance_method=config.get("distance_method", Distance.COSINE), - default_options=VectorStoreOptions(**config.get("default_options", {})), - metadata_store=get_metadata_store(config.get("metadata_store")), - ) + client_options = ObjectContructionConfig.model_validate(config["client"]) + client_cls = get_cls_from_config(client_options.type, qdrant_client) + config["client"] = client_cls(**client_options.config) + return super().from_config(config) @traceable async def store(self, entries: list[VectorStoreEntry]) -> None: diff --git a/packages/ragbits-core/tests/unit/embeddings/test_from_config.py b/packages/ragbits-core/tests/unit/embeddings/test_from_config.py new file mode 100644 index 00000000..c0d56a21 --- /dev/null +++ b/packages/ragbits-core/tests/unit/embeddings/test_from_config.py @@ -0,0 +1,28 @@ +from ragbits.core.embeddings import Embeddings, NoopEmbeddings +from ragbits.core.embeddings.litellm import LiteLLMEmbeddings +from ragbits.core.utils.config_handling import ObjectContructionConfig + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings", + "config": { + "model": "some_model", + "options": { + "option1": "value1", + "option2": "value2", + }, + }, + } + ) + embedding = Embeddings.subclass_from_config(config) + assert isinstance(embedding, LiteLLMEmbeddings) + assert embedding.model == "some_model" + assert embedding.options == {"option1": "value1", "option2": "value2"} + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "NoopEmbeddings"}) + embedding = Embeddings.subclass_from_config(config) + assert isinstance(embedding, NoopEmbeddings) diff --git a/packages/ragbits-core/tests/unit/llms/test_from_config.py b/packages/ragbits-core/tests/unit/llms/test_from_config.py new file mode 100644 index 00000000..6901afb6 --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/test_from_config.py @@ -0,0 +1,33 @@ +from ragbits.core.llms import LLM +from ragbits.core.llms.clients.litellm import LiteLLMOptions +from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.utils.config_handling import ObjectContructionConfig + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.llms.litellm:LiteLLM", + "config": { + "model_name": "some_model", + "use_structured_output": True, + "default_options": { + "frequency_penalty": 0.2, + "n": 42, + }, + }, + } + ) + llm: LLM = LLM.subclass_from_config(config) + assert isinstance(llm, LiteLLM) + assert llm.model_name == "some_model" + assert llm.use_structured_output is True + assert isinstance(llm.default_options, LiteLLMOptions) + assert llm.default_options.frequency_penalty == 0.2 + assert llm.default_options.n == 42 + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "LiteLLM"}) + llm: LLM = LLM.subclass_from_config(config) + assert isinstance(llm, LiteLLM) diff --git a/packages/ragbits-core/tests/unit/metadata_stores/test_from_config.py b/packages/ragbits-core/tests/unit/metadata_stores/test_from_config.py new file mode 100644 index 00000000..68838106 --- /dev/null +++ b/packages/ragbits-core/tests/unit/metadata_stores/test_from_config.py @@ -0,0 +1,19 @@ +from ragbits.core.metadata_stores.base import MetadataStore +from ragbits.core.metadata_stores.in_memory import InMemoryMetadataStore +from ragbits.core.utils.config_handling import ObjectContructionConfig + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.metadata_stores:InMemoryMetadataStore", + } + ) + store = MetadataStore.subclass_from_config(config) + assert isinstance(store, InMemoryMetadataStore) + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "InMemoryMetadataStore"}) + store = MetadataStore.subclass_from_config(config) + assert isinstance(store, InMemoryMetadataStore) diff --git a/packages/ragbits-core/tests/unit/utils/test_config_handling.py b/packages/ragbits-core/tests/unit/utils/test_config_handling.py new file mode 100644 index 00000000..8d8253aa --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/test_config_handling.py @@ -0,0 +1,64 @@ +import sys + +import pytest + +from ragbits.core.utils.config_handling import InvalidConfigError, ObjectContructionConfig, WithConstructionConfig + + +class ExampleClassWithConfigMixin(WithConstructionConfig): + default_module = sys.modules[__name__] + + def __init__(self, foo: str, bar: int) -> None: + self.foo = foo + self.bar = bar + + +class ExampleSubclass(ExampleClassWithConfigMixin): ... + + +class ExampleWithNoDefaultModule(WithConstructionConfig): + def __init__(self, foo: str, bar: int) -> None: + self.foo = foo + self.bar = bar + + +def test_defacult_from_config(): + config = {"foo": "foo", "bar": 1} + instance = ExampleClassWithConfigMixin.from_config(config) + assert instance.foo == "foo" + assert instance.bar == 1 + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ExampleSubclass", + "config": {"foo": "foo", "bar": 1}, + } + ) + instance = ExampleClassWithConfigMixin.subclass_from_config(config) + assert isinstance(instance, ExampleSubclass) + assert instance.foo == "foo" + assert instance.bar == 1 + + +def test_incorrect_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ExampleWithNoDefaultModule", # Not a subclass of ExampleClassWithConfigMixin + "config": {"foo": "foo", "bar": 1}, + } + ) + with pytest.raises(InvalidConfigError): + ExampleClassWithConfigMixin.subclass_from_config(config) + + +def test_no_default_module(): + config = ObjectContructionConfig.model_validate( + { + "type": "ExampleWithNoDefaultModule", + "config": {"foo": "foo", "bar": 1}, + } + ) + with pytest.raises(InvalidConfigError): + ExampleWithNoDefaultModule.subclass_from_config(config) diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py b/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py new file mode 100644 index 00000000..3206e3e4 --- /dev/null +++ b/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py @@ -0,0 +1,89 @@ +from chromadb import ClientAPI +from qdrant_client import AsyncQdrantClient +from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal + +from ragbits.core.metadata_stores.in_memory import InMemoryMetadataStore +from ragbits.core.utils.config_handling import ObjectContructionConfig +from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions +from ragbits.core.vector_stores.chroma import ChromaVectorStore +from ragbits.core.vector_stores.in_memory import InMemoryVectorStore +from ragbits.core.vector_stores.qdrant import QdrantVectorStore + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.vector_stores:InMemoryVectorStore", + "config": { + "metadata_store": { + "type": "ragbits.core.metadata_stores:InMemoryMetadataStore", + }, + "default_options": { + "k": 10, + "max_distance": 0.22, + }, + }, + } + ) + store = VectorStore.subclass_from_config(config) + assert isinstance(store, InMemoryVectorStore) + assert isinstance(store._default_options, VectorStoreOptions) + assert store._default_options.k == 10 + assert store._default_options.max_distance == 0.22 + assert isinstance(store._metadata_store, InMemoryMetadataStore) + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "InMemoryVectorStore"}) + store = VectorStore.subclass_from_config(config) + assert isinstance(store, InMemoryVectorStore) + + +def test_subclass_from_config_chroma_client(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.vector_stores.chroma:ChromaVectorStore", + "config": { + "client": {"type": "EphemeralClient"}, + "index_name": "some_index", + "default_options": { + "k": 10, + "max_distance": 0.22, + }, + }, + } + ) + store = VectorStore.subclass_from_config(config) + assert isinstance(store, ChromaVectorStore) + assert store._index_name == "some_index" + assert isinstance(store._client, ClientAPI) + assert store._default_options.k == 10 + assert store._default_options.max_distance == 0.22 + + +def test_subclass_from_config_drant_client(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.core.vector_stores.qdrant:QdrantVectorStore", + "config": { + "client": { + "type": "AsyncQdrantClient", + "config": { + "location": ":memory:", + }, + }, + "index_name": "some_index", + "default_options": { + "k": 10, + "max_distance": 0.22, + }, + }, + } + ) + store = VectorStore.subclass_from_config(config) + assert isinstance(store, QdrantVectorStore) + assert store._index_name == "some_index" + assert isinstance(store._client, AsyncQdrantClient) + assert isinstance(store._client._client, AsyncQdrantLocal) + assert store._default_options.k == 10 + assert store._default_options.max_distance == 0.22 diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 104ebb61..5db3f289 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -5,8 +5,9 @@ from pydantic import BaseModel, Field from ragbits.core.audit import traceable -from ragbits.core.embeddings import Embeddings, EmbeddingType, get_embeddings -from ragbits.core.vector_stores import VectorStore, get_vector_store +from ragbits.core.embeddings import Embeddings, EmbeddingType +from ragbits.core.utils.config_handling import ObjectContructionConfig +from ragbits.core.vector_stores import VectorStore from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element, ImageElement @@ -15,13 +16,10 @@ from ragbits.document_search.ingestion.processor_strategies import ( ProcessingExecutionStrategy, SequentialProcessing, - get_processing_strategy, ) from ragbits.document_search.ingestion.providers.base import BaseProvider -from ragbits.document_search.retrieval.rephrasers import get_rephraser from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rerankers import get_reranker from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions from ragbits.document_search.retrieval.rerankers.noop import NoopReranker @@ -36,6 +34,19 @@ class SearchConfig(BaseModel): embedder_kwargs: dict[str, Any] = Field(default_factory=dict) +class DocumentSearchConfig(BaseModel): + """ + Schema for for the dict taken by DocumentSearch.from_config method. + """ + + embedder: ObjectContructionConfig + vector_store: ObjectContructionConfig + rephraser: ObjectContructionConfig = ObjectContructionConfig(type="NoopQueryRephraser") + reranker: ObjectContructionConfig = ObjectContructionConfig(type="NoopReranker") + processing_strategy: ObjectContructionConfig = ObjectContructionConfig(type="SequentialProcessing") + providers: dict[str, ObjectContructionConfig] = {} + + class DocumentSearch: """ A main entrypoint to the DocumentSearch functionality. @@ -78,19 +89,24 @@ def from_config(cls, config: dict) -> "DocumentSearch": Creates and returns an instance of the DocumentSearch class from the given configuration. Args: - config: A dictionary containing the configuration for initializing the DocumentSearch instance. + config: A configuration object containing the configuration for initializing the DocumentSearch instance. Returns: DocumentSearch: An initialized instance of the DocumentSearch class. + + Raises: + ValidationError: If the configuration doesn't follow the expected format. + InvalidConfigError: If one of the specified classes can't be found or is not the correct type. """ - embedder = get_embeddings(config["embedder"]) - query_rephraser = get_rephraser(config.get("rephraser")) - reranker = get_reranker(config.get("reranker")) - vector_store = get_vector_store(config["vector_store"]) - processing_strategy = get_processing_strategy(config.get("processing_strategy")) - - providers_config_dict: dict = config.get("providers", {}) - providers_config = DocumentProcessorRouter.from_dict_to_providers_config(providers_config_dict) + model = DocumentSearchConfig.model_validate(config) + + embedder = Embeddings.subclass_from_config(model.embedder) + query_rephraser = QueryRephraser.subclass_from_config(model.rephraser) + reranker = Reranker.subclass_from_config(model.reranker) + vector_store = VectorStore.subclass_from_config(model.vector_store) + processing_strategy = ProcessingExecutionStrategy.subclass_from_config(model.processing_strategy) + + providers_config = DocumentProcessorRouter.from_dict_to_providers_config(model.providers) document_processor_router = DocumentProcessorRouter.from_config(providers_config) return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router, processing_strategy) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index 9a5cf054..69e17864 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -2,8 +2,8 @@ from collections.abc import Callable from typing import cast +from ragbits.core.utils.config_handling import ObjectContructionConfig from ragbits.document_search.documents.document import DocumentMeta, DocumentType -from ragbits.document_search.ingestion.providers import get_provider from ragbits.document_search.ingestion.providers.base import BaseProvider from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider @@ -47,29 +47,25 @@ def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | Ba self._providers = providers @staticmethod - def from_dict_to_providers_config(dict_config: dict) -> ProvidersConfig: + def from_dict_to_providers_config(dict_config: dict[str, ObjectContructionConfig]) -> ProvidersConfig: """ - Creates ProvidersConfig from dictionary config. - Example of the dictionary config: - { - "txt": { - { - "type": "UnstructuredProvider" - } - } - } + Creates ProvidersConfig from dictionary that maps document types to the provider configuration. Args: dict_config: The dictionary with configuration. Returns: ProvidersConfig object. + + Raises: + InvalidConfigError: If a provider class can't be found or is not the correct type. """ providers_config = {} for document_type, config in dict_config.items(): providers_config[DocumentType(document_type)] = cast( - Callable[[], BaseProvider] | BaseProvider, get_provider(config) + Callable[[], BaseProvider] | BaseProvider, + BaseProvider.subclass_from_config(config), ) return providers_config diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/__init__.py index b231ed3c..4fcddec9 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/__init__.py @@ -1,34 +1,6 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config - from .base import ProcessingExecutionStrategy from .batched import BatchedAsyncProcessing from .distributed import DistributedProcessing from .sequential import SequentialProcessing __all__ = ["BatchedAsyncProcessing", "DistributedProcessing", "ProcessingExecutionStrategy", "SequentialProcessing"] - - -def get_processing_strategy(config: dict | None = None) -> ProcessingExecutionStrategy: - """ - Initializes and returns a ProcessingExecutionStrategy object based on the provided configuration. - - Args: - config: A dictionary containing configuration details for the ProcessingExecutionStrategy. - - Returns: - An instance of the specified ProcessingExecutionStrategy class, initialized with the provided config - (if any) or default arguments. - - Raises: - KeyError: If the provided configuration does not contain a valid "type" key. - InvalidConfigurationError: If the provided configuration is invalid. - NotImplementedError: If the specified ProcessingExecutionStrategy class cannot be created from - the provided configuration. - """ - if config is None: - return SequentialProcessing() - - strategy_cls = get_cls_from_config(config["type"], sys.modules[__name__]) - return strategy_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/base.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/base.py index f2c94936..96459afa 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/processor_strategies/base.py @@ -1,16 +1,17 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import ClassVar -from typing_extensions import Self - +from ragbits.core.utils.config_handling import WithConstructionConfig from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element from ragbits.document_search.documents.sources import Source +from ragbits.document_search.ingestion import processor_strategies from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.base import BaseProvider -class ProcessingExecutionStrategy(ABC): +class ProcessingExecutionStrategy(WithConstructionConfig, ABC): """ Base class for processing execution strategies that define how documents are processed to become elements. @@ -19,18 +20,7 @@ class ProcessingExecutionStrategy(ABC): the processing is executed. """ - @classmethod - def from_config(cls, config: dict) -> Self: - """ - Creates and returns an instance of the ProcessingExecutionStrategy subclass from the given configuration. - - Args: - config: A dictionary containing the configuration for initializing the instance. - - Returns: - An initialized instance of the ProcessingExecutionStrategy subclass. - """ - return cls(**config) + default_module: ClassVar = processor_strategies @staticmethod async def to_document_meta(document: DocumentMeta | Document | Source) -> DocumentMeta: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py index b013a551..930d9bda 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/__init__.py @@ -1,27 +1,4 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config - from .base import BaseProvider from .dummy import DummyProvider -__all__ = ["BaseProvider", "DummyProvider", "get_provider"] - -module = sys.modules[__name__] - - -def get_provider(provider_config: dict) -> BaseProvider: - """ - Initializes and returns a Provider object based on the provided configuration. - - Args: - provider_config : A dictionary containing configuration details for the provider. - - Returns: - An instance of the specified Provider class, initialized with the provided config - (if any) or default arguments. - """ - provider_cls = get_cls_from_config(provider_config["type"], module) - config = provider_config.get("config", {}) - - return provider_cls(**config) +__all__ = ["BaseProvider", "DummyProvider"] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py index 2b99a9bc..15656db9 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/base.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod +from typing import ClassVar +from ragbits.core.utils.config_handling import WithConstructionConfig from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element +from ragbits.document_search.ingestion import providers class DocumentTypeNotSupportedError(Exception): @@ -14,11 +17,13 @@ def __init__(self, provider_name: str, document_type: DocumentType) -> None: super().__init__(message) -class BaseProvider(ABC): +class BaseProvider(WithConstructionConfig, ABC): """ A base class for the document processing providers. """ + default_module: ClassVar = providers + SUPPORTED_DOCUMENT_TYPES: set[DocumentType] @abstractmethod diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 3b7b109f..9a4e1578 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,6 +1,3 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser @@ -12,33 +9,4 @@ "QueryRephraser", "QueryRephraserInput", "QueryRephraserPrompt", - "get_rephraser", ] - -module = sys.modules[__name__] - - -def get_rephraser(config: dict | None = None) -> QueryRephraser: - """ - Initializes and returns a QueryRephraser object based on the provided configuration. - - Args: - config: A dictionary containing configuration details for the QueryRephraser. - - Returns: - An instance of the specified QueryRephraser class, initialized with the provided config - (if any) or default arguments. - - Raises: - KeyError: If the configuration dictionary does not contain a "type" key. - ValueError: If an invalid rephraser class is specified in the configuration. - """ - if config is None: - return NoopQueryRephraser() - - rephraser_cls = get_cls_from_config(config["type"], module) - - if not issubclass(rephraser_cls, QueryRephraser): - raise ValueError(f"Invalid rephraser class: {rephraser_cls}") - - return rephraser_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 3eba3820..88859d79 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,11 +1,17 @@ from abc import ABC, abstractmethod +from typing import ClassVar +from ragbits.core.utils.config_handling import WithConstructionConfig +from ragbits.document_search.retrieval import rephrasers -class QueryRephraser(ABC): + +class QueryRephraser(WithConstructionConfig, ABC): """ Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ + default_module: ClassVar = rephrasers + @abstractmethod async def rephrase(self, query: str) -> list[str]: """ @@ -17,16 +23,3 @@ async def rephrase(self, query: str) -> list[str]: Returns: The rephrased queries. """ - - @classmethod - def from_config(cls, config: dict) -> "QueryRephraser": - """ - Create an instance of `QueryRephraser` from a configuration dictionary. - - Args: - config: A dictionary containing configuration settings for the rephraser. - - Returns: - An instance of the rephraser class initialized with the provided configuration. - """ - return cls(**config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index a7627dd0..0bec4a68 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,9 +1,9 @@ from typing import Any from ragbits.core.audit import traceable -from ragbits.core.llms import get_llm from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt +from ragbits.core.utils.config_handling import ObjectContructionConfig from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.prompts import ( QueryRephraserInput, @@ -61,9 +61,14 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": An instance of the rephraser class initialized with the provided configuration. Raises: - KeyError: If the configuration dictionary does not contain the required keys. - ValueError: If the prompt class is not a subclass of `Prompt` or the LLM class is not a subclass of `LLM`. + ValidationError: If the LLM or prompt configuration doesn't follow the expected format. + InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. + ValueError: If the prompt class is not a subclass of `Prompt`. + """ - llm = get_llm(config["llm"]) - prompt_cls = get_rephraser_prompt(prompt) if (prompt := config.get("prompt")) else None + llm: LLM = LLM.subclass_from_config(ObjectContructionConfig.model_validate(config["llm"])) + prompt_cls = None + if "prompt" in config: + prompt_config = ObjectContructionConfig.model_validate(config["prompt"]) + prompt_cls = get_rephraser_prompt(prompt_config.type) return cls(llm=llm, prompt=prompt_cls) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py index a9026279..7e577a83 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py @@ -1,30 +1,4 @@ -import sys - -from ragbits.core.utils.config_handling import get_cls_from_config from ragbits.document_search.retrieval.rerankers.base import Reranker from ragbits.document_search.retrieval.rerankers.noop import NoopReranker __all__ = ["NoopReranker", "Reranker"] - - -def get_reranker(config: dict | None = None) -> Reranker: - """ - Initializes and returns a Reranker object based on the provided configuration. - - Args: - config: A dictionary containing configuration details for the Reranker. - - Returns: - An instance of the specified Reranker class, initialized with the provided config - (if any) or default arguments. - - Raises: - KeyError: If the provided configuration does not contain a valid "type" key. - InvalidConfigurationError: If the provided configuration is invalid. - NotImplementedError: If the specified Reranker class cannot be created from the provided configuration. - """ - if config is None: - return NoopReranker() - - reranker_cls = get_cls_from_config(config["type"], sys.modules[__name__]) - return reranker_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py index 11b0c86e..a7e08a95 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import ClassVar from pydantic import BaseModel +from typing_extensions import Self +from ragbits.core.utils.config_handling import WithConstructionConfig from ragbits.document_search.documents.element import Element +from ragbits.document_search.retrieval import rerankers class RerankerOptions(BaseModel): @@ -15,11 +19,13 @@ class RerankerOptions(BaseModel): max_chunks_per_doc: int | None = None -class Reranker(ABC): +class Reranker(WithConstructionConfig, ABC): """ Reranks elements retrieved from vector store. """ + default_module: ClassVar = rerankers + def __init__(self, default_options: RerankerOptions | None = None) -> None: """ Constructs a new Reranker instance. @@ -30,20 +36,19 @@ def __init__(self, default_options: RerankerOptions | None = None) -> None: self._default_options = default_options or RerankerOptions() @classmethod - def from_config(cls, config: dict) -> "Reranker": + def from_config(cls, config: dict) -> Self: """ - Creates and returns an instance of the Reranker class from the given configuration. + Initializes the class with the provided configuration. Args: - config: A dictionary containing the configuration for initializing the Reranker instance. + config: A dictionary containing configuration details for the class. Returns: - An initialized instance of the Reranker class. - - Raises: - NotImplementedError: If the class cannot be created from the provided configuration. + An instance of the class initialized with the provided configuration. """ - raise NotImplementedError(f"Cannot create class {cls.__name__} from config.") + default_options = config.pop("default_options", None) + options = RerankerOptions(**default_options) if default_options else None + return cls(**config, default_options=options) @abstractmethod async def rerank( diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py index a27e982f..aa877a9c 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py @@ -23,22 +23,6 @@ def __init__(self, model: str, default_options: RerankerOptions | None = None) - super().__init__(default_options) self.model = model - @classmethod - def from_config(cls, config: dict) -> "LiteLLMReranker": - """ - Creates and returns an instance of the LiteLLMReranker class from the given configuration. - - Args: - config: A dictionary containing the configuration for initializing the LiteLLMReranker instance. - - Returns: - An initialized instance of the LiteLLMReranker class. - """ - return cls( - model=config["model"], - default_options=RerankerOptions(**config.get("default_options", {})), - ) - @traceable async def rerank( self, diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py index fe4686de..681fa7e4 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py @@ -10,19 +10,6 @@ class NoopReranker(Reranker): A no-op reranker that does not change the order of the elements. """ - @classmethod - def from_config(cls, config: dict) -> "NoopReranker": - """ - Creates and returns an instance of the NoopReranker class from the given configuration. - - Args: - config: A dictionary containing the configuration for initializing the NoopReranker instance. - - Returns: - An initialized instance of the NoopReranker class. - """ - return cls(default_options=RerankerOptions(**config.get("default_options", {}))) - @traceable async def rerank( # noqa: PLR6301 self, diff --git a/packages/ragbits-document-search/tests/unit/test_providers.py b/packages/ragbits-document-search/tests/unit/test_providers.py index 5975610e..55e9c781 100644 --- a/packages/ragbits-document-search/tests/unit/test_providers.py +++ b/packages/ragbits-document-search/tests/unit/test_providers.py @@ -3,8 +3,10 @@ import pytest +from ragbits.core.utils.config_handling import ObjectContructionConfig from ragbits.document_search.documents.document import DocumentMeta, DocumentType -from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError +from ragbits.document_search.ingestion.providers.base import BaseProvider, DocumentTypeNotSupportedError +from ragbits.document_search.ingestion.providers.dummy import DummyProvider from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider @@ -50,3 +52,17 @@ async def test_unstructured_provider_raises_value_error_when_server_url_not_set( ) assert str(err.value) == "Either pass api_server argument or set the UNSTRUCTURED_SERVER_URL environment variable" + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + {"type": "ragbits.document_search.ingestion.providers:DummyProvider"} + ) + embedding = BaseProvider.subclass_from_config(config) + assert isinstance(embedding, DummyProvider) + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "DummyProvider"}) + embedding = BaseProvider.subclass_from_config(config) + assert isinstance(embedding, DummyProvider) diff --git a/packages/ragbits-document-search/tests/unit/test_rephrasers.py b/packages/ragbits-document-search/tests/unit/test_rephrasers.py new file mode 100644 index 00000000..cb2cb989 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_rephrasers.py @@ -0,0 +1,57 @@ +from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.utils.config_handling import ObjectContructionConfig +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser +from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser +from ragbits.document_search.retrieval.rephrasers.prompts import QueryRephraserPrompt + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + {"type": "ragbits.document_search.retrieval.rephrasers:NoopQueryRephraser"} + ) + rephraser = QueryRephraser.subclass_from_config(config) + assert isinstance(rephraser, NoopQueryRephraser) + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "NoopQueryRephraser"}) + rephraser = QueryRephraser.subclass_from_config(config) + assert isinstance(rephraser, NoopQueryRephraser) + + +def test_subclass_from_config_llm(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.document_search.retrieval.rephrasers.llm:LLMQueryRephraser", + "config": { + "llm": { + "type": "ragbits.core.llms.litellm:LiteLLM", + "config": {"model_name": "some_model"}, + }, + }, + } + ) + rephraser = QueryRephraser.subclass_from_config(config) + assert isinstance(rephraser, LLMQueryRephraser) + assert isinstance(rephraser._llm, LiteLLM) + assert rephraser._llm.model_name == "some_model" + + +def test_subclass_from_config_llm_prompt(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.document_search.retrieval.rephrasers.llm:LLMQueryRephraser", + "config": { + "llm": { + "type": "ragbits.core.llms.litellm:LiteLLM", + "config": {"model_name": "some_model"}, + }, + "prompt": {"type": "QueryRephraserPrompt"}, + }, + } + ) + rephraser = QueryRephraser.subclass_from_config(config) + assert isinstance(rephraser, LLMQueryRephraser) + assert isinstance(rephraser._llm, LiteLLM) + assert issubclass(rephraser._prompt, QueryRephraserPrompt) diff --git a/packages/ragbits-document-search/tests/unit/test_rerankers.py b/packages/ragbits-document-search/tests/unit/test_rerankers.py index 3a557f3d..fa483b0d 100644 --- a/packages/ragbits-document-search/tests/unit/test_rerankers.py +++ b/packages/ragbits-document-search/tests/unit/test_rerankers.py @@ -2,12 +2,12 @@ from collections.abc import Sequence from unittest.mock import patch -import pytest - +from ragbits.core.utils.config_handling import ObjectContructionConfig from ragbits.document_search.documents.document import DocumentMeta from ragbits.document_search.documents.element import Element, TextElement from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker +from ragbits.document_search.retrieval.rerankers.noop import NoopReranker class CustomReranker(Reranker): @@ -22,10 +22,8 @@ async def rerank( # noqa: PLR6301 def test_custom_reranker_from_config() -> None: - with pytest.raises(NotImplementedError) as exc_info: - CustomReranker.from_config({}) - - assert "Cannot create class CustomReranker from config" in str(exc_info.value) + reranker = CustomReranker.from_config({}) + assert isinstance(reranker, CustomReranker) def test_litellm_reranker_from_config() -> None: @@ -80,3 +78,49 @@ async def test_litellm_reranker_rerank() -> None: top_n=2, max_chunks_per_doc=None, ) + + +def test_subclass_from_config(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.document_search.retrieval.rerankers:NoopReranker", + "config": { + "default_options": { + "top_n": 12, + "max_chunks_per_doc": 42, + }, + }, + } + ) + reranker = Reranker.subclass_from_config(config) + assert isinstance(reranker, NoopReranker) + assert isinstance(reranker._default_options, RerankerOptions) + assert reranker._default_options.top_n == 12 + assert reranker._default_options.max_chunks_per_doc == 42 + + +def test_subclass_from_config_default_path(): + config = ObjectContructionConfig.model_validate({"type": "NoopReranker"}) + reranker = Reranker.subclass_from_config(config) + assert isinstance(reranker, NoopReranker) + + +def test_subclass_from_config_llm(): + config = ObjectContructionConfig.model_validate( + { + "type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", + "config": { + "model": "some_model", + "default_options": { + "top_n": 12, + "max_chunks_per_doc": 42, + }, + }, + } + ) + reranker = Reranker.subclass_from_config(config) + assert isinstance(reranker, LiteLLMReranker) + assert isinstance(reranker._default_options, RerankerOptions) + assert reranker.model == "some_model" + assert reranker._default_options.top_n == 12 + assert reranker._default_options.max_chunks_per_doc == 42