Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): async unstructured api #37

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ classifiers = [
dependencies = [
"numpy~=1.24.0",
"ragbits",
"unstructured>=0.15.12",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.api import partition_via_api
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
Expand All @@ -21,7 +21,7 @@
DEFAULT_CHUNKING_KWARGS: dict = {}

UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY"
UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL"
UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL"


class UnstructuredProvider(BaseProvider):
Expand Down Expand Up @@ -50,15 +50,47 @@ class UnstructuredProvider(BaseProvider):
DocumentType.XML,
}

def __init__(self, partition_kwargs: Optional[dict] = None, chunking_kwargs: Optional[dict] = None):
def __init__(
self,
partition_kwargs: Optional[dict] = None,
chunking_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
api_server: Optional[str] = None,
) -> None:
"""Initialize the UnstructuredProvider.

Args:
partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation
for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters
chunking_kwargs: The additional arguments for the chunking.
api_key: The API key to use for the Unstructured API. If not specified, the UNSTRUCTURED_API_KEY environment
variable will be used.
api_server: The API server URL to use for the Unstructured API. If not specified, the
UNSTRUCTURED_SERVER_URL environment variable will be used.
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS
self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS
self.api_key = api_key
self.api_server = api_server
self._client = None

@property
def client(self) -> UnstructuredClient:
"""Get the UnstructuredClient instance. If the client is not initialized, it will be created.

Returns:
The UnstructuredClient instance.

Raises:
ValueError: If the UNSTRUCTURED_API_KEY_ENV environment variable is not set.
ValueError: If the UNSTRUCTURED_SERVER_URL_ENV environment variable is not set.
"""
if self._client is not None:
return self._client
api_key = _set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV)
api_server = _set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV)
self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server)
return self._client

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document using the Unstructured API.
Expand All @@ -70,27 +102,24 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
The list of elements extracted from the document.

Raises:
ValueError: If the UNSTRUCTURED_API_KEY or UNSTRUCTURED_API_URL environment variables are not set.
DocumentTypeNotSupportedError: If the document type is not supported.

"""
self.validate_document_type(document_meta.document_type)
if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set")
if (api_url := os.getenv(UNSTRUCTURED_API_URL_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set")

document = await document_meta.fetch()

# TODO: Currently this is a blocking call. It should be made async.
elements = partition_via_api(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
api_key=api_key,
api_url=api_url,
**self.partition_kwargs,
res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
}
}
)
elements = chunk_elements(elements, **self.chunking_kwargs)
elements = chunk_elements(elements_from_dicts(res.elements), **self.chunking_kwargs)
return [_to_text_element(element, document_meta) for element in elements]


Expand All @@ -99,3 +128,11 @@ def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta)
document_meta=document_meta,
content=element.text,
)


def _set_or_raise(name: str, value: Optional[str], env_var: str) -> str:
if value is not None:
return value
if (env_value := os.getenv(env_var)) is None:
raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable")
return env_value
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from ragbits.document_search.ingestion.providers.unstructured import (
DEFAULT_PARTITION_KWARGS,
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UNSTRUCTURED_SERVER_URL_ENV,
UnstructuredProvider,
)

from ..helpers import env_vars_not_set


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
Expand All @@ -30,7 +30,7 @@ async def test_document_processor_processes_text_document_with_unstructured_prov


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_md_document_with_unstructured_provider():
Expand All @@ -44,7 +44,7 @@ async def test_document_processor_processes_md_document_with_unstructured_provid


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_default_partition_kwargs():
Expand All @@ -58,7 +58,7 @@ async def test_unstructured_provider_document_with_default_partition_kwargs():


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_unstructured_provider_document_with_custom_partition_kwargs():
Expand Down
20 changes: 10 additions & 10 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from unittest.mock import patch

import pytest

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError
from ragbits.document_search.ingestion.providers.unstructured import (
UNSTRUCTURED_API_KEY_ENV,
UNSTRUCTURED_API_URL_ENV,
UnstructuredProvider,
)
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider


@pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES)
Expand All @@ -21,20 +20,21 @@ def test_unsupported_provider_validates_supported_document_types_fails():
assert "Document type unknown is not supported by the UnstructuredProvider" in str(err.value)


@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

assert f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set" in str(err.value)
assert "Either pass api_key argument or set the UNSTRUCTURED_API_KEY environment variable" == str(err.value)


async def test_unstructured_provider_raises_value_error_when_api_url_not_set(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv(UNSTRUCTURED_API_KEY_ENV, "dummy_key")
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_server_url_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
await UnstructuredProvider(api_key="api_key").process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

assert f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set" in str(err.value)
assert "Either pass api_server argument or set the UNSTRUCTURED_SERVER_URL environment variable" == str(err.value)
64 changes: 22 additions & 42 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading