Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): allow to use local instance of unstructured #74

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.auto import partition
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
chunking_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
api_server: Optional[str] = None,
use_api: bool = False,
) -> None:
"""Initialize the UnstructuredProvider.

Expand All @@ -72,6 +75,7 @@ def __init__(
self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS
self.api_key = api_key
self.api_server = api_server
self.use_api = use_api
self._client = None

@property
Expand Down Expand Up @@ -108,18 +112,27 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
self.validate_document_type(document_meta.document_type)
document = await document_meta.fetch()

res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
if self.use_api:
res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
}
}
}
)
elements = chunk_elements(elements_from_dicts(res.elements), **self.chunking_kwargs)
)
elements = elements_from_dicts(res.elements)
else:
elements = partition(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
**self.partition_kwargs,
)

elements = chunk_elements(elements, **self.chunking_kwargs)
return [_to_text_element(element, document_meta) for element in elements]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
from ..helpers import env_vars_not_set


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"config",
[
{},
pytest.param(
{DocumentType.TXT: UnstructuredProvider(use_api=True)},
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
document_processor = DocumentProcessorRouter.from_config()
async def test_document_processor_processes_text_document_with_unstructured_provider(config):
document_processor = DocumentProcessorRouter.from_config(config)
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")

elements = await document_processor.get_provider(document_meta).process(document_meta)
Expand All @@ -43,28 +52,46 @@ async def test_document_processor_processes_md_document_with_unstructured_provid
assert elements[0].content == "Ragbits\n\nRepository for internal experiment with our upcoming LLM framework."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"use_api",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_unstructured_provider_document_with_default_partition_kwargs():
async def test_unstructured_provider_document_with_default_partition_kwargs(use_api):
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
unstructured_provider = UnstructuredProvider()
unstructured_provider = UnstructuredProvider(use_api=use_api)
elements = await unstructured_provider.process(document_meta)

assert unstructured_provider.partition_kwargs == DEFAULT_PARTITION_KWARGS
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"use_api",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_unstructured_provider_document_with_custom_partition_kwargs():
async def test_unstructured_provider_document_with_custom_partition_kwargs(use_api):
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
partition_kwargs = {"languages": ["pl"], "strategy": "fast"}
unstructured_provider = UnstructuredProvider(partition_kwargs=partition_kwargs)
unstructured_provider = UnstructuredProvider(use_api=use_api, partition_kwargs=partition_kwargs)
elements = await unstructured_provider.process(document_meta)

assert unstructured_provider.partition_kwargs == partition_kwargs
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_unsupported_provider_validates_supported_document_types_fails():
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
await UnstructuredProvider(use_api=True).process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

Expand All @@ -33,7 +33,7 @@ async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_server_url_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider(api_key="api_key").process(
await UnstructuredProvider(api_key="api_key", use_api=True).process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

Expand Down
Loading