From f7f79a973205766c9eb67d18b0ac07f7b3285413 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Wed, 16 Oct 2024 15:51:06 +0200 Subject: [PATCH] feat(document-search): batch ingestion from sources --- .../src/ragbits/document_search/_main.py | 55 ++++++++++++++++--- .../document_search/documents/sources.py | 42 ++++++++++++++ .../ingestion/providers/dummy.py | 2 - .../tests/unit/example_files/bar.md | 1 + .../tests/unit/example_files/foo.md | 1 + .../tests/unit/example_files/lorem.txt | 1 + .../tests/unit/test_document_search.py | 14 +++++ .../tests/unit/test_gcs_source.py | 7 --- .../tests/unit/test_local_file_source.py | 21 +++++++ uv.lock | 4 ++ 10 files changed, 130 insertions(+), 18 deletions(-) create mode 100644 packages/ragbits-document-search/tests/unit/example_files/bar.md create mode 100644 packages/ragbits-document-search/tests/unit/example_files/foo.md create mode 100644 packages/ragbits-document-search/tests/unit/example_files/lorem.txt create mode 100644 packages/ragbits-document-search/tests/unit/test_local_file_source.py diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index dd92caeee..fb9cb4f5e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union from pydantic import BaseModel, Field @@ -104,20 +104,20 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) return self.reranker.rerank(elements) - async def ingest_document( + async def _process_document( self, document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], document_processor: Optional[BaseProvider] = None, - ) -> None: + ) -> list[Element]: """ - Ingest a document. + Process a document and return the elements. Args: - document: The document or metadata of the document to ingest. - document_processor: The document processor to use. If not provided, the document processor will be - determined based on the document metadata. - """ + document: The document to process. + Returns: + The elements. + """ if isinstance(document, Source): document_meta = await DocumentMeta.from_source(document) elif isinstance(document, DocumentMeta): @@ -128,7 +128,44 @@ async def ingest_document( if document_processor is None: document_processor = self.document_processor_router.get_provider(document_meta) - elements = await document_processor.process(document_meta) + document_processor = self.document_processor_router.get_provider(document_meta) + return await document_processor.process(document_meta) + + async def ingest_document( + self, + document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, + ) -> None: + """ + Ingest a document. + + Args: + document: The document or metadata of the document to ingest. + document_processor: The document processor to use. If not provided, the document processor will be + determined based on the document metadata. + """ + + elements = await self._process_document(document, document_processor) + await self.insert_elements(elements) + + async def ingest_documents( + self, + documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, + ) -> None: + """ + Ingest multiple documents. + + Args: + documents: The documents or metadata of the documents to ingest. + document_processor: The document processor to use. If not provided, the document processor will be + determined based on the document metadata. + """ + + elements = [] + # TODO: Parallelize + for document in documents: + elements.extend(await self._process_document(document, document_processor)) await self.insert_elements(elements) async def insert_elements(self, elements: list[Element]) -> None: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index fc5a93a83..73e50df80 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -66,6 +66,23 @@ async def fetch(self) -> Path: """ return self.path + @classmethod + def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSource"]: + """ + List all sources in the given directory, matching the file pattern. + + Args: + path: The path to the directory. + file_pattern: The file pattern to match. + + Returns: + List of source objects. + """ + sources = [] + for file_path in path.glob(file_pattern): + sources.append(cls(path=file_path)) + return sources + class GCSSource(Source): """ @@ -122,3 +139,28 @@ async def fetch(self) -> Path: file_object.write(content) return path + + @classmethod + async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: + """ + List all sources in the given GCS bucket, matching the prefix. + + Args: + bucket: The GCS bucket. + prefix: The prefix to match. + + Returns: + List of source objects. + + Raises: + ImportError: If the required 'gcloud-aio-storage' package is not installed + """ + if not HAS_GCLOUD_AIO: + raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage") + + async with Storage() as client: + objects = await client.list_objects(bucket, params={"prefix": prefix}) + sources = [] + for obj in objects["items"]: + sources.append(cls(bucket=bucket, object_name=obj["name"])) + return sources diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py index 712f31a09..5867be9f9 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py @@ -6,8 +6,6 @@ class DummyProvider(BaseProvider): """This is a mock provider that returns a TextElement with the content of the document. It should be used for testing purposes only. - - TODO: Remove this provider after the implementation of the real providers. """ SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT} diff --git a/packages/ragbits-document-search/tests/unit/example_files/bar.md b/packages/ragbits-document-search/tests/unit/example_files/bar.md new file mode 100644 index 000000000..ba0e162e1 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/bar.md @@ -0,0 +1 @@ +bar \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/example_files/foo.md b/packages/ragbits-document-search/tests/unit/example_files/foo.md new file mode 100644 index 000000000..191028156 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/foo.md @@ -0,0 +1 @@ +foo \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/example_files/lorem.txt b/packages/ragbits-document-search/tests/unit/example_files/lorem.txt new file mode 100644 index 000000000..0db1a7d5c --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/lorem.txt @@ -0,0 +1 @@ +lorem \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index ba8695b11..b796df6a4 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -147,3 +147,17 @@ async def test_document_search_with_search_config(): assert len(results) == 1 assert results[0].content == "Name of Peppa's brother is George" + + +async def test_document_search_insert_documents(): + document_search = DocumentSearch.from_config(CONFIG) + examples_files = Path(__file__).parent / "example_files" + + await document_search.ingest_documents( + LocalFileSource.list_sources(examples_files, file_pattern="*.md"), + document_processor=DummyProvider(), + ) + + results = await document_search.search("foo") + assert len(results) == 2 + assert {result.content for result in results} == {"foo", "bar"} diff --git a/packages/ragbits-document-search/tests/unit/test_gcs_source.py b/packages/ragbits-document-search/tests/unit/test_gcs_source.py index da32b5a9f..6ffbe72e2 100644 --- a/packages/ragbits-document-search/tests/unit/test_gcs_source.py +++ b/packages/ragbits-document-search/tests/unit/test_gcs_source.py @@ -1,9 +1,6 @@ import os from pathlib import Path -import aiohttp -import pytest - from ragbits.document_search.documents.sources import GCSSource TEST_FILE_PATH = Path(__file__) @@ -16,7 +13,3 @@ async def test_gcs_source_fetch(): path = await source.fetch() assert path == TEST_FILE_PATH - - source = GCSSource(bucket="", object_name="not_found_file.py") - with pytest.raises(aiohttp.ClientConnectionError): - await source.fetch() diff --git a/packages/ragbits-document-search/tests/unit/test_local_file_source.py b/packages/ragbits-document-search/tests/unit/test_local_file_source.py new file mode 100644 index 000000000..a2cfe6cb5 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_local_file_source.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from ragbits.document_search.documents.sources import LocalFileSource + +TEST_FILE_PATH = Path(__file__) + + +async def test_local_source_fetch(): + source = LocalFileSource(path=TEST_FILE_PATH) + + path = await source.fetch() + assert path == TEST_FILE_PATH + + +async def test_local_source_list_sources(): + example_files = TEST_FILE_PATH.parent / "example_files" + sources = LocalFileSource.list_sources(example_files, file_pattern="*.md") + + assert len(sources) == 2 + assert all(isinstance(source, LocalFileSource) for source in sources) + assert all(source.path.suffix == ".md" for source in sources) diff --git a/uv.lock b/uv.lock index 6156c1afe..e2904e2dd 100644 --- a/uv.lock +++ b/uv.lock @@ -3119,6 +3119,9 @@ local = [ { name = "torch" }, { name = "transformers" }, ] +promptfoo = [ + { name = "pyyaml" }, +] [package.dev-dependencies] dev = [ @@ -3137,6 +3140,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.24.0" }, { name = "pydantic", specifier = ">=2.9.1" }, + { name = "pyyaml", marker = "extra == 'promptfoo'", specifier = "~=6.0.2" }, { name = "tomli", specifier = "~=2.0.2" }, { name = "torch", marker = "extra == 'local'", specifier = "~=2.2.1" }, { name = "transformers", marker = "extra == 'local'", specifier = "~=4.44.2" },