Skip to content

Commit

Permalink
feat(llms): default values for default_llm_factories (#209)
Browse files Browse the repository at this point in the history
Co-authored-by: Ludwik Trammer <[email protected]>
  • Loading branch information
kdziedzic68 and ludwiktrammer authored Nov 28, 2024
1 parent 9523e89 commit 0c4ef7b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 61 deletions.
8 changes: 4 additions & 4 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class CoreConfig(BaseModel):
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str | None] = {
LLMType.TEXT: None,
LLMType.VISION: None,
LLMType.STRUCTURED_OUTPUT: None,
default_llm_factories: dict[LLMType, str] = {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory",
}


Expand Down
34 changes: 13 additions & 21 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,14 @@ def get_llm_from_factory(factory_path: str) -> LLM:
factory_path (str): The path to the factory function.
Returns:
LLM: An instance of the LLM.
LLM: An instance of the LLM class.
"""
module_name, function_name = factory_path.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
return function()


def has_default_llm(llm_type: LLMType = LLMType.TEXT) -> bool:
"""
Check if the default LLM factory is set in the configuration.
Returns:
bool: Whether the default LLM factory is set.
"""
default_factory = core_config.default_llm_factories.get(llm_type, None)
return default_factory is not None


def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
Expand All @@ -43,15 +32,8 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
Returns:
LLM: An instance of the default LLM.
Raises:
ValueError: If the default LLM factory is not set or expected llm type is not defined in config
"""
if llm_type not in core_config.default_llm_factories:
raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.")
factory = core_config.default_llm_factories[llm_type]
if factory is None:
raise ValueError("Default LLM factory is not set")

return get_llm_from_factory(factory)


Expand All @@ -61,7 +43,7 @@ def simple_litellm_factory() -> LLM:
default options, and assumes that the API key is set in the environment.
Returns:
LLM: An instance of the LiteLLM.
LLM: An instance of the LiteLLM class.
"""
return LiteLLM()

Expand All @@ -72,6 +54,16 @@ def simple_litellm_vision_factory() -> LLM:
default options, and assumes that the API key is set in the environment.
Returns:
LLM: An instance of the LiteLLM.
LLM: An instance of the LiteLLM class.
"""
return LiteLLM(model_name="gpt-4o-mini")


def simple_litellm_structured_output_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the support for structured output.
Returns:
LLM: An instance of the LiteLLM class.
"""
return LiteLLM(model_name="gpt-4o-mini-2024-07-18", use_structured_output=True)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from pathlib import Path

from PIL import Image
Expand All @@ -8,8 +7,7 @@
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.factory import get_default_llm, has_default_llm
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.llms.factory import get_default_llm
from ragbits.core.prompt import Prompt
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
Expand All @@ -22,7 +20,6 @@
)

DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image."
DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini"


class _ImagePrompt(Prompt):
Expand Down Expand Up @@ -96,15 +93,7 @@ async def _to_image_element(
img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
prompt = _ImagePrompt(_ImagePromptInput(images=[img_bytes]))
if self.image_describer is None:
if self._llm is not None:
llm_to_use = self._llm
elif has_default_llm(LLMType.VISION):
llm_to_use = get_default_llm(LLMType.VISION)
else:
warnings.warn(
f"Vision LLM was not provided, setting default option to {DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL}"
)
llm_to_use = LiteLLM(DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL)
llm_to_use = self._llm if self._llm is not None else get_default_llm(LLMType.VISION)
self.image_describer = ImageDescriber(llm_to_use)
image_description = await self.image_describer.get_image_description(prompt=prompt)
return ImageElement(
Expand Down

0 comments on commit 0c4ef7b

Please sign in to comment.