From c5ca2796aa44e149cc13c67848a98a2d4a914fe1 Mon Sep 17 00:00:00 2001 From: Konrad Czarnota Date: Tue, 22 Oct 2024 16:04:30 +0200 Subject: [PATCH] Change LLMLite in ustructure providers to more generic LLM --- .../ingestion/providers/unstructured/images.py | 14 +++++--------- .../ingestion/providers/unstructured/utils.py | 13 ++++--------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py index f137724f..466ace1d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py @@ -6,7 +6,8 @@ from unstructured.documents.elements import Element as UnstructuredElement from unstructured.documents.elements import ElementType -from ragbits.core.llms.litellm import LiteLLMOptions +from ragbits.core.llms.base import LLM +from ragbits.core.llms.litellm import LiteLLM from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element, ImageElement from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider @@ -18,7 +19,6 @@ ) DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini" -DEFAULT_LLM_OPTIONS = LiteLLMOptions() class UnstructuredImageProvider(UnstructuredDefaultProvider): @@ -38,8 +38,7 @@ def __init__( api_key: Optional[str] = None, api_server: Optional[str] = None, use_api: bool = False, - llm_model_name: Optional[str] = None, - llm_options: Optional[LiteLLMOptions] = None, + llm: Optional[LLM] = None, ) -> None: """Initialize the UnstructuredPdfProvider. @@ -51,13 +50,10 @@ def __init__( variable will be used. api_server: The API server URL to use for the Unstructured API. If not specified, the UNSTRUCTURED_SERVER_URL environment variable will be used. - llm_model_name: name of LLM model to be used. - llm_options: llm lite options to be used. + llm: llm to use """ super().__init__(partition_kwargs, chunking_kwargs, api_key, api_server, use_api) - self.image_summarizer = ImageDescriber( - llm_model_name or DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL, llm_options or DEFAULT_LLM_OPTIONS - ) + self.image_summarizer = ImageDescriber(llm or LiteLLM(DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL)) async def _chunk_and_convert( self, elements: list[UnstructuredElement], document_meta: DocumentMeta, document_path: Path diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py index 6c4c6131..5347168d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/utils.py @@ -6,7 +6,7 @@ from PIL import Image from unstructured.documents.elements import Element as UnstructuredElement -from ragbits.core.llms.litellm import LiteLLM, LiteLLMOptions +from ragbits.core.llms.base import LLM from ragbits.document_search.documents.document import DocumentMeta from ragbits.document_search.documents.element import TextElement @@ -87,10 +87,8 @@ class ImageDescriber: DEFAULT_PROMPT = "Describe the content of the image." - def __init__(self, llm_name: str, llm_options: LiteLLMOptions): - self.llm_name = llm_name - self.llm: Optional[LiteLLM] = None - self.llm_options = llm_options + def __init__(self, llm: LLM): + self.llm = llm async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] = DEFAULT_PROMPT) -> str: """ @@ -103,9 +101,6 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] Returns: summary of the image """ - if not self.llm: - self.llm = LiteLLM(self.llm_name) - img_base64 = base64.b64encode(image_bytes).decode("utf-8") # TODO make this use prompt structure from ragbits core once there is a support for images @@ -121,4 +116,4 @@ async def get_image_description(self, image_bytes: bytes, prompt: Optional[str] ], } ] - return await self.llm.client.call(messages, self.llm_options) # type: ignore + return await self.llm.client.call(messages, self.llm.default_options) # type: ignore