Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds committed Oct 21, 2024
1 parent d52de07 commit 0be9224
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.documents.elements import ElementType
from unstructured.partition.auto import partition
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured.utils import set_or_raise, to_text_element
from ragbits.document_search.ingestion.providers.unstructured.utils import check_required_argument, to_text_element

DEFAULT_PARTITION_KWARGS: dict = {
"strategy": "hi_res",
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
variable will be used.
api_server: The API server URL to use for the Unstructured API. If not specified, the
UNSTRUCTURED_SERVER_URL environment variable will be used.
use_api: whether to use unstructured API
use_api: whether to use Unstructured API, otherwise use local version of Unstructured library
ignore_images: if True images will be skipped
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS
Expand All @@ -95,8 +96,10 @@ def client(self) -> UnstructuredClient:
"""
if self._client is not None:
return self._client
api_key = set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV)
api_server = set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV)
api_key = check_required_argument(arg_name="api_key", value=self.api_key, fallback_env=UNSTRUCTURED_API_KEY_ENV)
api_server = check_required_argument(
arg_name="api_server", value=self.api_server, fallback_env=UNSTRUCTURED_SERVER_URL_ENV
)
self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server)
return self._client

Expand Down Expand Up @@ -145,8 +148,8 @@ async def _chunk_and_convert(
if self.__class__ == UnstructuredDefaultProvider:
chunked_elements = chunk_elements(elements, **self.chunking_kwargs)
return [to_text_element(element, document_meta) for element in chunked_elements]
image_elements = [e for e in elements if e.category == "Image"]
other_elements = [e for e in elements if e.category != "Image"]
image_elements = [e for e in elements if e.category == ElementType.IMAGE]
other_elements = [e for e in elements if e.category != ElementType.IMAGE]
chunked_other_elements = chunk_elements(other_elements, **self.chunking_kwargs)

text_elements: list[Element] = [to_text_element(element, document_meta) for element in chunked_other_elements]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def __init__(
async def _to_image_element(
self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path
) -> ImageElement:
image_coordinates = extract_image_coordinates(element)
top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element)
image = Image.open(document_path).convert("RGB")
img_bytes = crop_and_convert_to_bytes(image, *image_coordinates)
img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
image_description = await self.image_summarizer.get_image_description(img_bytes)
return ImageElement(
description=image_description,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,14 @@ def __init__(
async def _to_image_element(
self, element: UnstructuredElement, document_meta: DocumentMeta, document_path: Path
) -> ImageElement:
image_coordinates = extract_image_coordinates(element)
top_x, top_y, bottom_x, bottom_y = extract_image_coordinates(element)
page_number = element.metadata.page_number
image = convert_from_path(document_path, first_page=page_number, last_page=page_number)[0]

new_system = CoordinateSystem(image.width, image.height)
new_system.orientation = Orientation.SCREEN
x0, y0 = element.metadata.coordinates.system.convert_coordinates_to_new_system(
new_system, image_coordinates[0], image_coordinates[1]
)
x1, y1 = element.metadata.coordinates.system.convert_coordinates_to_new_system(
new_system, image_coordinates[2], image_coordinates[3]
)
x0, y0 = element.metadata.coordinates.system.convert_coordinates_to_new_system(new_system, top_x, top_y)
x1, y1 = element.metadata.coordinates.system.convert_coordinates_to_new_system(new_system, bottom_x, bottom_y)

img_bytes = crop_and_convert_to_bytes(image, x0, y0, x1, y1)
image_description = await self.image_summarizer.get_image_description(img_bytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -
)


def set_or_raise(name: str, value: Optional[str], env_var: str) -> str:
def check_required_argument(value: Optional[str], arg_name: str, fallback_env: str) -> str:
"""
Checks if given environment variable is set and returns it or raises an error
Args:
name: name of the variable
arg_name: name of the variable
value: optional default value
env_var: name of the environment variable to get
fallback_env: name of the environment variable to get
Raises:
ValueError: if environment variable is not set
Expand All @@ -45,8 +45,8 @@ def set_or_raise(name: str, value: Optional[str], env_var: str) -> str:
"""
if value is not None:
return value
if (env_value := os.getenv(env_var)) is None:
raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable")
if (env_value := os.getenv(fallback_env)) is None:
raise ValueError(f"Either pass {arg_name} argument or set the {fallback_env} environment variable")
return env_value


Expand Down Expand Up @@ -87,10 +87,10 @@ class ImageDescriber:

DEFAULT_PROMPT = "Describe the content of the image."

def __init__(self, model_name: str, model_options: LiteLLMOptions):
self.model_name = model_name
self.model: Optional[LiteLLM] = None
self.model_options = model_options
def __init__(self, llm_name: str, llm_options: LiteLLMOptions):
self.llm_name = llm_name
self.llm: Optional[LiteLLM] = None
self.llm_options = llm_options

async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] = DEFAULT_PROMPT) -> str:
"""
Expand All @@ -103,8 +103,8 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str]
Returns:
summary of the image
"""
if not self.model:
self.model = LiteLLM(self.model_name)
if not self.llm:
self.llm = LiteLLM(self.llm_name)

img_base64 = base64.b64encode(image_bytes).decode("utf-8")

Expand All @@ -121,4 +121,4 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str]
],
}
]
return await self.model.client.call(messages, self.model_options)
return await self.llm.client.call(messages, self.llm_options)

0 comments on commit 0be9224

Please sign in to comment.