Skip to content

Commit

Permalink
feat(document-search): add support for images (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds authored Oct 22, 2024
1 parent 2465486 commit defd0b2
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 67 deletions.
5 changes: 1 addition & 4 deletions examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.element import TextElement


class QueryWithContext(BaseModel):
Expand Down Expand Up @@ -124,9 +123,7 @@ async def _handle_message(
if not self._documents_ingested:
yield self.NO_DOCUMENTS_INGESTED_MESSAGE
results = await self.document_search.search(message[-1])
prompt = RAGPrompt(
QueryWithContext(query=message, context=[i.content for i in results if isinstance(i, TextElement)])
)
prompt = RAGPrompt(QueryWithContext(query=message, context=[i.get_key() for i in results]))
response = await self._llm.generate(prompt)
yield response.answer

Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABCMeta, abstractmethod
from typing import Dict, Generic, List, Optional, Type
from typing import Dict, Generic, Optional, Type

from pydantic import BaseModel
from typing_extensions import TypeVar

ChatFormat = List[Dict[str, str]]
ChatFormat = list[dict[str, str]]
OutputT = TypeVar("OutputT", default=str)


Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ dependencies = [
"numpy~=1.24.0",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
"ragbits-core==0.1.0",
"pdf2image>=1.17.0",
"ragbits-core==0.1.0"
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DocumentType(str, Enum):
RTF = "rtf"
TSV = "tsv"
XML = "xml"
JPG = "jpg"
PNG = "png"

UNKNOWN = "unknown"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,23 @@ def get_key(self) -> str:
The key.
"""
return self.content


class ImageElement(Element):
"""
An object representing an image element in a document.
"""

element_type: str = "image"
description: str
ocr_extracted_text: str
image_bytes: bytes

def get_key(self) -> str:
"""
Get the key of the element which will be used to generate the vector.
Returns:
The key.
"""
return self.description + " " + self.ocr_extracted_text
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,34 @@
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers import get_provider
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider
from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider
from ragbits.document_search.ingestion.providers.unstructured.images import UnstructuredImageProvider
from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider

ProvidersConfig = dict[DocumentType, BaseProvider]


DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredProvider(),
DocumentType.MD: UnstructuredProvider(),
DocumentType.PDF: UnstructuredProvider(),
DocumentType.DOCX: UnstructuredProvider(),
DocumentType.DOC: UnstructuredProvider(),
DocumentType.PPTX: UnstructuredProvider(),
DocumentType.PPT: UnstructuredProvider(),
DocumentType.XLSX: UnstructuredProvider(),
DocumentType.XLS: UnstructuredProvider(),
DocumentType.CSV: UnstructuredProvider(),
DocumentType.HTML: UnstructuredProvider(),
DocumentType.EPUB: UnstructuredProvider(),
DocumentType.ORG: UnstructuredProvider(),
DocumentType.ODT: UnstructuredProvider(),
DocumentType.RST: UnstructuredProvider(),
DocumentType.RTF: UnstructuredProvider(),
DocumentType.TSV: UnstructuredProvider(),
DocumentType.XML: UnstructuredProvider(),
DocumentType.TXT: UnstructuredDefaultProvider(),
DocumentType.MD: UnstructuredDefaultProvider(),
DocumentType.PDF: UnstructuredPdfProvider(),
DocumentType.DOCX: UnstructuredDefaultProvider(),
DocumentType.DOC: UnstructuredDefaultProvider(),
DocumentType.PPTX: UnstructuredDefaultProvider(),
DocumentType.PPT: UnstructuredDefaultProvider(),
DocumentType.XLSX: UnstructuredDefaultProvider(),
DocumentType.XLS: UnstructuredDefaultProvider(),
DocumentType.CSV: UnstructuredDefaultProvider(),
DocumentType.HTML: UnstructuredDefaultProvider(),
DocumentType.EPUB: UnstructuredDefaultProvider(),
DocumentType.ORG: UnstructuredDefaultProvider(),
DocumentType.ODT: UnstructuredDefaultProvider(),
DocumentType.RST: UnstructuredDefaultProvider(),
DocumentType.RTF: UnstructuredDefaultProvider(),
DocumentType.TSV: UnstructuredDefaultProvider(),
DocumentType.XML: UnstructuredDefaultProvider(),
DocumentType.JPG: UnstructuredImageProvider(),
DocumentType.PNG: UnstructuredImageProvider(),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@

from .base import BaseProvider
from .dummy import DummyProvider
from .unstructured import UnstructuredProvider

__all__ = ["BaseProvider", "DummyProvider", "UnstructuredProvider"]
from .unstructured.default import UnstructuredDefaultProvider
from .unstructured.images import UnstructuredImageProvider
from .unstructured.pdf import UnstructuredPdfProvider

__all__ = [
"BaseProvider",
"DummyProvider",
"UnstructuredDefaultProvider",
"UnstructuredPdfProvider",
"UnstructuredImageProvider",
]

module = sys.modules[__name__]


def get_provider(provider_config: dict) -> BaseProvider:
"""
Initializes and returns an Provider object based on the provided configuration.
Initializes and returns a Provider object based on the provided configuration.
Args:
provider_config : A dictionary containing configuration details for the provider.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from io import BytesIO
from pathlib import Path
from typing import Optional

from unstructured.chunking.basic import chunk_elements
Expand All @@ -9,8 +9,9 @@
from unstructured_client import UnstructuredClient

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured.utils import check_required_argument, to_text_element

DEFAULT_PARTITION_KWARGS: dict = {
"strategy": "hi_res",
Expand All @@ -26,15 +27,14 @@
UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL"


class UnstructuredProvider(BaseProvider):
class UnstructuredDefaultProvider(BaseProvider):
"""
A provider that uses the Unstructured API to process the documents.
A provider that uses the Unstructured API or local SDK to process the documents.
"""

SUPPORTED_DOCUMENT_TYPES = {
DocumentType.TXT,
DocumentType.MD,
DocumentType.PDF,
DocumentType.DOCX,
DocumentType.DOC,
DocumentType.PPTX,
Expand All @@ -59,8 +59,9 @@ def __init__(
api_key: Optional[str] = None,
api_server: Optional[str] = None,
use_api: bool = False,
ignore_images: bool = False,
) -> None:
"""Initialize the UnstructuredProvider.
"""Initialize the UnstructuredDefaultProvider.
Args:
partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation
Expand All @@ -70,13 +71,16 @@ def __init__(
variable will be used.
api_server: The API server URL to use for the Unstructured API. If not specified, the
UNSTRUCTURED_SERVER_URL environment variable will be used.
use_api: whether to use Unstructured API, otherwise use local version of Unstructured library
ignore_images: if True images will be skipped
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS
self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS
self.api_key = api_key
self.api_server = api_server
self.use_api = use_api
self._client = None
self.ignore_images = ignore_images

@property
def client(self) -> UnstructuredClient:
Expand All @@ -91,8 +95,10 @@ def client(self) -> UnstructuredClient:
"""
if self._client is not None:
return self._client
api_key = _set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV)
api_server = _set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV)
api_key = check_required_argument(arg_name="api_key", value=self.api_key, fallback_env=UNSTRUCTURED_API_KEY_ENV)
api_server = check_required_argument(
arg_name="api_server", value=self.api_server, fallback_env=UNSTRUCTURED_SERVER_URL_ENV
)
self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server)
return self._client

Expand Down Expand Up @@ -120,6 +126,7 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
"coordinates": True,
**self.partition_kwargs,
}
}
Expand All @@ -132,20 +139,14 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
**self.partition_kwargs,
)

elements = chunk_elements(elements, **self.chunking_kwargs)
return [_to_text_element(element, document_meta) for element in elements]
return await self._chunk_and_convert(elements, document_meta, document.local_path)


def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -> TextElement:
return TextElement(
document_meta=document_meta,
content=element.text,
)


def _set_or_raise(name: str, value: Optional[str], env_var: str) -> str:
if value is not None:
return value
if (env_value := os.getenv(env_var)) is None:
raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable")
return env_value
async def _chunk_and_convert(
# pylint: disable=unused-argument
self,
elements: list[UnstructuredElement],
document_meta: DocumentMeta,
document_path: Path,
) -> list[Element]:
chunked_elements = chunk_elements(elements, **self.chunking_kwargs)
return [to_text_element(element, document_meta) for element in chunked_elements]
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path
from typing import Optional

from PIL import Image
from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM
from ragbits.core.llms.litellm import LiteLLM
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider
from ragbits.document_search.ingestion.providers.unstructured.utils import (
ImageDescriber,
crop_and_convert_to_bytes,
extract_image_coordinates,
to_text_element,
)

DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini"


class UnstructuredImageProvider(UnstructuredDefaultProvider):
"""
A specialized provider that handles pngs and jpgs using the Unstructured
"""

SUPPORTED_DOCUMENT_TYPES = {
DocumentType.JPG,
DocumentType.PNG,
}

def __init__(
self,
partition_kwargs: Optional[dict] = None,
chunking_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
api_server: Optional[str] = None,
use_api: bool = False,
llm: Optional[LLM] = None,
) -> None:
"""Initialize the UnstructuredPdfProvider.
Args:
partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation
for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters
chunking_kwargs: The additional arguments for the chunking.
api_key: The API key to use for the Unstructured API. If not specified, the UNSTRUCTURED_API_KEY environment
variable will be used.
api_server: The API server URL to use for the Unstructured API. If not specified, the
UNSTRUCTURED_SERVER_URL environment variable will be used.
llm: llm to use
"""
super().__init__(partition_kwargs, chunking_kwargs, api_key, api_server, use_api)
self.image_summarizer = ImageDescriber(llm or LiteLLM(DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL))

async def _chunk_and_convert(
self, elements: list[UnstructuredElement], document_meta: DocumentMeta, document_path: Path
) -> list[Element]:
image_elements = [e for e in elements if e.category == ElementType.IMAGE]
other_elements = [e for e in elements if e.category != ElementType.IMAGE]
chunked_other_elements = chunk_elements(other_elements, **self.chunking_kwargs)

text_elements: list[Element] = [to_text_element(element, document_meta) for element in chunked_other_elements]
if self.ignore_images:
return text_elements
return text_elements + [
await self._to_image_element(element, document_meta, document_path) for element in image_elements
]

async def _to_image_element(
self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path
) -> ImageElement:
top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element)
image = self._load_document_as_image(document_path)
top_x, top_y, bottom_x, bottom_y = self._convert_coordinates(
top_x, top_y, bottom_x, bottom_y, image.width, image.height, element
)

img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
image_description = await self.image_summarizer.get_image_description(img_bytes)
return ImageElement(
description=image_description,
ocr_extracted_text=element.text,
image_bytes=img_bytes,
document_meta=document_meta,
)

@staticmethod
def _load_document_as_image(
document_path: Path, page: Optional[int] = None # pylint: disable=unused-argument
) -> Image:
return Image.open(document_path).convert("RGB")

@staticmethod
def _convert_coordinates(
# pylint: disable=unused-argument
top_x: float,
top_y: float,
bottom_x: float,
bottom_y: float,
image_width: int,
image_height: int,
element: UnstructuredElement,
) -> tuple[float, float, float, float]:
return top_x, top_y, bottom_x, bottom_y
Loading

0 comments on commit defd0b2

Please sign in to comment.