From 439e4459e3137d67fe733f30a9210ff34ba495fc Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 9 Dec 2024 13:54:42 +0100 Subject: [PATCH] refactor: option for classes to create instances from factory path (#240) --- .../evaluation/document-search/optimize.py | 6 +- packages/ragbits-core/src/ragbits/core/cli.py | 6 +- .../ragbits-core/src/ragbits/core/config.py | 6 +- .../src/ragbits/core/llms/factory.py | 20 +------ .../src/ragbits/core/prompt/lab/app.py | 3 +- .../src/ragbits/core/utils/config_handling.py | 55 +++++++++++++------ .../src/ragbits/core/vector_stores/chroma.py | 4 +- .../src/ragbits/core/vector_stores/qdrant.py | 4 +- .../tests/unit/llms/factory/__init__.py | 5 -- .../unit/llms/factory/test_get_default_llm.py | 12 +++- .../llms/factory/test_get_llm_from_factory.py | 22 -------- .../unit/utils/pyproject/test_get_instace.py | 6 +- .../bad_factory_project/pyproject.toml | 6 +- .../factory_project/pyproject.toml | 6 +- .../tests/unit/utils/test_config_handling.py | 16 ++++++ .../retrieval/rephrasers/prompts.py | 4 +- .../src/ragbits/evaluate/callbacks/neptune.py | 4 +- .../evaluate/dataset_generator/pipeline.py | 8 +-- .../tasks/corpus_generation.py | 6 +- .../tasks/text_generation/base.py | 6 +- .../src/ragbits/evaluate/loaders/__init__.py | 4 +- .../src/ragbits/evaluate/metrics/__init__.py | 4 +- 22 files changed, 105 insertions(+), 108 deletions(-) delete mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py diff --git a/examples/evaluation/document-search/optimize.py b/examples/evaluation/document-search/optimize.py index b3709709..521e4e41 100644 --- a/examples/evaluation/document-search/optimize.py +++ b/examples/evaluation/document-search/optimize.py @@ -3,7 +3,7 @@ import hydra from omegaconf import DictConfig, OmegaConf -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path from ragbits.evaluate.loaders import dataloader_factory from ragbits.evaluate.metrics import metric_set_factory from ragbits.evaluate.optimizer import Optimizer @@ -21,12 +21,12 @@ def main(config: DictConfig) -> None: config: Hydra configuration. """ dataloader = dataloader_factory(config.data) - pipeline_class = get_cls_from_config(config.pipeline.type, module) + pipeline_class = import_by_path(config.pipeline.type, module) metrics = metric_set_factory(config.metrics) callback_configurators = None if getattr(config, "callbacks", None): callback_configurators = [ - get_cls_from_config(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks + import_by_path(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks ] optimization_cfg = OmegaConf.create({"direction": "maximize", "n_trials": 10}) diff --git a/packages/ragbits-core/src/ragbits/core/cli.py b/packages/ragbits-core/src/ragbits/core/cli.py index 7f8ad1d9..4e962b76 100644 --- a/packages/ragbits-core/src/ragbits/core/cli.py +++ b/packages/ragbits-core/src/ragbits/core/cli.py @@ -10,7 +10,7 @@ from ragbits.cli.app import CLI from ragbits.core.config import core_config -from ragbits.core.llms.base import LLMType +from ragbits.core.llms.base import LLM, LLMType from ragbits.core.prompt.prompt import ChatFormat, Prompt @@ -91,13 +91,11 @@ def execute( Raises: ValueError: If `llm_factory` is not provided. """ - from ragbits.core.llms.factory import get_llm_from_factory - prompt = _render(prompt_path=prompt_path, payload=payload) if llm_factory is None: raise ValueError("`llm_factory` must be provided") - llm = get_llm_from_factory(llm_factory) + llm: LLM = LLM.subclass_from_factory(llm_factory) llm_output = asyncio.run(llm.generate(prompt)) response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output) diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py index a329de93..77168f2a 100644 --- a/packages/ragbits-core/src/ragbits/core/config.py +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -14,9 +14,9 @@ class CoreConfig(BaseModel): # Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm" default_llm_factories: dict[LLMType, str] = { - LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory", - LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory", - LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory", + LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory", + LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory", + LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory", } diff --git a/packages/ragbits-core/src/ragbits/core/llms/factory.py b/packages/ragbits-core/src/ragbits/core/llms/factory.py index 9405a976..584cc2f1 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/factory.py +++ b/packages/ragbits-core/src/ragbits/core/llms/factory.py @@ -1,26 +1,8 @@ -import importlib - from ragbits.core.config import core_config from ragbits.core.llms.base import LLM, LLMType from ragbits.core.llms.litellm import LiteLLM -def get_llm_from_factory(factory_path: str) -> LLM: - """ - Get an instance of an LLM using a factory function specified by the user. - - Args: - factory_path (str): The path to the factory function. - - Returns: - LLM: An instance of the LLM class. - """ - module_name, function_name = factory_path.rsplit(".", 1) - module = importlib.import_module(module_name) - function = getattr(module, function_name) - return function() - - def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM: """ Get an instance of the default LLM using the factory function @@ -34,7 +16,7 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM: """ factory = core_config.default_llm_factories[llm_type] - return get_llm_from_factory(factory) + return LLM.subclass_from_factory(factory) def simple_litellm_factory() -> LLM: diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 8d430b38..1fd7846c 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -16,7 +16,6 @@ from ragbits.core.config import core_config from ragbits.core.llms import LLM from ragbits.core.llms.base import LLMType -from ragbits.core.llms.factory import get_llm_from_factory from ragbits.core.prompt import Prompt from ragbits.core.prompt.discovery import PromptDiscovery @@ -166,7 +165,7 @@ def lab_app( # pylint: disable=missing-param-doc prompts_state = gr.State( PromptState( prompts=list(prompts), - llm=get_llm_from_factory(llm_factory) if llm_factory else None, + llm=LLM.subclass_from_factory(llm_factory) if llm_factory else None, ) ) diff --git a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py index 398b1f6b..861ffff4 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py +++ b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py @@ -13,39 +13,39 @@ class InvalidConfigError(Exception): """ -def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401 +def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401 """ - Retrieves and returns a class based on the given type string. The class can be either in the - default module or a specified module if provided in the type string. + Retrieves and returns an object based on the string in the format of "module.submodule:object_name". + If the first part is ommited, the default module is used. Args: - cls_path: A string representing the path to the class or object. This can either be a - path implicitly referencing the default module or a full path (module.submodule:ClassName) - if the class is located in a different module. - default_module: The default module to search for the class if no specific module - is provided in the type string. + path: A string representing the path to the object. This can either be a + path implicitly referencing the default module or a full path (module.submodule:object_name) + if the object is located in a different module. + default_module: The default module to search for the object if no specific module + is provided in the path string. Returns: Any: The object retrieved from the specified or default module. Raises: - InvalidConfigError: The requested class is not found under the specified module + InvalidConfigError: The requested object is not found under the specified module """ - if ":" in cls_path: + if ":" in path: try: - module_stringified, object_stringified = cls_path.split(":") + module_stringified, object_stringified = path.split(":") module = import_module(module_stringified) return getattr(module, object_stringified) except AttributeError as err: - raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err + raise InvalidConfigError(f"{object_stringified} not found in module {module_stringified}") from err if default_module is None: - raise InvalidConfigError("Given type string does not contain a module and no default module provided") + raise InvalidConfigError("Not provided a full path and no default module specified") try: - return getattr(default_module, cls_path) + return getattr(default_module, path) except AttributeError as err: - raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err + raise InvalidConfigError(f"{path} not found in module {default_module}") from err class ObjectContructionConfig(BaseModel): @@ -83,12 +83,35 @@ def subclass_from_config(cls, config: ObjectContructionConfig) -> Self: Raises: InvalidConfigError: The class can't be found or is not a subclass of the current class. """ - subclass = get_cls_from_config(config.type, cls.default_module) + subclass = import_by_path(config.type, cls.default_module) if not issubclass(subclass, cls): raise InvalidConfigError(f"{subclass} is not a subclass of {cls}") return subclass.from_config(config.config) + @classmethod + def subclass_from_factory(cls, factory_path: str) -> Self: + """ + Creates the class using the provided factory function. May return a subclass of the class, + if requested by the factory. + + Args: + factory_path: A string representing the path to the factory function + in the format of "module.submodule:factory_name". + + Returns: + An instance of the class initialized with the provided factory function. + + Raises: + InvalidConfigError: The factory can't be found or the object returned + is not a subclass of the current class. + """ + factory = import_by_path(factory_path, cls.default_module) + obj = factory() + if not isinstance(obj, cls): + raise InvalidConfigError(f"The object returned by factory {factory_path} is not an instance of {cls}") + return obj + @classmethod def from_config(cls, config: dict) -> Self: """ diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py index 5a1a71b5..d89e4489 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py @@ -6,7 +6,7 @@ from ragbits.core.audit import traceable from ragbits.core.metadata_stores.base import MetadataStore -from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config +from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path from ragbits.core.utils.dict_transformations import flatten_dict, unflatten_dict from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery @@ -59,7 +59,7 @@ def from_config(cls, config: dict) -> Self: InvalidConfigError: The client or metadata_store class can't be found or is not the correct type. """ client_options = ObjectContructionConfig.model_validate(config["client"]) - client_cls = get_cls_from_config(client_options.type, chromadb) + client_cls = import_by_path(client_options.type, chromadb) config["client"] = client_cls(**client_options.config) return super().from_config(config) diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py index 72e1e584..122c21c0 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py @@ -8,7 +8,7 @@ from ragbits.core.audit import traceable from ragbits.core.metadata_stores.base import MetadataStore -from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config +from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions @@ -56,7 +56,7 @@ def from_config(cls, config: dict) -> Self: InvalidConfigError: The client or metadata_store class can't be found or is not the correct type. """ client_options = ObjectContructionConfig.model_validate(config["client"]) - client_cls = get_cls_from_config(client_options.type, qdrant_client) + client_cls = import_by_path(client_options.type, qdrant_client) config["client"] = client_cls(**client_options.config) return super().from_config(config) diff --git a/packages/ragbits-core/tests/unit/llms/factory/__init__.py b/packages/ragbits-core/tests/unit/llms/factory/__init__.py index a3559f0c..e69de29b 100644 --- a/packages/ragbits-core/tests/unit/llms/factory/__init__.py +++ b/packages/ragbits-core/tests/unit/llms/factory/__init__.py @@ -1,5 +0,0 @@ -import sys -from pathlib import Path - -# Add "llms" to sys.path -sys.path.append(str(Path(__file__).parent.parent)) diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py index 005fa1b2..d2568c87 100644 --- a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py @@ -6,12 +6,22 @@ from ragbits.core.llms.litellm import LiteLLM +def mock_llm_factory() -> LiteLLM: + """ + A mock LLM factory that creates a LiteLLM instance with a mock model name. + + Returns: + LiteLLM: An instance of the LiteLLM. + """ + return LiteLLM(model_name="mock_model") + + def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None: """ Test the get_llm_from_factory function. """ monkeypatch.setattr( - core_config, "default_llm_factories", {LLMType.TEXT: "factory.test_get_llm_from_factory.mock_llm_factory"} + core_config, "default_llm_factories", {LLMType.TEXT: "unit.llms.factory.test_get_default_llm:mock_llm_factory"} ) llm = get_default_llm() diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py deleted file mode 100644 index 8d2a948c..00000000 --- a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -from ragbits.core.llms.factory import get_llm_from_factory -from ragbits.core.llms.litellm import LiteLLM - - -def mock_llm_factory() -> LiteLLM: - """ - A mock LLM factory that creates a LiteLLM instance with a mock model name. - - Returns: - LiteLLM: An instance of the LiteLLM. - """ - return LiteLLM(model_name="mock_model") - - -def test_get_llm_from_factory(): - """ - Test the get_llm_from_factory function. - """ - llm = get_llm_from_factory("factory.test_get_llm_from_factory.mock_llm_factory") - - assert isinstance(llm, LiteLLM) - assert llm.model_name == "mock_model" diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py index 3dea0c5b..6263bf4c 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py @@ -80,9 +80,9 @@ def test_get_config_instance_factories(): ) assert config.default_llm_factories == { - LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory", - LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory", - LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_vision_factory", + LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory", + LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory", + LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_vision_factory", } diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml index b8839a25..3f42569e 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml +++ b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/bad_factory_project/pyproject.toml @@ -2,6 +2,6 @@ name = "bad_factory_project" [tool.ragbits.core.default_llm_factories] -non_existing = "ragbits.core.llms.factory.simple_litellm_factory" -vision = "ragbits.core.llms.factory.simple_litellm_vision_factory" -structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory" \ No newline at end of file +non_existing = "ragbits.core.llms.factory:simple_litellm_factory" +vision = "ragbits.core.llms.factory:simple_litellm_vision_factory" +structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory" diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml index 1e1a605e..bed6c8ad 100644 --- a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml +++ b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/factory_project/pyproject.toml @@ -2,6 +2,6 @@ name = "factory_project" [tool.ragbits.core.default_llm_factories] -text = "ragbits.core.llms.factory.simple_litellm_factory" -vision = "ragbits.core.llms.factory.simple_litellm_vision_factory" -structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory" \ No newline at end of file +text = "ragbits.core.llms.factory:simple_litellm_factory" +vision = "ragbits.core.llms.factory:simple_litellm_vision_factory" +structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory" diff --git a/packages/ragbits-core/tests/unit/utils/test_config_handling.py b/packages/ragbits-core/tests/unit/utils/test_config_handling.py index 8d8253aa..dd2ca8ce 100644 --- a/packages/ragbits-core/tests/unit/utils/test_config_handling.py +++ b/packages/ragbits-core/tests/unit/utils/test_config_handling.py @@ -22,6 +22,10 @@ def __init__(self, foo: str, bar: int) -> None: self.bar = bar +def example_factory() -> ExampleClassWithConfigMixin: + return ExampleSubclass("aligator", 42) + + def test_defacult_from_config(): config = {"foo": "foo", "bar": 1} instance = ExampleClassWithConfigMixin.from_config(config) @@ -62,3 +66,15 @@ def test_no_default_module(): ) with pytest.raises(InvalidConfigError): ExampleWithNoDefaultModule.subclass_from_config(config) + + +def test_subclass_from_factory(): + instance = ExampleClassWithConfigMixin.subclass_from_factory("unit.utils.test_config_handling:example_factory") + assert isinstance(instance, ExampleSubclass) + assert instance.foo == "aligator" + assert instance.bar == 42 + + +def test_subclass_from_factory_incorrect_class(): + with pytest.raises(InvalidConfigError): + ExampleWithNoDefaultModule.subclass_from_factory("unit.utils.test_config_handling:example_factory") diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py index 1f1e0c90..c53f9c88 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from ragbits.core.prompt.prompt import Prompt -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path module = sys.modules[__name__] @@ -46,7 +46,7 @@ def get_rephraser_prompt(prompt: str) -> type[Prompt[QueryRephraserInput, Any]]: Raises: ValueError: If the prompt class is not a subclass of `Prompt`. """ - prompt_cls = get_cls_from_config(prompt, module) + prompt_cls = import_by_path(prompt, module) if not issubclass(prompt_cls, Prompt): raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}") diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/callbacks/neptune.py b/packages/ragbits-evaluate/src/ragbits/evaluate/callbacks/neptune.py index af2638ca..e9fe7674 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/callbacks/neptune.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/callbacks/neptune.py @@ -3,7 +3,7 @@ import neptune -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path from .base import CallbackConfigurator @@ -21,6 +21,6 @@ def get_callback(self) -> Callable: Returns: Callable: configured neptune callback """ - callback_class = get_cls_from_config(self.config.callback_type, module) + callback_class = import_by_path(self.config.callback_type, module) run = neptune.init_run(project=self.config.project) return callback_class(run) diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/pipeline.py b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/pipeline.py index bb95e151..f1bb3c8b 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/pipeline.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/pipeline.py @@ -7,7 +7,7 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path module = sys.modules[__name__] @@ -120,14 +120,14 @@ def _parse_pipeline_steps(self) -> list[Step]: tasks = [] for task_config in self.config.tasks: llm_config = task_config.llm - llm = get_cls_from_config(llm_config.provider_type, module)(**llm_config.kwargs) + llm = import_by_path(llm_config.provider_type, module)(**llm_config.kwargs) task_kwargs: dict[Any, Any] = {"llm": llm} task_kwargs.update(task_config.kwargs or {}) # type: ignore - task = get_cls_from_config(task_config.type, module)(**task_kwargs) + task = import_by_path(task_config.type, module)(**task_kwargs) tasks.append(task) filter_types = getattr(task_config, "filters", None) or [] for filter_type in filter_types: - filter = get_cls_from_config(filter_type, module)(tasks[-1]) + filter = import_by_path(filter_type, module)(tasks[-1]) tasks.append(filter) return tasks diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/corpus_generation.py b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/corpus_generation.py index 9fe203d1..b1a659bb 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/corpus_generation.py @@ -7,7 +7,7 @@ from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path module = sys.modules[__name__] @@ -23,9 +23,7 @@ def __init__( ): super().__init__() self._llm = llm - self._prompt_class = ( - get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class - ) + self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class self._num_per_topic = num_per_topic @property diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/text_generation/base.py b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/text_generation/base.py index ba9a7154..5a0df65b 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/tasks/text_generation/base.py @@ -6,7 +6,7 @@ from distilabel.steps.tasks import TextGeneration from ragbits.core.prompt import ChatFormat, Prompt -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path module = sys.modules[__name__] @@ -18,9 +18,7 @@ def __init__(self, llm: LLM, inputs: list[str], outputs: list[str], prompt_class super().__init__(llm=llm) self._inputs = inputs self._outputs = outputs - self._prompt_class = ( - get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class - ) + self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class @property def inputs(self) -> list[str]: diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/loaders/__init__.py b/packages/ragbits-evaluate/src/ragbits/evaluate/loaders/__init__.py index 0430ada8..f65da00f 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/loaders/__init__.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/loaders/__init__.py @@ -2,7 +2,7 @@ from omegaconf import DictConfig -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path from .base import DataLoader @@ -17,5 +17,5 @@ def dataloader_factory(config: DictConfig) -> DataLoader: Returns: DataLoader """ - dataloader_class = get_cls_from_config(config.type, module) + dataloader_class = import_by_path(config.type, module) return dataloader_class(config.options) diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/metrics/__init__.py b/packages/ragbits-evaluate/src/ragbits/evaluate/metrics/__init__.py index 6289bdf4..408249a6 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/metrics/__init__.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/metrics/__init__.py @@ -2,7 +2,7 @@ from omegaconf import ListConfig -from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.core.utils.config_handling import import_by_path from .base import MetricSet @@ -19,6 +19,6 @@ def metric_set_factory(cfg: ListConfig) -> MetricSet: """ metrics = [] for metric_cfg in cfg: - metric_module = get_cls_from_config(metric_cfg.type, module) + metric_module = import_by_path(metric_cfg.type, module) metrics.append(metric_module(metric_cfg)) return MetricSet(*metrics)