Skip to content

Commit

Permalink
Swap sync with async Unstructured api call
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds committed Sep 24, 2024
1 parent 1d743fb commit 6b0c604
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 147 deletions.
3 changes: 2 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ classifiers = [
dependencies = [
"numpy~=1.24.0",
"ragbits",
"unstructured>=0.15.12",
"unstructured>=0.15.13",
"unstructured-client==0.26.0b2",
]

[tool.uv]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.api import partition_via_api
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
Expand All @@ -18,7 +18,7 @@
}

UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY"
UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL"
UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL"


class UnstructuredProvider(BaseProvider):
Expand Down Expand Up @@ -55,6 +55,29 @@ def __init__(self, partition_kwargs: Optional[dict] = None):
for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS
self._client = None

@property
def client(self) -> UnstructuredClient:
"""Get the UnstructuredClient instance. If the client is not initialized, it will be created.
Returns:
The UnstructuredClient instance.
Raises:
ValueError: If the UNSTRUCTURED_API_KEY_ENV environment variable is not set.
ValueError: If the UNSTRUCTURED_SERVER_URL_ENV environment variable is not set.
"""
if self._client is not None:
return self._client
if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None:
print(api_key)
print("I should raise here")
raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set")
if (server_url := os.getenv(UNSTRUCTURED_SERVER_URL_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_SERVER_URL_ENV} environment variable is not set")
self._client = UnstructuredClient(api_key_auth=api_key, server_url=server_url)
return self._client

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document using the Unstructured API.
Expand All @@ -66,27 +89,24 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
The list of elements extracted from the document.
Raises:
ValueError: If the UNSTRUCTURED_API_KEY or UNSTRUCTURED_API_URL environment variables are not set.
DocumentTypeNotSupportedError: If the document type is not supported.
"""
self.validate_document_type(document_meta.document_type)
if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set")
if (api_url := os.getenv(UNSTRUCTURED_API_URL_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set")

document = await document_meta.fetch()

# TODO: Currently this is a blocking call. It should be made async.
elements = partition_via_api(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
api_key=api_key,
api_url=api_url,
**self.partition_kwargs,
res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
}
}
)
return [_to_text_element(element, document_meta) for element in elements]
return [_to_text_element(element, document_meta) for element in elements_from_dicts(res.elements)]


def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -> TextElement:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from ragbits.document_search.ingestion.providers.unstructured import (
DEFAULT_PARTITION_KWARGS,
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UNSTRUCTURED_SERVER_URL_ENV,
UnstructuredProvider,
)

from ..helpers import env_vars_not_set


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
Expand All @@ -26,11 +26,11 @@ async def test_document_processor_processes_text_document_with_unstructured_prov

assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredProvider)
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George"
assert elements[0].content == "Name of Peppa's brother is George."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_md_document_with_unstructured_provider():
Expand All @@ -44,7 +44,7 @@ async def test_document_processor_processes_md_document_with_unstructured_provid


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_default_partition_kwargs():
Expand All @@ -58,7 +58,7 @@ async def test_unstructured_provider_document_with_default_partition_kwargs():


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_custom_partition_kwargs():
Expand Down
12 changes: 7 additions & 5 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
from unittest.mock import patch

import pytest
from dotenv import load_dotenv

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError
from ragbits.document_search.ingestion.providers.unstructured import (
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UNSTRUCTURED_SERVER_URL_ENV,
UnstructuredProvider,
)

load_dotenv()


@pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES)
def test_unsupported_provider_validates_supported_document_types_passes(document_type: DocumentType):
Expand All @@ -24,6 +24,7 @@ def test_unsupported_provider_validates_supported_document_types_fails():
assert "Document type unknown is not supported by the UnstructuredProvider" in str(err.value)


@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
Expand All @@ -33,11 +34,12 @@ async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
assert f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set" in str(err.value)


@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_url_not_set(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv(UNSTRUCTURED_API_KEY_ENV, "dummy_key")
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

assert f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set" in str(err.value)
assert f"{UNSTRUCTURED_SERVER_URL_ENV} environment variable is not set" in str(err.value)
Loading

0 comments on commit 6b0c604

Please sign in to comment.