Skip to content

Commit

Permalink
refactor: option for classes to create instances from factory path (#240
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ludwiktrammer authored Dec 9, 2024
1 parent 0b6e1e1 commit 439e445
Show file tree
Hide file tree
Showing 22 changed files with 105 additions and 108 deletions.
6 changes: 3 additions & 3 deletions examples/evaluation/document-search/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hydra
from omegaconf import DictConfig, OmegaConf

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path
from ragbits.evaluate.loaders import dataloader_factory
from ragbits.evaluate.metrics import metric_set_factory
from ragbits.evaluate.optimizer import Optimizer
Expand All @@ -21,12 +21,12 @@ def main(config: DictConfig) -> None:
config: Hydra configuration.
"""
dataloader = dataloader_factory(config.data)
pipeline_class = get_cls_from_config(config.pipeline.type, module)
pipeline_class = import_by_path(config.pipeline.type, module)
metrics = metric_set_factory(config.metrics)
callback_configurators = None
if getattr(config, "callbacks", None):
callback_configurators = [
get_cls_from_config(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks
import_by_path(callback_cfg.type, module)(callback_cfg.args) for callback_cfg in config.callbacks
]

optimization_cfg = OmegaConf.create({"direction": "maximize", "n_trials": 10})
Expand Down
6 changes: 2 additions & 4 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt


Expand Down Expand Up @@ -91,13 +91,11 @@ def execute(
Raises:
ValueError: If `llm_factory` is not provided.
"""
from ragbits.core.llms.factory import get_llm_from_factory

prompt = _render(prompt_path=prompt_path, payload=payload)

if llm_factory is None:
raise ValueError("`llm_factory` must be provided")
llm = get_llm_from_factory(llm_factory)
llm: LLM = LLM.subclass_from_factory(llm_factory)

llm_output = asyncio.run(llm.generate(prompt))
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output)
Expand Down
6 changes: 3 additions & 3 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class CoreConfig(BaseModel):

# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str] = {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory",
LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory",
}


Expand Down
20 changes: 1 addition & 19 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.litellm import LiteLLM


def get_llm_from_factory(factory_path: str) -> LLM:
"""
Get an instance of an LLM using a factory function specified by the user.
Args:
factory_path (str): The path to the factory function.
Returns:
LLM: An instance of the LLM class.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
Expand All @@ -34,7 +16,7 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
factory = core_config.default_llm_factories[llm_type]
return get_llm_from_factory(factory)
return LLM.subclass_from_factory(factory)


def simple_litellm_factory() -> LLM:
Expand Down
3 changes: 1 addition & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ragbits.core.config import core_config
from ragbits.core.llms import LLM
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery

Expand Down Expand Up @@ -166,7 +165,7 @@ def lab_app( # pylint: disable=missing-param-doc
prompts_state = gr.State(
PromptState(
prompts=list(prompts),
llm=get_llm_from_factory(llm_factory) if llm_factory else None,
llm=LLM.subclass_from_factory(llm_factory) if llm_factory else None,
)
)

Expand Down
55 changes: 39 additions & 16 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,39 @@ class InvalidConfigError(Exception):
"""


def get_cls_from_config(cls_path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
If the first part is ommited, the default module is used.
Args:
cls_path: A string representing the path to the class or object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:ClassName)
if the class is located in a different module.
default_module: The default module to search for the class if no specific module
is provided in the type string.
path: A string representing the path to the object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:object_name)
if the object is located in a different module.
default_module: The default module to search for the object if no specific module
is provided in the path string.
Returns:
Any: The object retrieved from the specified or default module.
Raises:
InvalidConfigError: The requested class is not found under the specified module
InvalidConfigError: The requested object is not found under the specified module
"""
if ":" in cls_path:
if ":" in path:
try:
module_stringified, object_stringified = cls_path.split(":")
module_stringified, object_stringified = path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)
except AttributeError as err:
raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err
raise InvalidConfigError(f"{object_stringified} not found in module {module_stringified}") from err

if default_module is None:
raise InvalidConfigError("Given type string does not contain a module and no default module provided")
raise InvalidConfigError("Not provided a full path and no default module specified")

try:
return getattr(default_module, cls_path)
return getattr(default_module, path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err
raise InvalidConfigError(f"{path} not found in module {default_module}") from err


class ObjectContructionConfig(BaseModel):
Expand Down Expand Up @@ -83,12 +83,35 @@ def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
Raises:
InvalidConfigError: The class can't be found or is not a subclass of the current class.
"""
subclass = get_cls_from_config(config.type, cls.default_module)
subclass = import_by_path(config.type, cls.default_module)
if not issubclass(subclass, cls):
raise InvalidConfigError(f"{subclass} is not a subclass of {cls}")

return subclass.from_config(config.config)

@classmethod
def subclass_from_factory(cls, factory_path: str) -> Self:
"""
Creates the class using the provided factory function. May return a subclass of the class,
if requested by the factory.
Args:
factory_path: A string representing the path to the factory function
in the format of "module.submodule:factory_name".
Returns:
An instance of the class initialized with the provided factory function.
Raises:
InvalidConfigError: The factory can't be found or the object returned
is not a subclass of the current class.
"""
factory = import_by_path(factory_path, cls.default_module)
obj = factory()
if not isinstance(obj, cls):
raise InvalidConfigError(f"The object returned by factory {factory_path} is not an instance of {cls}")
return obj

