Skip to content

Commit

Permalink
feat(document-search): batch ingestion from sources
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 16, 2024
1 parent 17d7b3d commit f7f79a9
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -104,20 +104,20 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())

return self.reranker.rerank(elements)

async def ingest_document(
async def _process_document(
self,
document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
) -> list[Element]:
"""
Ingest a document.
Process a document and return the elements.
Args:
document: The document or metadata of the document to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
document: The document to process.
Returns:
The elements.
"""
if isinstance(document, Source):
document_meta = await DocumentMeta.from_source(document)
elif isinstance(document, DocumentMeta):
Expand All @@ -128,7 +128,44 @@ async def ingest_document(
if document_processor is None:
document_processor = self.document_processor_router.get_provider(document_meta)

elements = await document_processor.process(document_meta)
document_processor = self.document_processor_router.get_provider(document_meta)
return await document_processor.process(document_meta)

async def ingest_document(
self,
document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
"""
Ingest a document.
Args:
document: The document or metadata of the document to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""

elements = await self._process_document(document, document_processor)
await self.insert_elements(elements)

async def ingest_documents(
self,
documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
"""
Ingest multiple documents.
Args:
documents: The documents or metadata of the documents to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""

elements = []
# TODO: Parallelize
for document in documents:
elements.extend(await self._process_document(document, document_processor))
await self.insert_elements(elements)

async def insert_elements(self, elements: list[Element]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ async def fetch(self) -> Path:
"""
return self.path

@classmethod
def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSource"]:
"""
List all sources in the given directory, matching the file pattern.
Args:
path: The path to the directory.
file_pattern: The file pattern to match.
Returns:
List of source objects.
"""
sources = []
for file_path in path.glob(file_pattern):
sources.append(cls(path=file_path))
return sources


class GCSSource(Source):
"""
Expand Down Expand Up @@ -122,3 +139,28 @@ async def fetch(self) -> Path:
file_object.write(content)

return path

@classmethod
async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
"""
List all sources in the given GCS bucket, matching the prefix.
Args:
bucket: The GCS bucket.
prefix: The prefix to match.
Returns:
List of source objects.
Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

async with Storage() as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
class DummyProvider(BaseProvider):
"""This is a mock provider that returns a TextElement with the content of the document.
It should be used for testing purposes only.
TODO: Remove this provider after the implementation of the real providers.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bar
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lorem
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ async def test_document_search_with_search_config():

assert len(results) == 1
assert results[0].content == "Name of Peppa's brother is George"


async def test_document_search_insert_documents():
document_search = DocumentSearch.from_config(CONFIG)
examples_files = Path(__file__).parent / "example_files"

await document_search.ingest_documents(
LocalFileSource.list_sources(examples_files, file_pattern="*.md"),
document_processor=DummyProvider(),
)

results = await document_search.search("foo")
assert len(results) == 2
assert {result.content for result in results} == {"foo", "bar"}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
from pathlib import Path

import aiohttp
import pytest

from ragbits.document_search.documents.sources import GCSSource

TEST_FILE_PATH = Path(__file__)
Expand All @@ -16,7 +13,3 @@ async def test_gcs_source_fetch():

path = await source.fetch()
assert path == TEST_FILE_PATH

source = GCSSource(bucket="", object_name="not_found_file.py")
with pytest.raises(aiohttp.ClientConnectionError):
await source.fetch()
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pathlib import Path

from ragbits.document_search.documents.sources import LocalFileSource

TEST_FILE_PATH = Path(__file__)


async def test_local_source_fetch():
source = LocalFileSource(path=TEST_FILE_PATH)

path = await source.fetch()
assert path == TEST_FILE_PATH


async def test_local_source_list_sources():
example_files = TEST_FILE_PATH.parent / "example_files"
sources = LocalFileSource.list_sources(example_files, file_pattern="*.md")

assert len(sources) == 2
assert all(isinstance(source, LocalFileSource) for source in sources)
assert all(source.path.suffix == ".md" for source in sources)
4 changes: 4 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f7f79a9

Please sign in to comment.