Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): user should be able to configure default types of default… #153

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rich import print as pprint

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.prompt.prompt import Prompt


Expand Down Expand Up @@ -38,7 +39,7 @@ def register(app: typer.Typer) -> None:
@prompts_app.command()
def lab(
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down Expand Up @@ -73,15 +74,16 @@ def render(prompt_path: str, payload: str | None = None) -> None:

@prompts_app.command(name="exec")
def execute(
prompt_path: str, payload: str | None = None, llm_factory: str | None = core_config.default_llm_factory
prompt_path: str,
payload: str | None = None,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Executes a prompt using the specified prompt class and LLM factory.

Raises:
ValueError: If `llm_factory` is not provided.
"""

from ragbits.core.llms.factory import get_llm_from_factory

prompt = _render(prompt_path=prompt_path, payload=payload)
Expand Down
9 changes: 7 additions & 2 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel

from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance


Expand All @@ -11,8 +12,12 @@ class CoreConfig(BaseModel):
# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm"
default_llm_factory: str | None = None
# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str | None] = {
LLMType.TEXT: None,
LLMType.VISION: None,
LLMType.STRUCTURED_OUTPUT: None,
}


core_config = get_config_instance(CoreConfig, subproject="core")
11 changes: 11 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Generic, cast, overload
Expand All @@ -7,6 +8,16 @@
from .clients.base import LLMClient, LLMClientOptions, LLMOptions


class LLMType(enum.Enum):
"""
Types of LLMs based on supported features
"""

TEXT = "text"
VISION = "vision"
STRUCTURED_OUTPUT = "structured_output"


class LLM(Generic[LLMClientOptions], ABC):
"""
Abstract class for interaction with Large Language Model.
Expand Down
29 changes: 23 additions & 6 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.litellm import LiteLLM


Expand All @@ -21,28 +21,34 @@ def get_llm_from_factory(factory_path: str) -> LLM:
return function()


def has_default_llm() -> bool:
def has_default_llm(llm_type: LLMType = LLMType.TEXT) -> bool:
"""
Check if the default LLM factory is set in the configuration.

Returns:
bool: Whether the default LLM factory is set.
"""
return core_config.default_llm_factory is not None
default_factory = core_config.default_llm_factories.get(llm_type, None)
return default_factory is not None


def get_default_llm() -> LLM:
def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
specified in the configuration.

Args:
llm_type: type of the LLM to get, defaults to text

Returns:
LLM: An instance of the default LLM.

Raises:
ValueError: If the default LLM factory is not set.
ValueError: If the default LLM factory is not set or expected llm type is not defined in config
"""
factory = core_config.default_llm_factory
if llm_type not in core_config.default_llm_factories:
raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.")
factory = core_config.default_llm_factories[llm_type]
if factory is None:
raise ValueError("Default LLM factory is not set")

Expand All @@ -58,3 +64,14 @@ def simple_litellm_factory() -> LLM:
LLM: An instance of the LiteLLM.
"""
return LiteLLM()


def simple_litellm_vision_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the vision enabled model,
default options, and assumes that the API key is set in the environment.

Returns:
LLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="gpt-4o-mini")
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ragbits.core.config import core_config
from ragbits.core.llms import LLM
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:

def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down
16 changes: 15 additions & 1 deletion packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import enum
from pathlib import Path
from typing import Any, TypeVar

import tomli
from pydantic import BaseModel

from ragbits.core.llms.base import LLMType


def find_pyproject(current_dir: Path | None = None) -> Path:
"""
Expand Down Expand Up @@ -64,7 +67,7 @@ def get_config_instance(
model: type[ConfigModelT], subproject: str | None = None, current_dir: Path | None = None
) -> ConfigModelT:
"""
Creates an instace of pydantic model loaded with the configuration from pyproject.toml.
Creates an instance of pydantic model loaded with the configuration from pyproject.toml.

Args:
model (Type[BaseModel]): The pydantic model to instantiate.
Expand All @@ -81,4 +84,15 @@ def get_config_instance(
config = get_ragbits_config(current_dir)
if subproject:
config = config.get(subproject, {})
if "default_llm_factories" in config:
config["default_llm_factories"] = {
_resolve_enum_member(k): v for k, v in config["default_llm_factories"].items()
}
return model(**config)


def _resolve_enum_member(enum_string: str) -> enum.Enum:
try:
return LLMType(enum_string)
except ValueError as err:
raise ValueError("Unsupported LLMType value provided in default_llm_factories in pyproject.toml") from err
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

konrad-czarnota-ds marked this conversation as resolved.
Show resolved Hide resolved
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import get_default_llm
from ragbits.core.llms.litellm import LiteLLM

Expand All @@ -9,7 +10,9 @@ def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the get_llm_from_factory function.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "factory.test_get_llm_from_factory.mock_llm_factory")
monkeypatch.setattr(
core_config, "default_llm_factories", {LLMType.TEXT: "factory.test_get_llm_from_factory.mock_llm_factory"}
)

llm = get_default_llm()
assert isinstance(llm, LiteLLM)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import pytest

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import has_default_llm


def test_has_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the has_default_llm function when the default LLM factory is not set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", None)
monkeypatch.setattr(core_config, "default_llm_factories", {})

assert has_default_llm() is False

Expand All @@ -17,6 +18,6 @@ def test_has_default_llm_false(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the has_default_llm function when the default LLM factory is set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm")
monkeypatch.setattr(core_config, "default_llm_factories", {LLMType.TEXT: "my_project.llms.get_llm"})

assert has_default_llm() is True
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path

import pytest
from pydantic import BaseModel

from ragbits.core.config import CoreConfig
from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance

projects_dir = Path(__file__).parent / "testprojects"
Expand Down Expand Up @@ -66,3 +69,30 @@ def test_get_config_instance_no_file():
)

assert config == OptionalHappyProjectConfig()


def test_get_config_instance_factories():
"""Test that default LLMs are loaded correctly"""
config = get_config_instance(
CoreConfig,
subproject="core",
current_dir=projects_dir / "factory_project",
)

assert config.default_llm_factories == {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_vision_factory",
}


def test_get_config_instance_bad_factories():
"""Test that non-existing LLM defined in pyproject raises error"""
with pytest.raises(ValueError) as err:
get_config_instance(
CoreConfig,
subproject="core",
current_dir=projects_dir / "bad_factory_project",
)

assert "Unsupported LLMType value provided in default_llm_factories in pyproject.toml" in str(err.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[project]
name = "bad_factory_project"

[tool.ragbits.core.default_llm_factories]
non_existing = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[project]
name = "factory_project"

[tool.ragbits.core.default_llm_factories]
text = "ragbits.core.llms.factory.simple_litellm_factory"
vision = "ragbits.core.llms.factory.simple_litellm_vision_factory"
structured_output = "ragbits.core.llms.factory.simple_litellm_vision_factory"
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import warnings
from pathlib import Path

from PIL import Image
from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.factory import get_default_llm, has_default_llm
from ragbits.core.llms.litellm import LiteLLM
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
Expand All @@ -17,7 +19,7 @@
to_text_element,
)

DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini"
DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini"


class UnstructuredImageProvider(UnstructuredDefaultProvider):
Expand Down Expand Up @@ -53,7 +55,8 @@ def __init__(
llm: llm to use
"""
super().__init__(partition_kwargs, chunking_kwargs, api_key, api_server, use_api)
self.image_summarizer = ImageDescriber(llm or LiteLLM(DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL))
self.image_describer: ImageDescriber | None = None
self._llm = llm

async def _chunk_and_convert(
self, elements: list[UnstructuredElement], document_meta: DocumentMeta, document_path: Path
Expand All @@ -79,7 +82,18 @@ async def _to_image_element(
)

img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
image_description = await self.image_summarizer.get_image_description(img_bytes)
if self.image_describer is None:
if self._llm is not None:
llm_to_use = self._llm
elif has_default_llm(LLMType.VISION):
llm_to_use = get_default_llm(LLMType.VISION)
else:
warnings.warn(
f"Vision LLM was not provided, setting default option to {DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL}"
)
llm_to_use = LiteLLM(DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL)
self.image_describer = ImageDescriber(llm_to_use)
image_description = await self.image_describer.get_image_description(img_bytes)
return ImageElement(
description=image_description,
ocr_extracted_text=element.text,
Expand Down
Loading