Skip to content

Commit

Permalink
feat(document-search): add chromadb support
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Sep 25, 2024
2 parents 0610aa1 + 584fd93 commit 6a942a1
Show file tree
Hide file tree
Showing 21 changed files with 1,237 additions and 107 deletions.
3 changes: 2 additions & 1 deletion .libraries-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pkg_resources
tiktoken
tiktoken
chardet
7 changes: 5 additions & 2 deletions packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ classifiers = [
]
dependencies = [
"numpy~=1.24.0",
"ragbits"
"ragbits",
"unstructured>=0.15.12",
]


[project.optional-dependencies]
chromadb = [
"chromadb~=0.4.24",
]
gcs = [
"gcloud-aio-storage~=9.3.0"
]

[tool.uv]
dev-dependencies = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers.base import Reranker
Expand Down Expand Up @@ -70,7 +71,7 @@ async def ingest_document(self, document: DocumentMeta) -> None:
"""
# TODO: This is a placeholder implementation. It should be replaced with a real implementation.

document_processor = DocumentProcessor()
document_processor = DocumentProcessor.from_config({DocumentType.TXT: DummyProvider()})
elements = await document_processor.process(document)
vectors = await self.embedder.embed_text([element.get_key() for element in elements])
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,32 @@

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, LocalFileSource


class DocumentType(str, Enum):
"""Types of documents that can be stored."""

MD = "md"
TXT = "txt"
PDF = "pdf"
CSV = "csv"
DOC = "doc"
DOCX = "docx"
HTML = "html"
EPUB = "epub"
XLSX = "xlsx"
XLS = "xls"
ORG = "org"
ODT = "odt"
PPT = "ppt"
PPTX = "pptx"
RST = "rst"
RTF = "rtf"
TSV = "tsv"
XML = "xml"

UNKNOWN = "unknown"


class DocumentMeta(BaseModel):
Expand All @@ -21,7 +39,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource] = Field(..., discriminator="source_type")
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand Down Expand Up @@ -63,6 +81,22 @@ def create_text_document_from_literal(cls, content: str) -> "DocumentMeta":
source=LocalFileSource(path=Path(temp_file.name)),
)

@classmethod
def from_local_path(cls, local_path: Path) -> "DocumentMeta":
"""
Create a document metadata from a local path.
Args:
local_path: The local path to the document.
Returns:
The document metadata.
"""
return cls(
document_type=DocumentType(local_path.suffix[1:]),
source=LocalFileSource(path=local_path),
)


class Document(BaseModel):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Element(BaseModel, ABC):
"""

element_type: str
document: DocumentMeta
document_meta: DocumentMeta

_elements_registry: ClassVar[dict[str, type["Element"]]] = {}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal

from pydantic import BaseModel

try:
from gcloud.aio.storage import Storage

HAS_GCLOUD_AIO = True
except ImportError:
HAS_GCLOUD_AIO = False

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"


class Source(BaseModel, ABC):
"""
Expand Down Expand Up @@ -54,3 +65,59 @@ async def fetch(self) -> Path:
The local path to the object fetched from the source.
"""
return self.path


class GCSSource(Source):
"""
An object representing a GCS file source.
"""

source_type: Literal["gcs"] = "gcs"

bucket: str
object_name: str

def get_id(self) -> str:
"""
Get unique identifier of the object in the source.
Returns:
Unique identifier.
"""
return f"gcs:gs://{self.bucket}/{self.object_name}"

