diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index fe3f826c2..dd92caeee 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -6,6 +6,7 @@ from ragbits.core.vector_store import VectorStore, get_vector_store from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element +from ragbits.document_search.documents.sources import GCSSource, LocalFileSource, Source from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.base import BaseProvider from ragbits.document_search.retrieval.rephrasers import get_rephraser @@ -104,7 +105,9 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) return self.reranker.rerank(elements) async def ingest_document( - self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None + self, + document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, ) -> None: """ Ingest a document. @@ -114,7 +117,14 @@ async def ingest_document( document_processor: The document processor to use. If not provided, the document processor will be determined based on the document metadata. """ - document_meta = document if isinstance(document, DocumentMeta) else document.metadata + + if isinstance(document, Source): + document_meta = await DocumentMeta.from_source(document) + elif isinstance(document, DocumentMeta): + document_meta = document + else: + document_meta = document.metadata + if document_processor is None: document_processor = self.document_processor_router.get_provider(document_meta) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py index 2ca2ec9a6..0d43df918 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py @@ -97,6 +97,24 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta": source=LocalFileSource(path=local_path), ) + @classmethod + async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta": + """ + Create a document metadata from a source. + + Args: + source: The source from which the document is fetched. + + Returns: + The document metadata. + """ + path = await source.fetch() + + return cls( + document_type=DocumentType(path.suffix[1:]), + source=source, + ) + class Document(BaseModel): """ diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 76743c890..ba8695b11 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,3 +1,4 @@ +import tempfile from pathlib import Path from typing import Union from unittest.mock import AsyncMock @@ -7,8 +8,10 @@ from ragbits.core.vector_store.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search._main import SearchConfig -from ragbits.document_search.documents.document import Document, DocumentMeta +from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType from ragbits.document_search.documents.element import TextElement +from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.dummy import DummyProvider CONFIG = { @@ -46,6 +49,33 @@ async def test_document_search_from_config(document, expected): assert first_result.content == expected +async def test_document_search_ingest_document_from_source(): + embeddings_mock = AsyncMock() + embeddings_mock.embed_text.return_value = [[0.1, 0.1]] + + providers = {DocumentType.TXT: DummyProvider()} + router = DocumentProcessorRouter.from_config(providers) + + document_search = DocumentSearch( + embedder=embeddings_mock, vector_store=InMemoryVectorStore(), document_processor_router=router + ) + + with tempfile.NamedTemporaryFile(suffix=".txt") as f: + f.write(b"Name of Peppa's brother is George") + f.seek(0) + + source = LocalFileSource(path=Path(f.name)) + + await document_search.ingest_document(source) + + results = await document_search.search("Peppa's brother") + + first_result = results[0] + + assert isinstance(first_result, TextElement) + assert first_result.content == "Name of Peppa's brother is George" + + @pytest.mark.parametrize( "document", [