Skip to content

Commit

Permalink
feat(document-search): Implement document search public interface (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds authored Oct 7, 2024
1 parent 704eef2 commit abbcdb0
Show file tree
Hide file tree
Showing 9 changed files with 566 additions and 424 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
from typing import Any, Optional, Union

from pydantic import BaseModel, Field

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker
from ragbits.document_search.vector_store.base import VectorStore


class SearchConfig(BaseModel):
"""
Configuration for the search process.
"""

reranker_kwargs: dict[str, Any] = Field(default_factory=dict)
vector_store_kwargs: dict[str, Any] = Field(default_factory=dict)
embedder_kwargs: dict[str, Any] = Field(default_factory=dict)


class DocumentSearch:
"""
A main entrypoint to the DocumentSearch functionality.
Expand All @@ -36,18 +50,21 @@ def __init__(
vector_store: VectorStore,
query_rephraser: QueryRephraser | None = None,
reranker: Reranker | None = None,
document_processor_router: DocumentProcessorRouter | None = None,
) -> None:
self.embedder = embedder
self.vector_store = vector_store
self.query_rephraser = query_rephraser or NoopQueryRephraser()
self.reranker = reranker or NoopReranker()
self.document_processor_router = document_processor_router or DocumentProcessorRouter.from_config()

async def search(self, query: str) -> list[Element]:
async def search(self, query: str, search_config: SearchConfig = SearchConfig()) -> list[Element]:
"""
Search for the most relevant chunks for a query.
Args:
query: The query to search for.
search_config: The search configuration.
Returns:
A list of chunks.
Expand All @@ -56,23 +73,36 @@ async def search(self, query: str) -> list[Element]:
elements = []
for rephrased_query in queries:
search_vector = await self.embedder.embed_text([rephrased_query])
# TODO: search parameters should be configurable
entries = await self.vector_store.retrieve(search_vector[0], k=1)
entries = await self.vector_store.retrieve(search_vector[0], **search_config.vector_store_kwargs)
elements.extend([Element.from_vector_db_entry(entry) for entry in entries])

return self.reranker.rerank(elements)

async def ingest_document(self, document: DocumentMeta) -> None:
async def ingest_document(
self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None
) -> None:
"""
Ingest a document.
Args:
document: The document to ingest.
document: The document or metadata of the document to ingest.
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
document_meta = document if isinstance(document, DocumentMeta) else document.metadata
if document_processor is None:
document_processor = self.document_processor_router.get_provider(document_meta)

elements = await document_processor.process(document_meta)
await self.insert_elements(elements)

async def insert_elements(self, elements: list[Element]) -> None:
"""
# TODO: This is a placeholder implementation. It should be replaced with a real implementation.
Insert Elements into the vector store.
document_processor = DocumentProcessor.from_config({DocumentType.TXT: DummyProvider()})
elements = await document_processor.process(document)
Args:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.get_key() for element in elements])
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors)]
await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Optional

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider

ProvidersConfig = dict[DocumentType, BaseProvider]


DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredProvider(),
DocumentType.MD: UnstructuredProvider(),
Expand All @@ -30,20 +30,21 @@
}


class DocumentProcessor:
class DocumentProcessorRouter:
"""
A class with an implementation of Document Processor, allowing to process documents.
The DocumentProcessorRouter is responsible for routing the document to the correct provider based on the document
metadata such as the document type.
"""

def __init__(self, providers: dict[DocumentType, BaseProvider]):
self._providers = providers

@classmethod
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessor":
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessorRouter":
"""
Create a DocumentProcessor from a configuration. If the configuration is not provided, the default configuration
will be used. If the configuration is provided, it will be merged with the default configuration, overriding
the default values for the document types that are defined in the configuration.
Create a DocumentProcessorRouter from a configuration. If the configuration is not provided, the default
configuration will be used. If the configuration is provided, it will be merged with the default configuration,
overriding the default values for the document types that are defined in the configuration.
Example of the configuration:
{
DocumentType.TXT: YourCustomProviderClass(),
Expand All @@ -55,30 +56,27 @@ def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "Doc
provider class.
Returns:
The DocumentProcessor.
The DocumentProcessorRouter.
"""
config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})

return cls(providers=config)

async def process(self, document_meta: DocumentMeta) -> list[Element]:
def get_provider(self, document_meta: DocumentMeta) -> BaseProvider:
"""
Process the document.
Get the provider for the document.
Args:
document_meta: The document to process.
document_meta: The document metadata.
Returns:
The list of elements extracted from the document.
The provider for processing the document.
Raises:
ValueError: If the provider for the document type is not defined in the configuration.
ValueError: If no provider is found for the document type.
"""
provider = self._providers.get(document_meta.document_type)
if provider is None:
raise ValueError(
f"Provider for {document_meta.document_type} is not defined in the configuration:" f" {self._providers}"
)

return await provider.process(document_meta)
raise ValueError(f"No provider found for the document type {document_meta.document_type}")
return provider
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Ragbits

