Skip to content

Commit

Permalink
feat(llms): option to set a default LLM factory
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 14, 2024
1 parent cc33d4c commit 3a6d5ac
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 17 deletions.
3 changes: 3 additions & 0 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@ class CoreConfig(BaseModel):
# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm"
default_llm_factory: str | None = None


core_config = get_config_instance(CoreConfig, subproject="core")
60 changes: 60 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM
from ragbits.core.llms.litellm import LiteLLM


def get_llm_from_factory(factory_path: str) -> LLM:
"""
Get an instance of an LLM using a factory function specified by the user.
Args:
factory_path (str): The path to the factory function.
Returns:
LLM: An instance of the LLM.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def has_default_llm() -> bool:
"""
Check if the default LLM factory is set in the configuration.
Returns:
bool: Whether the default LLM factory is set.
"""
return core_config.default_llm_factory is not None


def get_default_llm() -> LLM:
"""
Get an instance of the default LLM using the factory function
specified in the configuration.
Returns:
LLM: An instance of the default LLM.
Raises:
ValueError: If the default LLM factory is not set.
"""
factory = core_config.default_llm_factory
if factory is None:
raise ValueError("Default LLM factory is not set")

return get_llm_from_factory(factory)


def simple_litellm_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the default model,
default options, and assumes that the API key is set in the environment.
Returns:
LLM: An instance of the LiteLLM.
"""
return LiteLLM()
35 changes: 18 additions & 17 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
from ragbits.core.llms import LLM
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery

Expand All @@ -30,14 +30,12 @@ class PromptState:
Attributes:
prompts (list): A list containing discovered prompts.
rendered_prompt (Prompt): The most recently rendered Prompt instance.
llm_model_name (str): The name of the selected LLM model.
llm_api_key (str | None): The API key for the chosen LLM model.
llm (LLM): The LLM instance to be used for generating responses.
"""

prompts: list = field(default_factory=list)
rendered_prompt: Prompt | None = None
llm_model_name: str | None = None
llm_api_key: str | None = None
llm: LLM | None = None


def render_prompt(index: int, system_prompt: str, user_prompt: str, state: PromptState, *args: Any) -> PromptState:
Expand Down Expand Up @@ -99,15 +97,17 @@ def send_prompt_to_llm(state: PromptState) -> str:
Returns:
str: The response generated by the LLM.
"""
assert state.llm_model_name is not None, "LLM model name is not set."
llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key)
Raises:
ValueError: If the LLM model is not configured.
"""
assert state.rendered_prompt is not None, "Prompt has not been rendered yet."

if state.llm is None:
raise ValueError("LLM model is not configured.")

try:
response = asyncio.run(
llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions())
)
response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt))
except Exception as e: # pylint: disable=broad-except
response = str(e)

Expand Down Expand Up @@ -136,7 +136,8 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:


def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand All @@ -163,9 +164,8 @@ def lab_app( # pylint: disable=missing-param-doc
with gr.Blocks() as gr_app:
prompts_state = gr.State(
PromptState(
llm_model_name=llm_model,
llm_api_key=llm_api_key,
prompts=list(prompts),
llm=get_llm_from_factory(llm_factory) if llm_factory else None,
)
)

Expand Down Expand Up @@ -220,14 +220,15 @@ def show_split(index: int, state: gr.State) -> None:
)
gr.Textbox(label="Rendered User Prompt", value=rendered_user_prompt, interactive=False)

llm_enabled = state.llm_model_name is not None
llm_enabled = state.llm is not None
prompt_ready = state.rendered_prompt is not None
llm_request_button = gr.Button(
value="Send to LLM",
interactive=llm_enabled and prompt_ready,
)
gr.Markdown(
"To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled
"To enable this button set an LLM factory function in CLI options or your pyproject.toml",
visible=not llm_enabled,
)
gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready)
llm_prompt_response = gr.Textbox(lines=10, label="LLM response")
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ragbits.core.config import core_config
from ragbits.core.llms.factory import get_default_llm
from ragbits.core.llms.litellm import LiteLLM


def test_get_default_llm(monkeypatch):
"""
Test the get_llm_from_factory function.
"""
monkeypatch.setattr(
core_config, "default_llm_factory", "tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory"
)

llm = get_default_llm()

Check failure on line 14 in packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_get_default_llm

ModuleNotFoundError: No module named 'tests.unit.llms'
Raw output
monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x7f28662192d0>

    def test_get_default_llm(monkeypatch):
        """
        Test the get_llm_from_factory function.
        """
        monkeypatch.setattr(
            core_config, "default_llm_factory", "tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory"
        )
    
>       llm = get_default_llm()

packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py:14: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
packages/ragbits-core/src/ragbits/core/llms/factory.py:49: in get_default_llm
    return get_llm_from_factory(factory)
packages/ragbits-core/src/ragbits/core/llms/factory.py:19: in get_llm_from_factory
    module = importlib.import_module(module_name)
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:992: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:241: in _call_with_frames_removed
    ???
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:992: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:241: in _call_with_frames_removed
    ???
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

name = 'tests.unit.llms', import_ = <function _gcd_import at 0x7f2925d63400>

>   ???
E   ModuleNotFoundError: No module named 'tests.unit.llms'

<frozen importlib._bootstrap>:1004: ModuleNotFoundError
assert isinstance(llm, LiteLLM)
assert llm.model_name == "mock_model"
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.llms.litellm import LiteLLM


def mock_llm_factory() -> LiteLLM:
"""
A mock LLM factory that creates a LiteLLM instance with a mock model name.
Returns:
LiteLLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="mock_model")


def test_get_llm_from_factory():
"""
Test the get_llm_from_factory function.
"""
llm = get_llm_from_factory("tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory")

Check failure on line 19 in packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_get_llm_from_factory

ModuleNotFoundError: No module named 'tests.unit.llms'
Raw output
def test_get_llm_from_factory():
        """
        Test the get_llm_from_factory function.
        """
>       llm = get_llm_from_factory("tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory")

packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py:19: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
packages/ragbits-core/src/ragbits/core/llms/factory.py:19: in get_llm_from_factory
    module = importlib.import_module(module_name)
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:992: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:241: in _call_with_frames_removed
    ???
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:992: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:241: in _call_with_frames_removed
    ???
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

name = 'tests.unit.llms', import_ = <function _gcd_import at 0x7f2925d63400>

>   ???
E   ModuleNotFoundError: No module named 'tests.unit.llms'

<frozen importlib._bootstrap>:1004: ModuleNotFoundError

assert isinstance(llm, LiteLLM)
assert llm.model_name == "mock_model"
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ragbits.core.config import core_config
from ragbits.core.llms.factory import has_default_llm


def test_has_default_llm(monkeypatch):
"""
Test the has_default_llm function when the default LLM factory is not set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", None)

assert has_default_llm() is False


def test_has_default_llm_false(monkeypatch):
"""
Test the has_default_llm function when the default LLM factory is set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm")

assert has_default_llm() is True

0 comments on commit 3a6d5ac

Please sign in to comment.