@classmethod
def from_config(cls, config: dict) -> Self:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ragbits.core.audit import traceable
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
from ragbits.core.utils.dict_transformations import flatten_dict, unflatten_dict
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery

Expand Down Expand Up @@ -59,7 +59,7 @@ def from_config(cls, config: dict) -> Self:
InvalidConfigError: The client or metadata_store class can't be found or is not the correct type.
"""
client_options = ObjectContructionConfig.model_validate(config["client"])
client_cls = get_cls_from_config(client_options.type, chromadb)
client_cls = import_by_path(client_options.type, chromadb)
config["client"] = client_cls(**client_options.config)
return super().from_config(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ragbits.core.audit import traceable
from ragbits.core.metadata_stores.base import MetadataStore
from ragbits.core.utils.config_handling import ObjectContructionConfig, get_cls_from_config
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions


Expand Down Expand Up @@ -56,7 +56,7 @@ def from_config(cls, config: dict) -> Self:
InvalidConfigError: The client or metadata_store class can't be found or is not the correct type.
"""
client_options = ObjectContructionConfig.model_validate(config["client"])
client_cls = get_cls_from_config(client_options.type, qdrant_client)
client_cls = import_by_path(client_options.type, qdrant_client)
config["client"] = client_cls(**client_options.config)
return super().from_config(config)

Expand Down
5 changes: 0 additions & 5 deletions packages/ragbits-core/tests/unit/llms/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
import sys
from pathlib import Path

# Add "llms" to sys.path
sys.path.append(str(Path(__file__).parent.parent))
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@
from ragbits.core.llms.litellm import LiteLLM


def mock_llm_factory() -> LiteLLM:
"""
A mock LLM factory that creates a LiteLLM instance with a mock model name.
Returns:
LiteLLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="mock_model")


def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the get_llm_from_factory function.
"""
monkeypatch.setattr(
core_config, "default_llm_factories", {LLMType.TEXT: "factory.test_get_llm_from_factory.mock_llm_factory"}
core_config, "default_llm_factories", {LLMType.TEXT: "unit.llms.factory.test_get_default_llm:mock_llm_factory"}
)

llm = get_default_llm()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def test_get_config_instance_factories():
)

assert config.default_llm_factories == {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.TEXT: "ragbits.core.llms.factory:simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory:simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_vision_factory",
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "bad_factory_project"

[tool.ragbits.core.default_llm_factories]
non_existing = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
non_existing = "ragbits.core.llms.factory:simple_litellm_factory"
vision = "ragbits.core.llms.factory:simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory"
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "factory_project"

[tool.ragbits.core.default_llm_factories]
text = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
text = "ragbits.core.llms.factory:simple_litellm_factory"
vision = "ragbits.core.llms.factory:simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory:simple_litellm_vision_factory"
16 changes: 16 additions & 0 deletions packages/ragbits-core/tests/unit/utils/test_config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, foo: str, bar: int) -> None:
self.bar = bar


def example_factory() -> ExampleClassWithConfigMixin:
return ExampleSubclass("aligator", 42)


def test_defacult_from_config():
config = {"foo": "foo", "bar": 1}
instance = ExampleClassWithConfigMixin.from_config(config)
Expand Down Expand Up @@ -62,3 +66,15 @@ def test_no_default_module():
)
with pytest.raises(InvalidConfigError):
ExampleWithNoDefaultModule.subclass_from_config(config)


def test_subclass_from_factory():
instance = ExampleClassWithConfigMixin.subclass_from_factory("unit.utils.test_config_handling:example_factory")
assert isinstance(instance, ExampleSubclass)
assert instance.foo == "aligator"
assert instance.bar == 42


def test_subclass_from_factory_incorrect_class():
with pytest.raises(InvalidConfigError):
ExampleWithNoDefaultModule.subclass_from_factory("unit.utils.test_config_handling:example_factory")
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from ragbits.core.prompt.prompt import Prompt
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path

module = sys.modules[__name__]

Expand Down Expand Up @@ -46,7 +46,7 @@ def get_rephraser_prompt(prompt: str) -> type[Prompt[QueryRephraserInput, Any]]:
Raises:
ValueError: If the prompt class is not a subclass of `Prompt`.
"""
prompt_cls = get_cls_from_config(prompt, module)
prompt_cls = import_by_path(prompt, module)

if not issubclass(prompt_cls, Prompt):
raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import neptune

from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.utils.config_handling import import_by_path

from .base import CallbackConfigurator

Expand All @@ -21,6 +21,6 @@ def get_callback(self) -> Callable:
Returns:
Callable: configured neptune callback
"""
callback_class = get_cls_from_config(self.config.callback_type, module)
callback_class = import_by_path(self.config.callback_type, module)
run = neptune.init_run(project=self.config.project)
return callback_class(run)
Loading

0 comments on commit 439e445

Please sign in to comment.