diff --git a/packages/ragbits-document-search/pyproject.toml b/packages/ragbits-document-search/pyproject.toml index 6820f0bf..83ae5304 100644 --- a/packages/ragbits-document-search/pyproject.toml +++ b/packages/ragbits-document-search/pyproject.toml @@ -34,7 +34,8 @@ classifiers = [ dependencies = [ "numpy~=1.24.0", "ragbits", - "unstructured>=0.15.12", + "unstructured>=0.15.13", + "unstructured-client>=0.26.0", ] [project.optional-dependencies] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py index 2e81b8ab..8a6be941 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py @@ -1,10 +1,10 @@ import os -from io import BytesIO from typing import Optional from unstructured.chunking.basic import chunk_elements from unstructured.documents.elements import Element as UnstructuredElement -from unstructured.partition.api import partition_via_api +from unstructured.staging.base import elements_from_dicts +from unstructured_client import UnstructuredClient from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element, TextElement @@ -21,7 +21,7 @@ DEFAULT_CHUNKING_KWARGS: dict = {} UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY" -UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL" +UNSTRUCTURED_SERVER_URL_ENV = "UNSTRUCTURED_SERVER_URL" class UnstructuredProvider(BaseProvider): @@ -50,15 +50,47 @@ class UnstructuredProvider(BaseProvider): DocumentType.XML, } - def __init__(self, partition_kwargs: Optional[dict] = None, chunking_kwargs: Optional[dict] = None): + def __init__( + self, + partition_kwargs: Optional[dict] = None, + chunking_kwargs: Optional[dict] = None, + api_key: Optional[str] = None, + api_server: Optional[str] = None, + ) -> None: """Initialize the UnstructuredProvider. Args: partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters + chunking_kwargs: The additional arguments for the chunking. + api_key: The API key to use for the Unstructured API. If not specified, the UNSTRUCTURED_API_KEY environment + variable will be used. + api_server: The API server URL to use for the Unstructured API. If not specified, the + UNSTRUCTURED_SERVER_URL environment variable will be used. """ self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS + self.api_key = api_key + self.api_server = api_server + self._client = None + + @property + def client(self) -> UnstructuredClient: + """Get the UnstructuredClient instance. If the client is not initialized, it will be created. + + Returns: + The UnstructuredClient instance. + + Raises: + ValueError: If the UNSTRUCTURED_API_KEY_ENV environment variable is not set. + ValueError: If the UNSTRUCTURED_SERVER_URL_ENV environment variable is not set. + """ + if self._client is not None: + return self._client + api_key = _set_or_raise(name="api_key", value=self.api_key, env_var=UNSTRUCTURED_API_KEY_ENV) + api_server = _set_or_raise(name="api_server", value=self.api_server, env_var=UNSTRUCTURED_SERVER_URL_ENV) + self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server) + return self._client async def process(self, document_meta: DocumentMeta) -> list[Element]: """Process the document using the Unstructured API. @@ -70,27 +102,24 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]: The list of elements extracted from the document. Raises: - ValueError: If the UNSTRUCTURED_API_KEY or UNSTRUCTURED_API_URL environment variables are not set. DocumentTypeNotSupportedError: If the document type is not supported. """ self.validate_document_type(document_meta.document_type) - if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None: - raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set") - if (api_url := os.getenv(UNSTRUCTURED_API_URL_ENV)) is None: - raise ValueError(f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set") - document = await document_meta.fetch() - # TODO: Currently this is a blocking call. It should be made async. - elements = partition_via_api( - file=BytesIO(document.local_path.read_bytes()), - metadata_filename=document.local_path.name, - api_key=api_key, - api_url=api_url, - **self.partition_kwargs, + res = await self.client.general.partition_async( + request={ + "partition_parameters": { + "files": { + "content": document.local_path.read_bytes(), + "file_name": document.local_path.name, + }, + **self.partition_kwargs, + } + } ) - elements = chunk_elements(elements, **self.chunking_kwargs) + elements = chunk_elements(elements_from_dicts(res.elements), **self.chunking_kwargs) return [_to_text_element(element, document_meta) for element in elements] @@ -99,3 +128,11 @@ def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) document_meta=document_meta, content=element.text, ) + + +def _set_or_raise(name: str, value: Optional[str], env_var: str) -> str: + if value is not None: + return value + if (env_value := os.getenv(env_var)) is None: + raise ValueError(f"Either pass {name} argument or set the {env_var} environment variable") + return env_value diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index b7827b85..65b2088e 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -7,7 +7,7 @@ from ragbits.document_search.ingestion.providers.unstructured import ( DEFAULT_PARTITION_KWARGS, UNSTRUCTURED_API_KEY_ENV, - UNSTRUCTURED_API_URL_ENV, + UNSTRUCTURED_SERVER_URL_ENV, UnstructuredProvider, ) @@ -15,7 +15,7 @@ @pytest.mark.skipif( - env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), + env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), reason="Unstructured API environment variables not set", ) async def test_document_processor_processes_text_document_with_unstructured_provider(): @@ -30,7 +30,7 @@ async def test_document_processor_processes_text_document_with_unstructured_prov @pytest.mark.skipif( - env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), + env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), reason="Unstructured API environment variables not set", ) async def test_document_processor_processes_md_document_with_unstructured_provider(): @@ -44,7 +44,7 @@ async def test_document_processor_processes_md_document_with_unstructured_provid @pytest.mark.skipif( - env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), + env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), reason="Unstructured API environment variables not set", ) async def test_unstructured_provider_document_with_default_partition_kwargs(): @@ -58,7 +58,7 @@ async def test_unstructured_provider_document_with_default_partition_kwargs(): @pytest.mark.skipif( - env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), + env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]), reason="Unstructured API environment variables not set", ) async def test_unstructured_provider_document_with_custom_partition_kwargs(): diff --git a/packages/ragbits-document-search/tests/unit/test_providers.py b/packages/ragbits-document-search/tests/unit/test_providers.py index 8b6eb9d9..7793d16b 100644 --- a/packages/ragbits-document-search/tests/unit/test_providers.py +++ b/packages/ragbits-document-search/tests/unit/test_providers.py @@ -1,12 +1,11 @@ +import os +from unittest.mock import patch + import pytest from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError -from ragbits.document_search.ingestion.providers.unstructured import ( - UNSTRUCTURED_API_KEY_ENV, - UNSTRUCTURED_API_URL_ENV, - UnstructuredProvider, -) +from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider @pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES) @@ -21,20 +20,21 @@ def test_unsupported_provider_validates_supported_document_types_fails(): assert "Document type unknown is not supported by the UnstructuredProvider" in str(err.value) +@patch.dict(os.environ, {}, clear=True) async def test_unstructured_provider_raises_value_error_when_api_key_not_set(): with pytest.raises(ValueError) as err: await UnstructuredProvider().process( DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") ) - assert f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set" in str(err.value) + assert "Either pass api_key argument or set the UNSTRUCTURED_API_KEY environment variable" == str(err.value) -async def test_unstructured_provider_raises_value_error_when_api_url_not_set(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv(UNSTRUCTURED_API_KEY_ENV, "dummy_key") +@patch.dict(os.environ, {}, clear=True) +async def test_unstructured_provider_raises_value_error_when_server_url_not_set(): with pytest.raises(ValueError) as err: - await UnstructuredProvider().process( + await UnstructuredProvider(api_key="api_key").process( DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") ) - assert f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set" in str(err.value) + assert "Either pass api_server argument or set the UNSTRUCTURED_SERVER_URL environment variable" == str(err.value) diff --git a/uv.lock b/uv.lock index 2bd8b3a8..8279842a 100644 --- a/uv.lock +++ b/uv.lock @@ -608,7 +608,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version <= '3.11'" }, + { name = "tomli", marker = "python_full_version == '3.11'" }, ] [[package]] @@ -666,18 +666,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686 }, ] -[[package]] -name = "deepdiff" -version = "8.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "orderly-set" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/62/ba/aced1d6a7d988ca1b6f9b274faed7dafc7356a733e45a457819bddcf2dbc/deepdiff-8.0.1.tar.gz", hash = "sha256:245599a4586ab59bb599ca3517a9c42f3318ff600ded5e80a3432693c8ec3c4b", size = 427721 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/46/01673060e83277a863baf0909b387cd809865cba2d5e7213db76516bedd9/deepdiff-8.0.1-py3-none-any.whl", hash = "sha256:42e99004ce603f9a53934c634a57b04ad5900e0d8ed0abb15e635767489cbc05", size = 82741 }, -] - [[package]] name = "deprecated" version = "1.2.14" @@ -726,6 +714,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/56/4ddf8b36aa4b52077045b17ffb8958f3360b250df4143d1482d9d5bb54d5/emoji-2.14.0-py3-none-any.whl", hash = "sha256:fcc936bf374b1aec67dda5303ae99710ba88cc9cdce2d1a71c5f2204e6d78799", size = 586897 }, ] +[[package]] +name = "eval-type-backport" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/23/ca/1601a9fa588867fe2ab6c19ed4c936929160d08a86597adf61bbd443fe57/eval_type_backport-0.2.0.tar.gz", hash = "sha256:68796cfbc7371ebf923f03bdf7bef415f3ec098aeced24e054b253a0e78f7b37", size = 8977 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/ac/aa3d8e0acbcd71140420bc752d7c9779cf3a2a3bb1d7ef30944e38b2cd39/eval_type_backport-0.2.0-py3-none-any.whl", hash = "sha256:ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933", size = 5855 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -2167,15 +2164,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ad/2e/36097c0a4d0115b8c7e377c90bab7783ac183bc5cb4071308f8959454311/opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb", size = 6946 }, ] -[[package]] -name = "orderly-set" -version = "5.2.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c8/71/5408fee86ce5408132a3ece6eff61afa2c25d5b37cd76bc100a9a4a4d8dd/orderly_set-5.2.2.tar.gz", hash = "sha256:52a18b86aaf3f5d5a498bbdb27bf3253a4e5c57ab38e5b7a56fa00115cd28448", size = 19103 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/71/6f9554919da608cb5bcf709822a9644ba4785cc7856e01ea375f6d808774/orderly_set-5.2.2-py3-none-any.whl", hash = "sha256:f7a37c95a38c01cdfe41c3ffb62925a318a2286ea0a41790c057fc802aec54da", size = 11621 }, -] - [[package]] name = "orjson" version = "3.10.7" @@ -2740,14 +2728,14 @@ wheels = [ [[package]] name = "python-dateutil" -version = "2.9.0.post0" +version = "2.8.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } +sdist = { url = "https://files.pythonhosted.org/packages/4c/c4/13b4776ea2d76c115c1d1b84579f3764ee6d57204f6be27119f13a61d0a9/python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", size = 357324 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, + { url = "https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9", size = 247702 }, ] [[package]] @@ -2927,6 +2915,7 @@ dependencies = [ { name = "numpy" }, { name = "ragbits" }, { name = "unstructured" }, + { name = "unstructured-client" }, ] [package.optional-dependencies] @@ -2953,7 +2942,8 @@ requires-dist = [ { name = "gcloud-aio-storage", marker = "extra == 'gcs'", specifier = "~=9.3.0" }, { name = "numpy", specifier = "~=1.24.0" }, { name = "ragbits", editable = "packages/ragbits-core" }, - { name = "unstructured", specifier = ">=0.15.12" }, + { name = "unstructured", specifier = ">=0.15.13" }, + { name = "unstructured-client", specifier = ">=0.26.0" }, ] [package.metadata.requires-dev] @@ -3665,7 +3655,7 @@ name = "triton" version = "2.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "python_full_version < '3.13'" }, + { name = "filelock", marker = "python_full_version < '3.12'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/95/05/ed974ce87fe8c8843855daa2136b3409ee1c126707ab54a8b72815c08b49/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5", size = 167900779 }, @@ -3753,33 +3743,23 @@ wheels = [ [[package]] name = "unstructured-client" -version = "0.25.9" +version = "0.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "certifi" }, - { name = "charset-normalizer" }, { name = "cryptography" }, - { name = "dataclasses-json" }, - { name = "deepdiff" }, + { name = "eval-type-backport" }, { name = "httpx" }, - { name = "idna" }, { name = "jsonpath-python" }, - { name = "marshmallow" }, - { name = "mypy-extensions" }, { name = "nest-asyncio" }, - { name = "packaging" }, + { name = "pydantic" }, { name = "pypdf" }, { name = "python-dateutil" }, - { name = "requests" }, { name = "requests-toolbelt" }, - { name = "six" }, - { name = "typing-extensions" }, { name = "typing-inspect" }, - { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/b2/1214a9391951754a770d6be81a67093e827a198f610dbaa971fea2b65a3a/unstructured-client-0.25.9.tar.gz", hash = "sha256:fcc461623f58fefb0e22508e28bf653a8f6934b9779cb4a90dd68d77a39fb5b2", size = 39986 } +sdist = { url = "https://files.pythonhosted.org/packages/aa/d7/7c3a2d484c08d6cee284e808114a5ae0c86fcb61c0542840cfdc4ac6def9/unstructured_client-0.26.0.tar.gz", hash = "sha256:d52ffde62ef06464a50b947e83b537b9a89a6118857442ed901315515b7d9918", size = 45751 } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/71/f0e594858f251ee2ac2edfe532714fd47afcc4e9294a3862a7c19ec13cf6/unstructured_client-0.25.9-py3-none-any.whl", hash = "sha256:c984c01878c8fc243be7c842467d1113a194d885ab6396ae74258ee42717c5b5", size = 45296 }, + { url = "https://files.pythonhosted.org/packages/f6/50/04e6712de68500220e81a496cd8b480015cbd041206cd08358882f400a78/unstructured_client-0.26.0-py3-none-any.whl", hash = "sha256:399b69441b5473ee4cdee38a0208573a4b646c02566e01eab1108066381a2914", size = 59673 }, ] [[package]]