Skip to content

Commit

Permalink
feat(document-search): determine document type automatically (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla authored Oct 16, 2024
1 parent f8771b3 commit d61725f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ragbits.core.vector_store import VectorStore, get_vector_store
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.documents.sources import GCSSource, LocalFileSource, Source
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.retrieval.rephrasers import get_rephraser
Expand Down Expand Up @@ -104,7 +105,9 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())
return self.reranker.rerank(elements)

async def ingest_document(
self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None
self,
document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]],
document_processor: Optional[BaseProvider] = None,
) -> None:
"""
Ingest a document.
Expand All @@ -114,7 +117,14 @@ async def ingest_document(
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
document_meta = document if isinstance(document, DocumentMeta) else document.metadata

if isinstance(document, Source):
document_meta = await DocumentMeta.from_source(document)
elif isinstance(document, DocumentMeta):
document_meta = document
else:
document_meta = document.metadata

if document_processor is None:
document_processor = self.document_processor_router.get_provider(document_meta)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
source=LocalFileSource(path=local_path),
)

@classmethod
async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta":
"""
Create a document metadata from a source.
Args:
source: The source from which the document is fetched.
Returns:
The document metadata.
"""
path = await source.fetch()

return cls(
document_type=DocumentType(path.suffix[1:]),
source=source,
)


class Document(BaseModel):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
from pathlib import Path
from typing import Union
from unittest.mock import AsyncMock
Expand All @@ -7,8 +8,10 @@
from ragbits.core.vector_store.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search._main import SearchConfig
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.dummy import DummyProvider

CONFIG = {
Expand Down Expand Up @@ -46,6 +49,33 @@ async def test_document_search_from_config(document, expected):
assert first_result.content == expected


async def test_document_search_ingest_document_from_source():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

providers = {DocumentType.TXT: DummyProvider()}
router = DocumentProcessorRouter.from_config(providers)

document_search = DocumentSearch(
embedder=embeddings_mock, vector_store=InMemoryVectorStore(), document_processor_router=router
)

with tempfile.NamedTemporaryFile(suffix=".txt") as f:
f.write(b"Name of Peppa's brother is George")
f.seek(0)

source = LocalFileSource(path=Path(f.name))

await document_search.ingest_document(source)

results = await document_search.search("Peppa's brother")

first_result = results[0]

assert isinstance(first_result, TextElement)
assert first_result.content == "Name of Peppa's brother is George"


@pytest.mark.parametrize(
"document",
[
Expand Down

0 comments on commit d61725f

Please sign in to comment.