Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llms): default values for default_llm_factories #209

Merged
merged 7 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class CoreConfig(BaseModel):
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str | None] = {
LLMType.TEXT: None,
LLMType.VISION: None,
LLMType.STRUCTURED_OUTPUT: None,
default_llm_factories: dict[LLMType, str] = {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory",
}


Expand Down
34 changes: 13 additions & 21 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,14 @@ def get_llm_from_factory(factory_path: str) -> LLM:
factory_path (str): The path to the factory function.

Returns:
LLM: An instance of the LLM.
LLM: An instance of the LLM class.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def has_default_llm(llm_type: LLMType = LLMType.TEXT) -> bool:
"""
Check if the default LLM factory is set in the configuration.

Returns:
bool: Whether the default LLM factory is set.
"""
default_factory = core_config.default_llm_factories.get(llm_type, None)
return default_factory is not None


def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
Expand All @@ -43,15 +32,8 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
Returns:
LLM: An instance of the default LLM.

Raises:
ValueError: If the default LLM factory is not set or expected llm type is not defined in config
"""
if llm_type not in core_config.default_llm_factories:
raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.")
factory = core_config.default_llm_factories[llm_type]
if factory is None:
raise ValueError("Default LLM factory is not set")

return get_llm_from_factory(factory)


Expand All @@ -61,7 +43,7 @@ def simple_litellm_factory() -> LLM:
default options, and assumes that the API key is set in the environment.

Returns:
LLM: An instance of the LiteLLM.
LLM: An instance of the LiteLLM class.
"""
return LiteLLM()

Expand All @@ -72,6 +54,16 @@ def simple_litellm_vision_factory() -> LLM:
default options, and assumes that the API key is set in the environment.

Returns:
LLM: An instance of the LiteLLM.
LLM: An instance of the LiteLLM class.
"""
return LiteLLM(model_name="gpt-4o-mini")


def simple_litellm_structured_output_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the support for structured output.

Returns:
LLM: An instance of the LiteLLM class.
"""
return LiteLLM(model_name="gpt-4o-mini-2024-07-18", use_structured_output=True)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from pathlib import Path

from PIL import Image
Expand All @@ -8,8 +7,7 @@
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.factory import get_default_llm, has_default_llm
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.llms.factory import get_default_llm
from ragbits.core.prompt import Prompt
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
Expand All @@ -22,7 +20,6 @@
)

DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image."
DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini"


class _ImagePrompt(Prompt):
Expand Down Expand Up @@ -96,15 +93,7 @@ async def _to_image_element(
img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
prompt = _ImagePrompt(_ImagePromptInput(images=[img_bytes]))
if self.image_describer is None:
if self._llm is not None:
llm_to_use = self._llm
elif has_default_llm(LLMType.VISION):
llm_to_use = get_default_llm(LLMType.VISION)
else:
warnings.warn(
f"Vision LLM was not provided, setting default option to {DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL}"
)
llm_to_use = LiteLLM(DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL)
llm_to_use = self._llm if self._llm is not None else get_default_llm(LLMType.VISION)
self.image_describer = ImageDescriber(llm_to_use)
image_description = await self.image_describer.get_image_description(prompt=prompt)
return ImageElement(
Expand Down
Loading