Repository for internal experiment with our upcoming LLM framework.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.unstructured import (
DEFAULT_PARTITION_KWARGS,
UNSTRUCTURED_API_KEY_ENV,
Expand All @@ -19,28 +19,28 @@
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
document_processor = DocumentProcessor.from_config()
document_processor = DocumentProcessorRouter.from_config()
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")

elements = await document_processor.process(document_meta)
elements = await document_processor.get_provider(document_meta).process(document_meta)

assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredProvider)
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George"
assert elements[0].content == "Name of Peppa's brother is George."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_API_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
)
async def test_document_processor_processes_md_document_with_unstructured_provider():
document_processor = DocumentProcessor.from_config()
document_meta = DocumentMeta.from_local_path(Path(__file__).parent.parent.parent.parent.parent / "README.md")
document_processor = DocumentProcessorRouter.from_config()
document_meta = DocumentMeta.from_local_path(Path(__file__).parent / "test_file.md")

elements = await document_processor.process(document_meta)
elements = await document_processor.get_provider(document_meta).process(document_meta)

assert len(elements) > 0
assert elements[0].content == "Ragbits"
assert len(elements) == 1
assert elements[0].content == "Ragbits\n\nRepository for internal experiment with our upcoming LLM framework."


@pytest.mark.skipif(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import pytest

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.dummy import DummyProvider


async def test_document_processor_processes_text_document_with_dummy_provider():
providers_config = {DocumentType.TXT: DummyProvider()}
document_processor = DocumentProcessor.from_config(providers_config)
async def test_document_processor_router():
document_processor_router = DocumentProcessorRouter.from_config({DocumentType.TXT: DummyProvider()})

document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")

elements = await document_processor.process(document_meta)
document_processor = document_processor_router.get_provider(document_meta)

assert isinstance(document_processor, DummyProvider)


async def test_document_processor_router_raises_when_no_provider_found():
document_processor_router = DocumentProcessorRouter.from_config()
document_processor_router._providers = {DocumentType.TXT: DummyProvider()}

document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")

document_meta.document_type = DocumentType.PDF

with pytest.raises(ValueError) as err:
_ = document_processor_router.get_provider(document_meta)

assert isinstance(document_processor._providers[DocumentType.TXT], DummyProvider)
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George"
assert str(err.value) == f"No provider found for the document type {DocumentType.PDF}"
Original file line number Diff line number Diff line change
@@ -1,19 +1,55 @@
from pathlib import Path
from typing import Union
from unittest.mock import AsyncMock

import pytest

from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search._main import SearchConfig
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore


async def test_document_search():
@pytest.mark.parametrize(
"document",
[
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
Document.from_document_meta(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"), Path("test.txt")
),
],
)
async def test_document_search_ingest_document(document: Union[DocumentMeta, Document]):
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George")
await document_search.ingest_document(document, document_processor=DummyProvider())

results = await document_search.search("Peppa's brother")

first_result = results[0]

assert isinstance(first_result, TextElement)
assert first_result.content == "Name of Peppa's brother is George"


async def test_document_search_insert_elements():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.insert_elements(
[
TextElement(
content="Name of Peppa's brother is George",
document_meta=DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
)
]
)

results = await document_search.search("Peppa's brother")
Expand All @@ -22,3 +58,28 @@ async def test_document_search():

assert isinstance(first_result, TextElement)
assert first_result.content == "Name of Peppa's brother is George"


async def test_document_search_with_no_results():
document_search = DocumentSearch(embedder=AsyncMock(), vector_store=InMemoryVectorStore())

results = await document_search.search("Peppa's sister")

assert not results


async def test_document_search_with_search_config():
embeddings_mock = AsyncMock()
embeddings_mock.embed_text.return_value = [[0.1, 0.1]]

document_search = DocumentSearch(embedder=embeddings_mock, vector_store=InMemoryVectorStore())

await document_search.ingest_document(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
document_processor=DummyProvider(),
)

results = await document_search.search("Peppa's brother", search_config=SearchConfig(vector_store_kwargs={"k": 1}))

assert len(results) == 1
assert results[0].content == "Name of Peppa's brother is George"
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ async def test_gcs_source_fetch():
assert path == TEST_FILE_PATH

source = GCSSource(bucket="", object_name="not_found_file.py")
with pytest.raises(aiohttp.ClientConnectorError):
with pytest.raises(aiohttp.ClientConnectionError):
await source.fetch()
3 changes: 0 additions & 3 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from dotenv import load_dotenv

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.ingestion.providers.base import DocumentTypeNotSupportedError
Expand All @@ -9,8 +8,6 @@
UnstructuredProvider,
)

load_dotenv()


@pytest.mark.parametrize("document_type", UnstructuredProvider.SUPPORTED_DOCUMENT_TYPES)
def test_unsupported_provider_validates_supported_document_types_passes(document_type: DocumentType):
Expand Down
Loading

0 comments on commit abbcdb0

Please sign in to comment.