From f7f79a973205766c9eb67d18b0ac07f7b3285413 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Wed, 16 Oct 2024 15:51:06 +0200 Subject: [PATCH 1/4] feat(document-search): batch ingestion from sources --- .../src/ragbits/document_search/_main.py | 55 ++++++++++++++++--- .../document_search/documents/sources.py | 42 ++++++++++++++ .../ingestion/providers/dummy.py | 2 - .../tests/unit/example_files/bar.md | 1 + .../tests/unit/example_files/foo.md | 1 + .../tests/unit/example_files/lorem.txt | 1 + .../tests/unit/test_document_search.py | 14 +++++ .../tests/unit/test_gcs_source.py | 7 --- .../tests/unit/test_local_file_source.py | 21 +++++++ uv.lock | 4 ++ 10 files changed, 130 insertions(+), 18 deletions(-) create mode 100644 packages/ragbits-document-search/tests/unit/example_files/bar.md create mode 100644 packages/ragbits-document-search/tests/unit/example_files/foo.md create mode 100644 packages/ragbits-document-search/tests/unit/example_files/lorem.txt create mode 100644 packages/ragbits-document-search/tests/unit/test_local_file_source.py diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index dd92caeee..fb9cb4f5e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Sequence, Union from pydantic import BaseModel, Field @@ -104,20 +104,20 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) return self.reranker.rerank(elements) - async def ingest_document( + async def _process_document( self, document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], document_processor: Optional[BaseProvider] = None, - ) -> None: + ) -> list[Element]: """ - Ingest a document. + Process a document and return the elements. Args: - document: The document or metadata of the document to ingest. - document_processor: The document processor to use. If not provided, the document processor will be - determined based on the document metadata. - """ + document: The document to process. + Returns: + The elements. + """ if isinstance(document, Source): document_meta = await DocumentMeta.from_source(document) elif isinstance(document, DocumentMeta): @@ -128,7 +128,44 @@ async def ingest_document( if document_processor is None: document_processor = self.document_processor_router.get_provider(document_meta) - elements = await document_processor.process(document_meta) + document_processor = self.document_processor_router.get_provider(document_meta) + return await document_processor.process(document_meta) + + async def ingest_document( + self, + document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, + ) -> None: + """ + Ingest a document. + + Args: + document: The document or metadata of the document to ingest. + document_processor: The document processor to use. If not provided, the document processor will be + determined based on the document metadata. + """ + + elements = await self._process_document(document, document_processor) + await self.insert_elements(elements) + + async def ingest_documents( + self, + documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, + ) -> None: + """ + Ingest multiple documents. + + Args: + documents: The documents or metadata of the documents to ingest. + document_processor: The document processor to use. If not provided, the document processor will be + determined based on the document metadata. + """ + + elements = [] + # TODO: Parallelize + for document in documents: + elements.extend(await self._process_document(document, document_processor)) await self.insert_elements(elements) async def insert_elements(self, elements: list[Element]) -> None: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index fc5a93a83..73e50df80 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -66,6 +66,23 @@ async def fetch(self) -> Path: """ return self.path + @classmethod + def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSource"]: + """ + List all sources in the given directory, matching the file pattern. + + Args: + path: The path to the directory. + file_pattern: The file pattern to match. + + Returns: + List of source objects. + """ + sources = [] + for file_path in path.glob(file_pattern): + sources.append(cls(path=file_path)) + return sources + class GCSSource(Source): """ @@ -122,3 +139,28 @@ async def fetch(self) -> Path: file_object.write(content) return path + + @classmethod + async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: + """ + List all sources in the given GCS bucket, matching the prefix. + + Args: + bucket: The GCS bucket. + prefix: The prefix to match. + + Returns: + List of source objects. + + Raises: + ImportError: If the required 'gcloud-aio-storage' package is not installed + """ + if not HAS_GCLOUD_AIO: + raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage") + + async with Storage() as client: + objects = await client.list_objects(bucket, params={"prefix": prefix}) + sources = [] + for obj in objects["items"]: + sources.append(cls(bucket=bucket, object_name=obj["name"])) + return sources diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py index 712f31a09..5867be9f9 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py @@ -6,8 +6,6 @@ class DummyProvider(BaseProvider): """This is a mock provider that returns a TextElement with the content of the document. It should be used for testing purposes only. - - TODO: Remove this provider after the implementation of the real providers. """ SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT} diff --git a/packages/ragbits-document-search/tests/unit/example_files/bar.md b/packages/ragbits-document-search/tests/unit/example_files/bar.md new file mode 100644 index 000000000..ba0e162e1 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/bar.md @@ -0,0 +1 @@ +bar \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/example_files/foo.md b/packages/ragbits-document-search/tests/unit/example_files/foo.md new file mode 100644 index 000000000..191028156 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/foo.md @@ -0,0 +1 @@ +foo \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/example_files/lorem.txt b/packages/ragbits-document-search/tests/unit/example_files/lorem.txt new file mode 100644 index 000000000..0db1a7d5c --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/example_files/lorem.txt @@ -0,0 +1 @@ +lorem \ No newline at end of file diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index ba8695b11..b796df6a4 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -147,3 +147,17 @@ async def test_document_search_with_search_config(): assert len(results) == 1 assert results[0].content == "Name of Peppa's brother is George" + + +async def test_document_search_insert_documents(): + document_search = DocumentSearch.from_config(CONFIG) + examples_files = Path(__file__).parent / "example_files" + + await document_search.ingest_documents( + LocalFileSource.list_sources(examples_files, file_pattern="*.md"), + document_processor=DummyProvider(), + ) + + results = await document_search.search("foo") + assert len(results) == 2 + assert {result.content for result in results} == {"foo", "bar"} diff --git a/packages/ragbits-document-search/tests/unit/test_gcs_source.py b/packages/ragbits-document-search/tests/unit/test_gcs_source.py index da32b5a9f..6ffbe72e2 100644 --- a/packages/ragbits-document-search/tests/unit/test_gcs_source.py +++ b/packages/ragbits-document-search/tests/unit/test_gcs_source.py @@ -1,9 +1,6 @@ import os from pathlib import Path -import aiohttp -import pytest - from ragbits.document_search.documents.sources import GCSSource TEST_FILE_PATH = Path(__file__) @@ -16,7 +13,3 @@ async def test_gcs_source_fetch(): path = await source.fetch() assert path == TEST_FILE_PATH - - source = GCSSource(bucket="", object_name="not_found_file.py") - with pytest.raises(aiohttp.ClientConnectionError): - await source.fetch() diff --git a/packages/ragbits-document-search/tests/unit/test_local_file_source.py b/packages/ragbits-document-search/tests/unit/test_local_file_source.py new file mode 100644 index 000000000..a2cfe6cb5 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_local_file_source.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from ragbits.document_search.documents.sources import LocalFileSource + +TEST_FILE_PATH = Path(__file__) + + +async def test_local_source_fetch(): + source = LocalFileSource(path=TEST_FILE_PATH) + + path = await source.fetch() + assert path == TEST_FILE_PATH + + +async def test_local_source_list_sources(): + example_files = TEST_FILE_PATH.parent / "example_files" + sources = LocalFileSource.list_sources(example_files, file_pattern="*.md") + + assert len(sources) == 2 + assert all(isinstance(source, LocalFileSource) for source in sources) + assert all(source.path.suffix == ".md" for source in sources) diff --git a/uv.lock b/uv.lock index 6156c1afe..e2904e2dd 100644 --- a/uv.lock +++ b/uv.lock @@ -3119,6 +3119,9 @@ local = [ { name = "torch" }, { name = "transformers" }, ] +promptfoo = [ + { name = "pyyaml" }, +] [package.dev-dependencies] dev = [ @@ -3137,6 +3140,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.24.0" }, { name = "pydantic", specifier = ">=2.9.1" }, + { name = "pyyaml", marker = "extra == 'promptfoo'", specifier = "~=6.0.2" }, { name = "tomli", specifier = "~=2.0.2" }, { name = "torch", marker = "extra == 'local'", specifier = "~=2.2.1" }, { name = "transformers", marker = "extra == 'local'", specifier = "~=4.44.2" }, From cfd96f63f1aa2e362e4fbfd8462b9642ec5ae5aa Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 17 Oct 2024 10:16:42 +0200 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Alan Konarski <129968242+akonarski-ds@users.noreply.github.com> --- .../src/ragbits/document_search/documents/sources.py | 5 +---- .../tests/unit/test_local_file_source.py | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index 73e50df80..a4a78f05f 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -78,10 +78,7 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou Returns: List of source objects. """ - sources = [] - for file_path in path.glob(file_pattern): - sources.append(cls(path=file_path)) - return sources + return [cls(path=file_path) for file_path in path.glob(file_pattern)] class GCSSource(Source): diff --git a/packages/ragbits-document-search/tests/unit/test_local_file_source.py b/packages/ragbits-document-search/tests/unit/test_local_file_source.py index a2cfe6cb5..bde7549e9 100644 --- a/packages/ragbits-document-search/tests/unit/test_local_file_source.py +++ b/packages/ragbits-document-search/tests/unit/test_local_file_source.py @@ -9,11 +9,13 @@ async def test_local_source_fetch(): source = LocalFileSource(path=TEST_FILE_PATH) path = await source.fetch() + assert path == TEST_FILE_PATH async def test_local_source_list_sources(): example_files = TEST_FILE_PATH.parent / "example_files" + sources = LocalFileSource.list_sources(example_files, file_pattern="*.md") assert len(sources) == 2 From 89c8d962e08628413a6e9e8dde9b9bc0a2552017 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 17 Oct 2024 10:33:04 +0200 Subject: [PATCH 3/4] Fix formatting --- .../tests/unit/test_local_file_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/tests/unit/test_local_file_source.py b/packages/ragbits-document-search/tests/unit/test_local_file_source.py index bde7549e9..f11b567a0 100644 --- a/packages/ragbits-document-search/tests/unit/test_local_file_source.py +++ b/packages/ragbits-document-search/tests/unit/test_local_file_source.py @@ -9,7 +9,7 @@ async def test_local_source_fetch(): source = LocalFileSource(path=TEST_FILE_PATH) path = await source.fetch() - + assert path == TEST_FILE_PATH From dd8939820f42119944b6caef3fa6e8b7ef5cd1ca Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 17 Oct 2024 11:37:46 +0200 Subject: [PATCH 4/4] Remove `ingest_document` --- docs/index.md | 3 +-- .../ragbits-core/examples/chromadb_example.py | 3 +-- .../examples/documents_chat.py | 2 +- .../examples/from_config_example.py | 3 +-- .../examples/simple_text.py | 3 +-- .../src/ragbits/document_search/_main.py | 19 +------------------ .../tests/unit/test_document_search.py | 18 +++++++++--------- 7 files changed, 15 insertions(+), 36 deletions(-) diff --git a/docs/index.md b/docs/index.md index 16a546000..5a17127ca 100644 --- a/docs/index.md +++ b/docs/index.md @@ -80,8 +80,7 @@ documents = [ async def main(): document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore()) - for document in documents: - await document_search.ingest_document(document) + await document_search.ingest(documents) return await document_search.search("I'm boiling my water and I need a joke") diff --git a/packages/ragbits-core/examples/chromadb_example.py b/packages/ragbits-core/examples/chromadb_example.py index 4abe81823..acd29a7a8 100644 --- a/packages/ragbits-core/examples/chromadb_example.py +++ b/packages/ragbits-core/examples/chromadb_example.py @@ -35,8 +35,7 @@ async def main(): ) document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store) - for document in documents: - await document_search.ingest_document(document) + await document_search.ingest(documents) results = await document_search.search("I'm boiling my water and I need a joke") print(results) diff --git a/packages/ragbits-document-search/examples/documents_chat.py b/packages/ragbits-document-search/examples/documents_chat.py index b5d4121dc..595577990 100644 --- a/packages/ragbits-document-search/examples/documents_chat.py +++ b/packages/ragbits-document-search/examples/documents_chat.py @@ -109,7 +109,7 @@ def _prepare_document_search(self, database_path: str, index_name: str) -> None: async def _create_database(self, document_paths: list[str]) -> str: for path in document_paths: - await self.document_search.ingest_document(DocumentMeta.from_local_path(Path(path))) + await self.document_search.ingest([DocumentMeta.from_local_path(Path(path))]) self._documents_ingested = True return self.DATABASE_CREATED_MESSAGE + self._database_path diff --git a/packages/ragbits-document-search/examples/from_config_example.py b/packages/ragbits-document-search/examples/from_config_example.py index 1599cf84a..89ab7e5b4 100644 --- a/packages/ragbits-document-search/examples/from_config_example.py +++ b/packages/ragbits-document-search/examples/from_config_example.py @@ -40,8 +40,7 @@ async def main(): document_search = DocumentSearch.from_config(config) - for document in documents: - await document_search.ingest_document(document) + await document_search.ingest(documents) results = await document_search.search("I'm boiling my water and I need a joke") print(results) diff --git a/packages/ragbits-document-search/examples/simple_text.py b/packages/ragbits-document-search/examples/simple_text.py index c7cd934a2..973500b4e 100644 --- a/packages/ragbits-document-search/examples/simple_text.py +++ b/packages/ragbits-document-search/examples/simple_text.py @@ -28,8 +28,7 @@ async def main(): document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore()) - for document in documents: - await document_search.ingest_document(document) + await document_search.ingest(documents) results = await document_search.search("I'm boiling my water and I need a joke") print(results) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index fb9cb4f5e..8a72f1746 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -131,24 +131,7 @@ async def _process_document( document_processor = self.document_processor_router.get_provider(document_meta) return await document_processor.process(document_meta) - async def ingest_document( - self, - document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], - document_processor: Optional[BaseProvider] = None, - ) -> None: - """ - Ingest a document. - - Args: - document: The document or metadata of the document to ingest. - document_processor: The document processor to use. If not provided, the document processor will be - determined based on the document metadata. - """ - - elements = await self._process_document(document, document_processor) - await self.insert_elements(elements) - - async def ingest_documents( + async def ingest( self, documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]], document_processor: Optional[BaseProvider] = None, diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index b796df6a4..173345ab4 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -40,7 +40,7 @@ async def test_document_search_from_config(document, expected): document_search = DocumentSearch.from_config(CONFIG) - await document_search.ingest_document(document) + await document_search.ingest([document]) results = await document_search.search("Peppa's brother") first_result = results[0] @@ -49,7 +49,7 @@ async def test_document_search_from_config(document, expected): assert first_result.content == expected -async def test_document_search_ingest_document_from_source(): +async def test_document_search_ingest_from_source(): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] @@ -66,7 +66,7 @@ async def test_document_search_ingest_document_from_source(): source = LocalFileSource(path=Path(f.name)) - await document_search.ingest_document(source) + await document_search.ingest([source]) results = await document_search.search("Peppa's brother") @@ -85,13 +85,13 @@ async def test_document_search_ingest_document_from_source(): ), ], ) -async def test_document_search_ingest_document(document: Union[DocumentMeta, Document]): +async def test_document_search_ingest(document: Union[DocumentMeta, Document]): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore()) - await document_search.ingest_document(document, document_processor=DummyProvider()) + await document_search.ingest([document], document_processor=DummyProvider()) results = await document_search.search("Peppa's brother") @@ -138,8 +138,8 @@ async def test_document_search_with_search_config(): document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore()) - await document_search.ingest_document( - DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"), + await document_search.ingest( + [DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")], document_processor=DummyProvider(), ) @@ -149,11 +149,11 @@ async def test_document_search_with_search_config(): assert results[0].content == "Name of Peppa's brother is George" -async def test_document_search_insert_documents(): +async def test_document_search_ingest_multiple_from_sources(): document_search = DocumentSearch.from_config(CONFIG) examples_files = Path(__file__).parent / "example_files" - await document_search.ingest_documents( + await document_search.ingest( LocalFileSource.list_sources(examples_files, file_pattern="*.md"), document_processor=DummyProvider(), )