diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py index 63f69e94..a329de93 100644 --- a/packages/ragbits-core/src/ragbits/core/config.py +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -13,10 +13,10 @@ class CoreConfig(BaseModel): prompt_path_pattern: str = "**/prompt_*.py" # Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm" - default_llm_factories: dict[LLMType, str | None] = { - LLMType.TEXT: None, - LLMType.VISION: None, - LLMType.STRUCTURED_OUTPUT: None, + default_llm_factories: dict[LLMType, str] = { + LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory", + LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory", + LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_structured_output_factory", } diff --git a/packages/ragbits-core/src/ragbits/core/llms/factory.py b/packages/ragbits-core/src/ragbits/core/llms/factory.py index 03b6db09..9405a976 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/factory.py +++ b/packages/ragbits-core/src/ragbits/core/llms/factory.py @@ -13,7 +13,7 @@ def get_llm_from_factory(factory_path: str) -> LLM: factory_path (str): The path to the factory function. Returns: - LLM: An instance of the LLM. + LLM: An instance of the LLM class. """ module_name, function_name = factory_path.rsplit(".", 1) module = importlib.import_module(module_name) @@ -21,17 +21,6 @@ def get_llm_from_factory(factory_path: str) -> LLM: return function() -def has_default_llm(llm_type: LLMType = LLMType.TEXT) -> bool: - """ - Check if the default LLM factory is set in the configuration. - - Returns: - bool: Whether the default LLM factory is set. - """ - default_factory = core_config.default_llm_factories.get(llm_type, None) - return default_factory is not None - - def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM: """ Get an instance of the default LLM using the factory function @@ -43,15 +32,8 @@ def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM: Returns: LLM: An instance of the default LLM. - Raises: - ValueError: If the default LLM factory is not set or expected llm type is not defined in config """ - if llm_type not in core_config.default_llm_factories: - raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.") factory = core_config.default_llm_factories[llm_type] - if factory is None: - raise ValueError("Default LLM factory is not set") - return get_llm_from_factory(factory) @@ -61,7 +43,7 @@ def simple_litellm_factory() -> LLM: default options, and assumes that the API key is set in the environment. Returns: - LLM: An instance of the LiteLLM. + LLM: An instance of the LiteLLM class. """ return LiteLLM() @@ -72,6 +54,16 @@ def simple_litellm_vision_factory() -> LLM: default options, and assumes that the API key is set in the environment. Returns: - LLM: An instance of the LiteLLM. + LLM: An instance of the LiteLLM class. """ return LiteLLM(model_name="gpt-4o-mini") + + +def simple_litellm_structured_output_factory() -> LLM: + """ + A basic LLM factory that creates an LiteLLM instance with the support for structured output. + + Returns: + LLM: An instance of the LiteLLM class. + """ + return LiteLLM(model_name="gpt-4o-mini-2024-07-18", use_structured_output=True) diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py deleted file mode 100644 index 152da8ba..00000000 --- a/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest - -from ragbits.core.config import core_config -from ragbits.core.llms.base import LLMType -from ragbits.core.llms.factory import has_default_llm - - -def test_has_default_llm(monkeypatch: pytest.MonkeyPatch) -> None: - """ - Test the has_default_llm function when the default LLM factory is not set. - """ - monkeypatch.setattr(core_config, "default_llm_factories", {}) - - assert has_default_llm() is False - - -def test_has_default_llm_false(monkeypatch: pytest.MonkeyPatch) -> None: - """ - Test the has_default_llm function when the default LLM factory is set. - """ - monkeypatch.setattr(core_config, "default_llm_factories", {LLMType.TEXT: "my_project.llms.get_llm"}) - - assert has_default_llm() is True diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py index 674a1584..f0f6c269 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py @@ -1,4 +1,3 @@ -import warnings from pathlib import Path from PIL import Image @@ -8,8 +7,7 @@ from unstructured.documents.elements import ElementType from ragbits.core.llms.base import LLM, LLMType -from ragbits.core.llms.factory import get_default_llm, has_default_llm -from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.llms.factory import get_default_llm from ragbits.core.prompt import Prompt from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element, ImageElement @@ -22,7 +20,6 @@ ) DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image." -DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini" class _ImagePrompt(Prompt): @@ -96,15 +93,7 @@ async def _to_image_element( img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y) prompt = _ImagePrompt(_ImagePromptInput(images=[img_bytes])) if self.image_describer is None: - if self._llm is not None: - llm_to_use = self._llm - elif has_default_llm(LLMType.VISION): - llm_to_use = get_default_llm(LLMType.VISION) - else: - warnings.warn( - f"Vision LLM was not provided, setting default option to {DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL}" - ) - llm_to_use = LiteLLM(DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL) + llm_to_use = self._llm if self._llm is not None else get_default_llm(LLMType.VISION) self.image_describer = ImageDescriber(llm_to_use) image_description = await self.image_describer.get_image_description(prompt=prompt) return ImageElement(