Skip to content

Commit

Permalink
feat(document-search): async unstructured api (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds authored Oct 8, 2024
1 parent ec46bec commit 5398d20
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 76 deletions.
3 changes: 2 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ classifiers = [
dependencies = [
"numpy~=1.24.0",
"ragbits",
"unstructured>=0.15.12",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.api import partition_via_api
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
Expand All @@ -21,7 +21,7 @@
DEFAULT_CHUNKING_KWARGS: dict = {}

UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY"
UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL"
UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL"


class UnstructuredProvider(BaseProvider):
Expand Down Expand Up @@ -50,15 +50,47 @@ class UnstructuredProvider(BaseProvider):
DocumentType.XML,
}

def __init__(self, partition_kwargs: Optional[dict] = None, chunking_kwargs: Optional[dict] = None):
def __init__(
self,
partition_kwargs: Optional[dict] = None,
chunking_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
api_server: Optional[str] = None,
) -> None:
"""Initialize the UnstructuredProvider.
Args:
partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation
for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters
chunking_kwargs: The additional arguments for the chunking.
api_key: The API key to use for the Unstructured API. If not specified, the UNSTRUCTURED_API_KEY environment
variable will be used.
api_server: The API server URL to use for the Unstructured API. If not specified, the
UNSTRUCTURED_SERVER_URL environment variable will be used.
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS
self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS
self.api_key = api_key
self.api_server = api_server
self._client = None

@property
def client(self) -> UnstructuredClient:
"""Get the UnstructuredClient instance. If the client is not initialized, it will be created.
Returns:
The UnstructuredClient instance.
Raises:
ValueError: If the UNSTRUCTURED_API_KEY_ENV environment variable is not set.
ValueError: If the UNSTRUCTURED_SERVER_URL_ENV environment variable is not set.
"""
if self._client is not None:
return self._client
api_key = _set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV)
api_server = _set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV)
self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server)
return self._client

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document using the Unstructured API.
Expand All @@ -70,27 +102,24 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
The list of elements extracted from the document.
Raises:
ValueError: If the UNSTRUCTURED_API_KEY or UNSTRUCTURED_API_URL environment variables are not set.
DocumentTypeNotSupportedError: If the document type is not supported.
"""
self.validate_document_type(document_meta.document_type)
if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set")
if (api_url := os.getenv(UNSTRUCTURED_API_URL_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set")

document = await document_meta.fetch()

# TODO: Currently this is a blocking call. It should be made async.
elements = partition_via_api(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
api_key=api_key,
api_url=api_url,
**self.partition_kwargs,
res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
}
}
)
elements = chunk_elements(elements, **self.chunking_kwargs)
elements = chunk_elements(elements_from_dicts(res.elements), **self.chunking_kwargs)
return [_to_text_element(element, document_meta) for element in elements]


Expand All @@ -99,3 +128,11 @@ def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta)
document_meta=document_meta,
content=element.text,
)


def _set_or_raise(name: str, value: Optional[str], env_var: str) -> str:
if value is not None:
return value
if (env_value := os.getenv(env_var)) is None:
raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable")
return env_value
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from ragbits.document_search.ingestion.providers.unstructured import (
DEFAULT_PARTITION_KWARGS,
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UNSTRUCTURED_SERVER_URL_ENV,
UnstructuredProvider,
)

from ..helpers import env_vars_not_set


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
Expand All @@ -30,7 +30,7 @@ async def test_document_processor_processes_text_document_with_unstructured_prov


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_md_document_with_unstructured_provider():
Expand All @@ -44,7 +44,7 @@ async def test_document_processor_processes_md_document_with_unstructured_provid


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_default_partition_kwargs():
Expand All @@ -58,7 +58,7 @@ async def test_unstructured_provider_document_with_default_partition_kwargs():


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_custom_partition_kwargs():
Expand Down
20 changes: 10 additions & 10 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from unittest.mock import patch

import pytest

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError
from ragbits.document_search.ingestion.providers.unstructured import (
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UnstructuredProvider,
)
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider


@pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES)
Expand All @@ -21,20 +20,21 @@ def test_unsupported_provider_validates_supported_document_types_fails():
assert "Document type unknown is not supported by the UnstructuredProvider" in str(err.value)


@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

assert f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set" in str(err.value)
assert "Either pass api_key argument or set the UNSTRUCTURED_API_KEY environment variable" == str(err.value)


async def test_unstructured_provider_raises_value_error_when_api_url_not_set(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv(UNSTRUCTURED_API_KEY_ENV, "dummy_key")
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_server_url_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
await UnstructuredProvider(api_key="api_key").process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

assert f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set" in str(err.value)
assert "Either pass api_server argument or set the UNSTRUCTURED_SERVER_URL environment variable" == str(err.value)
64 changes: 22 additions & 42 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5398d20

Please sign in to comment.