Skip to content

Commit

Permalink
feat(document-search): add document processing with unstructured (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds authored Sep 24, 2024
1 parent 9439868 commit 1d743fb
Show file tree
Hide file tree
Showing 19 changed files with 1,144 additions and 114 deletions.
3 changes: 2 additions & 1 deletion .libraries-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pkg_resources
tiktoken
tiktoken
chardet
3 changes: 2 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ classifiers = [
]
dependencies = [
"numpy~=1.24.0",
"ragbits"
"ragbits",
"unstructured>=0.15.12",
]

[tool.uv]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers.base import Reranker
Expand Down Expand Up @@ -70,7 +71,7 @@ async def ingest_document(self, document: DocumentMeta) -> None:
"""
# TODO: This is a placeholder implementation. It should be replaced with a real implementation.

document_processor = DocumentProcessor()
document_processor = DocumentProcessor.from_config({DocumentType.TXT: DummyProvider()})
elements = await document_processor.process(document)
vectors = await self.embedder.embed_text([element.get_key() for element in elements])
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ class DocumentType(str, Enum):

MD = "md"
TXT = "txt"
PDF = "pdf"
CSV = "csv"
DOC = "doc"
DOCX = "docx"
HTML = "html"
EPUB = "epub"
XLSX = "xlsx"
XLS = "xls"
ORG = "org"
ODT = "odt"
PPT = "ppt"
PPTX = "pptx"
RST = "rst"
RTF = "rtf"
TSV = "tsv"
XML = "xml"

UNKNOWN = "unknown"


class DocumentMeta(BaseModel):
Expand Down Expand Up @@ -63,6 +81,22 @@ def create_text_document_from_literal(cls, content: str) -> "DocumentMeta":
source=LocalFileSource(path=Path(temp_file.name)),
)

@classmethod
def from_local_path(cls, local_path: Path) -> "DocumentMeta":
"""
Create a document metadata from a local path.
Args:
local_path: The local path to the document.
Returns:
The document metadata.
"""
return cls(
document_type=DocumentType(local_path.suffix[1:]),
source=LocalFileSource(path=local_path),
)


class Document(BaseModel):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Element(BaseModel, ABC):
"""

element_type: str
document: DocumentMeta
document_meta: DocumentMeta

_elements_registry: ClassVar[dict[str, type["Element"]]] = {}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,84 @@
"""
TODO: This module is mocked. To be deleted and replaced with a real implementation.
"""
import copy
from typing import Optional

from typing import List
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider

from ragbits.document_search.documents.document import DocumentMeta, TextDocument
from ragbits.document_search.documents.element import Element, TextElement
ProvidersConfig = dict[DocumentType, BaseProvider]

DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredProvider(),
DocumentType.MD: UnstructuredProvider(),
DocumentType.PDF: UnstructuredProvider(),
DocumentType.DOCX: UnstructuredProvider(),
DocumentType.DOC: UnstructuredProvider(),
DocumentType.PPTX: UnstructuredProvider(),
DocumentType.PPT: UnstructuredProvider(),
DocumentType.XLSX: UnstructuredProvider(),
DocumentType.XLS: UnstructuredProvider(),
DocumentType.CSV: UnstructuredProvider(),
DocumentType.HTML: UnstructuredProvider(),
DocumentType.EPUB: UnstructuredProvider(),
DocumentType.ORG: UnstructuredProvider(),
DocumentType.ODT: UnstructuredProvider(),
DocumentType.RST: UnstructuredProvider(),
DocumentType.RTF: UnstructuredProvider(),
DocumentType.TSV: UnstructuredProvider(),
DocumentType.XML: UnstructuredProvider(),
}


class DocumentProcessor:
"""
A class with an implementation of Document Processor, allowing to process documents.
TODO: probably this one should be replaced with something more generic,
allowing for passing different processors for different document types.
"""

async def process(self, document_meta: DocumentMeta) -> List[Element]:
def __init__(self, providers: dict[DocumentType, BaseProvider]):
self._providers = providers

@classmethod
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessor":
"""
Create a DocumentProcessor from a configuration. If the configuration is not provided, the default configuration
will be used. If the configuration is provided, it will be merged with the default configuration, overriding
the default values for the document types that are defined in the configuration.
Example of the configuration:
{
DocumentType.TXT: YourCustomProviderClass(),
DocumentType.PDF: UnstructuredProvider(),
}
Args:
providers_config: The dictionary with the providers configuration, mapping the document types to the
provider class.
Returns:
The DocumentProcessor.
"""
config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})

return cls(providers=config)

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""
Process the document.
Args:
document_meta: The document to process.
Returns:
The processed elements.
"""
document = await document_meta.fetch()
The list of elements extracted from the document.
if isinstance(document, TextDocument):
# for now just return the whole document as a single element
return [TextElement(document=document_meta, content=document.content)]
Raises:
ValueError: If the provider for the document type is not defined in the configuration.
"""
provider = self._providers.get(document_meta.document_type)
if provider is None:
raise ValueError(
f"Provider for {document_meta.document_type} is not defined in the configuration:" f" {self._providers}"
)

return []
return await provider.process(document_meta)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element


class DocumentTypeNotSupportedError(Exception):
"""Raised when the document type is not supported by the provider."""

def __init__(self, provider_name: str, document_type: DocumentType) -> None:
message = f"Document type {document_type} is not supported by the {provider_name}"
super().__init__(message)


class BaseProvider(ABC):
"""A base class for the document processing providers."""

SUPPORTED_DOCUMENT_TYPES: set[DocumentType]

@abstractmethod
async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document.
Args:
document_meta: The document to process.
Returns:
The list of elements extracted from the document.
"""

def validate_document_type(self, document_type: DocumentType) -> None:
"""Check if the provider supports the document type.
Args:
document_type: The document type.
Raises:
DocumentTypeNotSupportedError: If the document type is not supported.
"""
if document_type not in self.SUPPORTED_DOCUMENT_TYPES:
raise DocumentTypeNotSupportedError(provider_name=self.__class__.__name__, document_type=document_type)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from ragbits.document_search.documents.document import DocumentMeta, DocumentType, TextDocument
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.ingestion.providers.base import BaseProvider


class DummyProvider(BaseProvider):
"""This is a mock provider that returns a TextElement with the content of the document.
It should be used for testing purposes only.
TODO: Remove this provider after the implementation of the real providers.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT}

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the text document.
Args:
document_meta: The document to process.
Returns:
List with a single TextElement containing the content of the document.
"""
self.validate_document_type(document_meta.document_type)

document = await document_meta.fetch()
if isinstance(document, TextDocument):
return [TextElement(content=document.content, document_meta=document_meta)]
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.api import partition_via_api

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.ingestion.providers.base import BaseProvider

DEFAULT_PARTITION_KWARGS: dict = {
"strategy": "hi_res",
"languages": ["eng"],
"split_pdf_page": True,
"split_pdf_allow_failed": True,
"split_pdf_concurrency_level": 15,
}

UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY"
UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL"


class UnstructuredProvider(BaseProvider):
"""
A provider that uses the Unstructured API to process the documents.
"""

SUPPORTED_DOCUMENT_TYPES = {
DocumentType.TXT,
DocumentType.MD,
DocumentType.PDF,
DocumentType.DOCX,
DocumentType.DOC,
DocumentType.PPTX,
DocumentType.PPT,
DocumentType.XLSX,
DocumentType.XLS,
DocumentType.CSV,
DocumentType.HTML,
DocumentType.EPUB,
DocumentType.ORG,
DocumentType.ODT,
DocumentType.RST,
DocumentType.RTF,
DocumentType.TSV,
DocumentType.XML,
}

def __init__(self, partition_kwargs: Optional[dict] = None):
"""Initialize the UnstructuredProvider.
Args:
partition_kwargs: The additional arguments for the partitioning. Refer to the Unstructured API documentation
for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters
"""
self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document using the Unstructured API.
Args:
document_meta: The document to process.
Returns:
The list of elements extracted from the document.
Raises:
ValueError: If the UNSTRUCTURED_API_KEY or UNSTRUCTURED_API_URL environment variables are not set.
DocumentTypeNotSupportedError: If the document type is not supported.
"""
self.validate_document_type(document_meta.document_type)
if (api_key := os.getenv(UNSTRUCTURED_API_KEY_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_KEY_ENV} environment variable is not set")
if (api_url := os.getenv(UNSTRUCTURED_API_URL_ENV)) is None:
raise ValueError(f"{UNSTRUCTURED_API_URL_ENV} environment variable is not set")

document = await document_meta.fetch()

# TODO: Currently this is a blocking call. It should be made async.
elements = partition_via_api(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
api_key=api_key,
api_url=api_url,
**self.partition_kwargs,
)
return [_to_text_element(element, document_meta) for element in elements]


def _to_text_element(element: UnstructuredElement, document_meta: DocumentMeta) -> TextElement:
return TextElement(
document_meta=document_meta,
content=element.text,
)
5 changes: 5 additions & 0 deletions packages/ragbits-document-search/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def env_vars_not_set(env_vars: list[str]) -> bool:
return all([os.environ.get(env_var) is None for env_var in env_vars])
Empty file.
Loading

0 comments on commit 1d743fb

Please sign in to comment.