async def fetch(self) -> Path:
"""
Fetch the file from Google Cloud Storage and store it locally.
The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally,
it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS.
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment
variable is not set, a temporary directory is used.
Returns:
Path: The local path to the downloaded file.
Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
local_dir = Path(tempfile.gettempdir())
else:
local_dir = Path(local_dir_env)

bucket_local_dir = local_dir / self.bucket
bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client:
content = await client.download(self.bucket, self.object_name)
with open(path, mode="wb+") as file_object:
file_object.write(content)

return path
Original file line number Diff line number Diff line change
@@ -1,35 +1,84 @@
"""
TODO: This module is mocked. To be deleted and replaced with a real implementation.
"""
import copy
from typing import Optional

from typing import List
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.ingestion.providers.unstructured import UnstructuredProvider

from ragbits.document_search.documents.document import DocumentMeta, TextDocument
from ragbits.document_search.documents.element import Element, TextElement
ProvidersConfig = dict[DocumentType, BaseProvider]

DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DocumentType.TXT: UnstructuredProvider(),
DocumentType.MD: UnstructuredProvider(),
DocumentType.PDF: UnstructuredProvider(),
DocumentType.DOCX: UnstructuredProvider(),
DocumentType.DOC: UnstructuredProvider(),
DocumentType.PPTX: UnstructuredProvider(),
DocumentType.PPT: UnstructuredProvider(),
DocumentType.XLSX: UnstructuredProvider(),
DocumentType.XLS: UnstructuredProvider(),
DocumentType.CSV: UnstructuredProvider(),
DocumentType.HTML: UnstructuredProvider(),
DocumentType.EPUB: UnstructuredProvider(),
DocumentType.ORG: UnstructuredProvider(),
DocumentType.ODT: UnstructuredProvider(),
DocumentType.RST: UnstructuredProvider(),
DocumentType.RTF: UnstructuredProvider(),
DocumentType.TSV: UnstructuredProvider(),
DocumentType.XML: UnstructuredProvider(),
}


class DocumentProcessor:
"""
A class with an implementation of Document Processor, allowing to process documents.
TODO: probably this one should be replaced with something more generic,
allowing for passing different processors for different document types.
"""

async def process(self, document_meta: DocumentMeta) -> List[Element]:
def __init__(self, providers: dict[DocumentType, BaseProvider]):
self._providers = providers

@classmethod
def from_config(cls, providers_config: Optional[ProvidersConfig] = None) -> "DocumentProcessor":
"""
Create a DocumentProcessor from a configuration. If the configuration is not provided, the default configuration
will be used. If the configuration is provided, it will be merged with the default configuration, overriding
the default values for the document types that are defined in the configuration.
Example of the configuration:
{
DocumentType.TXT: YourCustomProviderClass(),
DocumentType.PDF: UnstructuredProvider(),
}
Args:
providers_config: The dictionary with the providers configuration, mapping the document types to the
provider class.
Returns:
The DocumentProcessor.
"""
config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})

return cls(providers=config)

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""
Process the document.
Args:
document_meta: The document to process.
Returns:
The processed elements.
"""
document = await document_meta.fetch()
The list of elements extracted from the document.
if isinstance(document, TextDocument):
# for now just return the whole document as a single element
return [TextElement(document=document_meta, content=document.content)]
Raises:
ValueError: If the provider for the document type is not defined in the configuration.
"""
provider = self._providers.get(document_meta.document_type)
if provider is None:
raise ValueError(
f"Provider for {document_meta.document_type} is not defined in the configuration:" f" {self._providers}"
)

return []
return await provider.process(document_meta)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element


class DocumentTypeNotSupportedError(Exception):
"""Raised when the document type is not supported by the provider."""

def __init__(self, provider_name: str, document_type: DocumentType) -> None:
message = f"Document type {document_type} is not supported by the {provider_name}"
super().__init__(message)


class BaseProvider(ABC):
"""A base class for the document processing providers."""

SUPPORTED_DOCUMENT_TYPES: set[DocumentType]

@abstractmethod
async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the document.
Args:
document_meta: The document to process.
Returns:
The list of elements extracted from the document.
"""

def validate_document_type(self, document_type: DocumentType) -> None:
"""Check if the provider supports the document type.
Args:
document_type: The document type.
Raises:
DocumentTypeNotSupportedError: If the document type is not supported.
"""
if document_type not in self.SUPPORTED_DOCUMENT_TYPES:
raise DocumentTypeNotSupportedError(provider_name=self.__class__.__name__, document_type=document_type)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from ragbits.document_search.documents.document import DocumentMeta, DocumentType, TextDocument
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.ingestion.providers.base import BaseProvider


class DummyProvider(BaseProvider):
"""This is a mock provider that returns a TextElement with the content of the document.
It should be used for testing purposes only.
TODO: Remove this provider after the implementation of the real providers.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.TXT}

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""Process the text document.
Args:
document_meta: The document to process.
Returns:
List with a single TextElement containing the content of the document.
"""
self.validate_document_type(document_meta.document_type)

document = await document_meta.fetch()
if isinstance(document, TextDocument):
return [TextElement(content=document.content, document_meta=document_meta)]
return []
Loading

0 comments on commit 6a942a1

Please sign in to comment.