Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llms): option to set a default LLM factory #101

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@ class CoreConfig(BaseModel):
# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm"
default_llm_factory: str | None = None


core_config = get_config_instance(CoreConfig, subproject="core")
60 changes: 60 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM
from ragbits.core.llms.litellm import LiteLLM


def get_llm_from_factory(factory_path: str) -> LLM:
"""
Get an instance of an LLM using a factory function specified by the user.

Args:
factory_path (str): The path to the factory function.

Returns:
LLM: An instance of the LLM.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def has_default_llm() -> bool:
"""
Check if the default LLM factory is set in the configuration.

Returns:
bool: Whether the default LLM factory is set.
"""
return core_config.default_llm_factory is not None


def get_default_llm() -> LLM:
"""
Get an instance of the default LLM using the factory function
specified in the configuration.

Returns:
LLM: An instance of the default LLM.

Raises:
ValueError: If the default LLM factory is not set.
"""
factory = core_config.default_llm_factory
if factory is None:
raise ValueError("Default LLM factory is not set")

return get_llm_from_factory(factory)


def simple_litellm_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the default model,
default options, and assumes that the API key is set in the environment.

Returns:
LLM: An instance of the LiteLLM.
"""
return LiteLLM()
35 changes: 18 additions & 17 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
from ragbits.core.llms import LLM
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery

Expand All @@ -30,14 +30,12 @@ class PromptState:
Attributes:
prompts (list): A list containing discovered prompts.
rendered_prompt (Prompt): The most recently rendered Prompt instance.
llm_model_name (str): The name of the selected LLM model.
llm_api_key (str | None): The API key for the chosen LLM model.
llm (LLM): The LLM instance to be used for generating responses.
"""

prompts: list = field(default_factory=list)
rendered_prompt: Prompt | None = None
llm_model_name: str | None = None
llm_api_key: str | None = None
llm: LLM | None = None


def render_prompt(index: int, system_prompt: str, user_prompt: str, state: PromptState, *args: Any) -> PromptState:
Expand Down Expand Up @@ -99,15 +97,17 @@ def send_prompt_to_llm(state: PromptState) -> str:

Returns:
str: The response generated by the LLM.
"""
assert state.llm_model_name is not None, "LLM model name is not set."
llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key)

Raises:
ValueError: If the LLM model is not configured.
"""
assert state.rendered_prompt is not None, "Prompt has not been rendered yet."

if state.llm is None:
raise ValueError("LLM model is not configured.")

try:
response = asyncio.run(
llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions())
)
response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt))
except Exception as e: # pylint: disable=broad-except
response = str(e)

Expand Down Expand Up @@ -136,7 +136,8 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:


def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand All @@ -163,9 +164,8 @@ def lab_app( # pylint: disable=missing-param-doc
with gr.Blocks() as gr_app:
prompts_state = gr.State(
PromptState(
llm_model_name=llm_model,
llm_api_key=llm_api_key,
prompts=list(prompts),
llm=get_llm_from_factory(llm_factory) if llm_factory else None,
)
)

Expand Down Expand Up @@ -220,14 +220,15 @@ def show_split(index: int, state: gr.State) -> None:
)
gr.Textbox(label="Rendered User Prompt", value=rendered_user_prompt, interactive=False)

llm_enabled = state.llm_model_name is not None
llm_enabled = state.llm is not None
prompt_ready = state.rendered_prompt is not None
llm_request_button = gr.Button(
value="Send to LLM",
interactive=llm_enabled and prompt_ready,
)
gr.Markdown(
"To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled
"To enable this button set an LLM factory function in CLI options or your pyproject.toml",
visible=not llm_enabled,
)
gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready)
llm_prompt_response = gr.Textbox(lines=10, label="LLM response")
Expand Down
5 changes: 5 additions & 0 deletions packages/ragbits-core/tests/unit/llms/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
from pathlib import Path

# Add "llms" to sys.path
sys.path.append(str(Path(__file__).parent.parent))
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ragbits.core.config import core_config
from ragbits.core.llms.factory import get_default_llm
from ragbits.core.llms.litellm import LiteLLM


def test_get_default_llm(monkeypatch):
"""
Test the get_llm_from_factory function.
"""

monkeypatch.setattr(core_config, "default_llm_factory", "factory.test_get_llm_from_factory.mock_llm_factory")

llm = get_default_llm()
assert isinstance(llm, LiteLLM)
assert llm.model_name == "mock_model"
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.llms.litellm import LiteLLM


def mock_llm_factory() -> LiteLLM:
"""
A mock LLM factory that creates a LiteLLM instance with a mock model name.

Returns:
LiteLLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="mock_model")


def test_get_llm_from_factory():
"""
Test the get_llm_from_factory function.
"""
llm = get_llm_from_factory("factory.test_get_llm_from_factory.mock_llm_factory")

assert isinstance(llm, LiteLLM)
assert llm.model_name == "mock_model"
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ragbits.core.config import core_config
from ragbits.core.llms.factory import has_default_llm


def test_has_default_llm(monkeypatch):
"""
Test the has_default_llm function when the default LLM factory is not set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", None)

assert has_default_llm() is False


def test_has_default_llm_false(monkeypatch):
"""
Test the has_default_llm function when the default LLM factory is set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm")

assert has_default_llm() is True
14 changes: 2 additions & 12 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading