Skip to content

Commit

Permalink
feat(document-search): batch ingestion from sources (#112)
Browse files Browse the repository at this point in the history
Co-authored-by: Alan Konarski <[email protected]>
  • Loading branch information
ludwiktrammer and akonarski-ds authored Oct 17, 2024
1 parent bf7d2c6 commit daf4926
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 27 deletions.
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ documents = [
async def main():
document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore())

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

return await document_search.search("I'm boiling my water and I need a joke")

Expand Down
2 changes: 1 addition & 1 deletion examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _prepare_document_search(self, database_path: str, index_name: str) -> None:

async def _create_database(self, document_paths: list[str]) -> str:
for path in document_paths:
await self.document_search.ingest_document(DocumentMeta.from_local_path(Path(path)))
await self.document_search.ingest([DocumentMeta.from_local_path(Path(path))])
self._documents_ingested = True
return self.DATABASE_CREATED_MESSAGE + self._database_path

Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ async def main():

document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore())

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ async def main():
)
document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store)

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ async def main():

document_search = DocumentSearch.from_config(config)

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -104,20 +104,20 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())

return self.reranker.rerank(elements)

async def ingest_document(
async def _process_document(
self,
document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
) -> list[Element]:
"""
Ingest a document.
Process a document and return the elements.
Args:
document: The document or metadata of the document to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
document: The document to process.
Returns:
The elements.
"""
if isinstance(document, Source):
document_meta = await DocumentMeta.from_source(document)
elif isinstance(document, DocumentMeta):
Expand All @@ -128,7 +128,27 @@ async def ingest_document(
if document_processor is None:
document_processor = self.document_processor_router.get_provider(document_meta)

elements = await document_processor.process(document_meta)
document_processor = self.document_processor_router.get_provider(document_meta)
return await document_processor.process(document_meta)

async def ingest(
self,
documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
"""
Ingest multiple documents.
Args:
documents: The documents or metadata of the documents to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""

elements = []
# TODO: Parallelize
for document in documents:
elements.extend(await self._process_document(document, document_processor))
await self.insert_elements(elements)

async def insert_elements(self, elements: list[Element]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ async def fetch(self) -> Path:
raise SourceNotFoundError(source_id=self.id)
return self.path

@classmethod
def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSource"]:
"""
List all sources in the given directory, matching the file pattern.
Args:
path: The path to the directory.
file_pattern: The file pattern to match.
Returns:
List of source objects.
"""
return [cls(path=file_path) for file_path in path.glob(file_pattern)]


class GCSSource(Source):
"""
Expand Down Expand Up @@ -127,6 +141,29 @@ async def fetch(self) -> Path:

return path

@requires_dependencies(["gcloud.aio.storage"], "gcs")
@classmethod
async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
"""
List all sources in the given GCS bucket, matching the prefix.
Args:
bucket: The GCS bucket.
prefix: The prefix to match.
Returns:
List of source objects.
Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
async with Storage() as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources


class HuggingFaceSource(Source):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
class DummyProvider(BaseProvider):
"""This is a mock provider that returns a TextElement with the content of the document.
It should be used for testing purposes only.
TODO: Remove this provider after the implementation of the real providers.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bar
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lorem
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
async def test_document_search_from_config(document, expected):
document_search = DocumentSearch.from_config(CONFIG)

await document_search.ingest_document(document)
await document_search.ingest([document])
results = await document_search.search("Peppa's brother")

first_result = results[0]
Expand All @@ -49,7 +49,7 @@ async def test_document_search_from_config(document, expected):
assert first_result.content == expected


async def test_document_search_ingest_document_from_source():
async def test_document_search_ingest_from_source():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

Expand All @@ -66,7 +66,7 @@ async def test_document_search_ingest_document_from_source():

source = LocalFileSource(path=Path(f.name))

await document_search.ingest_document(source)
await document_search.ingest([source])

results = await document_search.search("Peppa's brother")

Expand All @@ -85,13 +85,13 @@ async def test_document_search_ingest_document_from_source():
),
],
)
async def test_document_search_ingest_document(document: Union[DocumentMeta, Document]):
async def test_document_search_ingest(document: Union[DocumentMeta, Document]):
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(document, document_processor=DummyProvider())
await document_search.ingest([document], document_processor=DummyProvider())

results = await document_search.search("Peppa's brother")

Expand Down Expand Up @@ -138,12 +138,26 @@ async def test_document_search_with_search_config():

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
await document_search.ingest(
[DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")],
document_processor=DummyProvider(),
)

results = await document_search.search("Peppa's brother", search_config=SearchConfig(vector_store_kwargs={"k": 1}))

assert len(results) == 1
assert results[0].content == "Name of Peppa's brother is George"


async def test_document_search_ingest_multiple_from_sources():
document_search = DocumentSearch.from_config(CONFIG)
examples_files = Path(__file__).parent / "example_files"

await document_search.ingest(
LocalFileSource.list_sources(examples_files, file_pattern="*.md"),
document_processor=DummyProvider(),
)

results = await document_search.search("foo")
assert len(results) == 2
assert {result.content for result in results} == {"foo", "bar"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from ragbits.document_search.documents.sources import LocalFileSource

TEST_FILE_PATH = Path(__file__)


async def test_local_source_fetch():
source = LocalFileSource(path=TEST_FILE_PATH)

path = await source.fetch()

assert path == TEST_FILE_PATH


async def test_local_source_list_sources():
example_files = TEST_FILE_PATH.parent / "example_files"

sources = LocalFileSource.list_sources(example_files, file_pattern="*.md")

assert len(sources) == 2
assert all(isinstance(source, LocalFileSource) for source in sources)
assert all(source.path.suffix == ".md" for source in sources)
4 changes: 4 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit daf4926

Please sign in to comment.