Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): batch ingestion from sources #112

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ documents = [
async def main():
document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore())

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

return await document_search.search("I'm boiling my water and I need a joke")

Expand Down
2 changes: 1 addition & 1 deletion examples/apps/documents_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _prepare_document_search(self, database_path: str, index_name: str) -> None:

async def _create_database(self, document_paths: list[str]) -> str:
for path in document_paths:
await self.document_search.ingest_document(DocumentMeta.from_local_path(Path(path)))
await self.document_search.ingest([DocumentMeta.from_local_path(Path(path))])
self._documents_ingested = True
return self.DATABASE_CREATED_MESSAGE + self._database_path

Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ async def main():

document_search = DocumentSearch(embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore())

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ async def main():
)
document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store)

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
3 changes: 1 addition & 2 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ async def main():

document_search = DocumentSearch.from_config(config)

for document in documents:
await document_search.ingest_document(document)
await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Sequence, Union

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -104,20 +104,20 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())

return self.reranker.rerank(elements)

async def ingest_document(
async def _process_document(
self,
document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
) -> list[Element]:
"""
Ingest a document.
Process a document and return the elements.

Args:
document: The document or metadata of the document to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
document: The document to process.

Returns:
The elements.
"""
if isinstance(document, Source):
document_meta = await DocumentMeta.from_source(document)
elif isinstance(document, DocumentMeta):
Expand All @@ -128,7 +128,27 @@ async def ingest_document(
if document_processor is None:
document_processor = self.document_processor_router.get_provider(document_meta)

elements = await document_processor.process(document_meta)
document_processor = self.document_processor_router.get_provider(document_meta)
return await document_processor.process(document_meta)

async def ingest(
self,
documents: Sequence[DocumentMeta | Document | Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
"""
Ingest multiple documents.

Args:
documents: The documents or metadata of the documents to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""

elements = []
# TODO: Parallelize
for document in documents:
elements.extend(await self._process_document(document, document_processor))
await self.insert_elements(elements)

async def insert_elements(self, elements: list[Element]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ async def fetch(self) -> Path:
raise SourceNotFoundError(source_id=self.id)
return self.path

@classmethod
def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSource"]:
"""
List all sources in the given directory, matching the file pattern.

Args:
path: The path to the directory.
file_pattern: The file pattern to match.

Returns:
List of source objects.
"""
return [cls(path=file_path) for file_path in path.glob(file_pattern)]


class GCSSource(Source):
"""
Expand Down Expand Up @@ -127,6 +141,29 @@ async def fetch(self) -> Path:

return path

@requires_dependencies(["gcloud.aio.storage"], "gcs")
@classmethod
async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
"""
List all sources in the given GCS bucket, matching the prefix.

Args:
bucket: The GCS bucket.
prefix: The prefix to match.

Returns:
List of source objects.

Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
async with Storage() as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources


class HuggingFaceSource(Source):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
class DummyProvider(BaseProvider):
"""This is a mock provider that returns a TextElement with the content of the document.
It should be used for testing purposes only.

TODO: Remove this provider after the implementation of the real providers.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bar
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lorem
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
async def test_document_search_from_config(document, expected):
document_search = DocumentSearch.from_config(CONFIG)

await document_search.ingest_document(document)
await document_search.ingest([document])
results = await document_search.search("Peppa's brother")

first_result = results[0]
Expand All @@ -49,7 +49,7 @@ async def test_document_search_from_config(document, expected):
assert first_result.content == expected


async def test_document_search_ingest_document_from_source():
async def test_document_search_ingest_from_source():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

Expand All @@ -66,7 +66,7 @@ async def test_document_search_ingest_document_from_source():

source = LocalFileSource(path=Path(f.name))

await document_search.ingest_document(source)
await document_search.ingest([source])

results = await document_search.search("Peppa's brother")

Expand All @@ -85,13 +85,13 @@ async def test_document_search_ingest_document_from_source():
),
],
)
async def test_document_search_ingest_document(document: Union[DocumentMeta, Document]):
async def test_document_search_ingest(document: Union[DocumentMeta, Document]):
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(document, document_processor=DummyProvider())
await document_search.ingest([document], document_processor=DummyProvider())

results = await document_search.search("Peppa's brother")

Expand Down Expand Up @@ -138,12 +138,26 @@ async def test_document_search_with_search_config():

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
await document_search.ingest(
[DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")],
document_processor=DummyProvider(),
)

results = await document_search.search("Peppa's brother", search_config=SearchConfig(vector_store_kwargs={"k": 1}))

assert len(results) == 1
assert results[0].content == "Name of Peppa's brother is George"


async def test_document_search_ingest_multiple_from_sources():
document_search = DocumentSearch.from_config(CONFIG)
examples_files = Path(__file__).parent / "example_files"

await document_search.ingest(
LocalFileSource.list_sources(examples_files, file_pattern="*.md"),
document_processor=DummyProvider(),
)

results = await document_search.search("foo")
assert len(results) == 2
assert {result.content for result in results} == {"foo", "bar"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from ragbits.document_search.documents.sources import LocalFileSource

TEST_FILE_PATH = Path(__file__)


async def test_local_source_fetch():
source = LocalFileSource(path=TEST_FILE_PATH)

path = await source.fetch()

assert path == TEST_FILE_PATH


async def test_local_source_list_sources():
example_files = TEST_FILE_PATH.parent / "example_files"

sources = LocalFileSource.list_sources(example_files, file_pattern="*.md")

assert len(sources) == 2
assert all(isinstance(source, LocalFileSource) for source in sources)
assert all(source.path.suffix == ".md" for source in sources)
4 changes: 4 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading