Skip to content

Commit

Permalink
suport images in prompts
Browse files Browse the repository at this point in the history
fix without touching precise mypy requirements

add auto detection images from fields

remove raise from generate_raw docstring

remove add user message

pr comments

use prompt structure in unstructured image describer

pr comments

remove redundancies

remove redundant argument from llm generate

fix import

remove redundant type igonre

do not assume role user for last message in formating chat for llm

image fileds definition

fix import
  • Loading branch information
kdziedzic68 committed Oct 28, 2024
1 parent 2d92dae commit d49944d
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 30 deletions.
10 changes: 8 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings as wrngs
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Generic, Optional, Type, cast, overload

from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, OutputT
from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT

from .clients.base import LLMClient, LLMClientOptions, LLMOptions

Expand Down Expand Up @@ -70,7 +71,7 @@ async def generate_raw(
options = (self.default_options | options) if options else self.default_options

response = await self.client.call(
conversation=prompt.chat,
conversation=self._format_chat_for_llm(prompt),
options=options,
json_mode=prompt.json_mode,
output_schema=prompt.output_schema(),
Expand Down Expand Up @@ -119,3 +120,8 @@ async def generate(
return prompt.parse_response(response)

return cast(OutputT, response)

def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
if prompt.list_images():
wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}")
return prompt.chat
28 changes: 27 additions & 1 deletion packages/ragbits-core/src/ragbits/core/llms/litellm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import warnings
from functools import cached_property
from typing import Optional

Expand All @@ -8,7 +10,7 @@
except ImportError:
HAS_LITELLM = False

from ragbits.core.prompt.base import BasePrompt
from ragbits.core.prompt.base import BasePrompt, ChatFormat

from .base import LLM
from .clients.litellm import LiteLLMClient, LiteLLMOptions
Expand Down Expand Up @@ -83,3 +85,27 @@ def count_tokens(self, prompt: BasePrompt) -> int:
Number of tokens in the prompt.
"""
return sum(litellm.token_counter(model=self.model_name, text=message["content"]) for message in prompt.chat)

def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat:
images = prompt.list_images()
chat = prompt.chat
if images:
if not litellm.supports_vision(self.model_name):
warnings.warn(
message=f"Model {self.model_name} does not support vision. Image input would be ignored",
category=UserWarning,
)
return chat
user_message_content = [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64.b64encode(im).decode('utf-8')}"},
}
for im in images
]
last_message = chat[-1]
if last_message["role"] == "user":
user_message_content = [{"type": "text", "text": last_message["content"]}] + user_message_content
chat = chat[:-1]
chat.append({"role": "user", "content": user_message_content})
return chat
13 changes: 11 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABCMeta, abstractmethod
from typing import Dict, Generic, Optional, Type
from typing import Any, Dict, Generic, Optional, Type

from pydantic import BaseModel
from typing_extensions import TypeVar

ChatFormat = list[dict[str, str]]
ChatFormat = list[dict[str, Any]]
OutputT = TypeVar("OutputT", default=str)


Expand Down Expand Up @@ -37,6 +37,15 @@ def output_schema(self) -> Optional[Dict | Type[BaseModel]]:
"""
return None

def list_images(self) -> list[bytes]:
"""
Returns the schema of the list of images compatible with llm apis
Returns:
list of dictionaries
"""

return []


class BasePromptWithParser(Generic[OutputT], BasePrompt, metaclass=ABCMeta):
"""
Expand Down
21 changes: 21 additions & 0 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
output_type: Type[OutputT]
system_prompt_template: Optional[Template]
user_prompt_template: Template
image_input_fields: Optional[list[str]] = None

@classmethod
def _get_io_types(cls) -> Tuple:
Expand Down Expand Up @@ -74,6 +75,17 @@ def _render_template(cls, template: Template, input_data: Optional[InputT]) -> s
context = input_data.model_dump(serialize_as_any=True)
return template.render(**context)

@classmethod
def _get_images_from_input_data(cls, input_data: Optional[InputT]) -> list[bytes]:
images = []
if isinstance(input_data, BaseModel):
image_input_fields = cls.image_input_fields or []
for field in image_input_fields:
images_for_field = getattr(input_data, field)
if images_for_field:
images.extend(images_for_field)
return images

@classmethod
def _format_message(cls, message: str) -> str:
return textwrap.dedent(message).strip()
Expand Down Expand Up @@ -119,6 +131,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
)
self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)
self.images = self._get_images_from_input_data(input_data)

# Additional few shot examples that can be added dynamically using methods
# (in opposite to the static `few_shots` attribute which is defined in the class)
Expand Down Expand Up @@ -181,6 +194,14 @@ def list_few_shots(self) -> ChatFormat:
result.append({"role": "assistant", "content": assistant_message})
return result

def list_images(self) -> list[bytes]:
"""
Returns the schema of the list of images compatible with llm apis
Returns:
list of dictionaries
"""
return self.images

def output_schema(self) -> Optional[Dict | Type[BaseModel]]:
"""
Returns the schema of the desired output. Can be used to request structured output from the LLM API
Expand Down
Binary file added packages/ragbits-core/tests/test-images/test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions packages/ragbits-core/tests/unit/prompts/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pydantic
import pytest

Expand All @@ -14,6 +16,14 @@ class _PromptInput(pydantic.BaseModel):
age: int


class _ImagePromptInput(pydantic.BaseModel):
"""
Input format for the TestImagePrompt
"""

images: list[bytes]


class _PromptOutput(pydantic.BaseModel):
"""
Output format for the TestPrompt.
Expand Down Expand Up @@ -96,6 +106,20 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
TestPrompt()


def test_image_prompt():
"Tests the prompt creation of images"
with open(Path(__file__).parent.parent.parent / "test-images" / "test.png", "rb") as f:
image_bytes = f.read()
image_list = [image_bytes]

class ImagePrompt(Prompt):
user_prompt = "What is on this image?"
image_input_fields = ["images"]

prompt = ImagePrompt(_ImagePromptInput(images=image_list))
assert len(prompt.list_images()) == 1


def test_prompt_with_no_input_type():
"""Test that a prompt can be created with no input type."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Optional

from PIL import Image
from pydantic import BaseModel
from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.prompt import Prompt
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider
Expand All @@ -19,6 +21,16 @@
)

DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini"
DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image."


class _ImagePrompt(Prompt):
user_prompt: str = DEFAULT_IMAGE_QUESTION_PROMPT
image_input_fields: list[str] = ["images"]


class _ImagePromptInput(BaseModel):
images: list[bytes]


class UnstructuredImageProvider(UnstructuredDefaultProvider):
Expand Down Expand Up @@ -79,7 +91,8 @@ async def _to_image_element(
)

img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
image_description = await self.image_summarizer.get_image_description(img_bytes)
prompt = _ImagePrompt(_ImagePromptInput(images=[img_bytes]))
image_description = await self.image_summarizer.get_image_description(prompt=prompt)
return ImageElement(
description=image_description,
ocr_extracted_text=element.text,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import base64
import io
import os
import warnings as wrngs
from typing import Optional

from PIL import Image
from unstructured.documents.elements import Element as UnstructuredElement

from ragbits.core.llms.base import LLM
from ragbits.core.prompt.base import BasePrompt
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.element import TextElement

Expand Down Expand Up @@ -85,35 +86,18 @@ class ImageDescriber:
Describes images content using an LLM
"""

DEFAULT_PROMPT = "Describe the content of the image."

def __init__(self, llm: LLM):
self.llm = llm

async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] = DEFAULT_PROMPT) -> str:
async def get_image_description(self, prompt: BasePrompt) -> str:
"""
Provides summary of the image (passed as bytes)
Provides summary of the image passed with prompt
Args:
image_bytes: bytes of the image
prompt: prompt to be used
prompt: BasePrompt an instance of a prompt
Returns:
summary of the image
"""
img_base64 = base64.b64encode(image_bytes).decode("utf-8")

# TODO make this use prompt structure from ragbits core once there is a support for images
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": f"{prompt}"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
],
}
]
return await self.llm.client.call(messages, self.llm.default_options) # type: ignore
if not prompt.list_images():
wrngs.warn(message="Image data not provided", category=UserWarning)
return await self.llm.generate(prompt=prompt)

0 comments on commit d49944d

Please sign in to comment.