Skip to content

Commit

Permalink
refactor(document-search): change the type in from_source method to S…
Browse files Browse the repository at this point in the history
…ource (#156)
  • Loading branch information
ludwiktrammer authored Oct 29, 2024
1 parent f07ae21 commit a60cdfe
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.documents.sources import GCSSource, LocalFileSource, Source
from ragbits.document_search.documents.sources import Source
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.retrieval.rephrasers import get_rephraser
Expand Down Expand Up @@ -109,7 +109,7 @@ async def search(self, query: str, config: SearchConfig | None = None) -> list[E

async def _process_document(
self,
document: DocumentMeta | Document | LocalFileSource | GCSSource,
document: DocumentMeta | Document | Source,
document_processor: BaseProvider | None = None,
) -> list[Element]:
"""
Expand Down Expand Up @@ -138,7 +138,7 @@ async def _process_document(

async def ingest(
self,
documents: Sequence[DocumentMeta | Document | LocalFileSource | GCSSource],
documents: Sequence[DocumentMeta | Document | Source],
document_processor: BaseProvider | None = None,
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import tempfile
from enum import Enum
from pathlib import Path
from typing import Annotated

from pydantic import BaseModel, Field
from pydantic import BaseModel

from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource
from ragbits.document_search.documents.sources import LocalFileSource, Source, SourceDiscriminator


class DocumentType(str, Enum):
Expand Down Expand Up @@ -42,7 +43,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type")
source: Annotated[Source, SourceDiscriminator()]

@property
def id(self) -> str:
Expand Down Expand Up @@ -101,7 +102,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
)

@classmethod
async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta":
async def from_source(cls, source: Source) -> "DocumentMeta":
"""
Create a document metadata from a source.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal
from typing import Any, ClassVar

from pydantic import BaseModel
from pydantic import BaseModel, GetCoreSchemaHandler, computed_field
from pydantic.alias_generators import to_snake
from pydantic_core import CoreSchema, core_schema

try:
from datasets import load_dataset
Expand All @@ -24,6 +26,23 @@ class Source(BaseModel, ABC):
An object representing a source.
"""

# Registry of all subclasses by their unique identifier
_registry: ClassVar[dict[str, type["Source"]]] = {}

@classmethod
def class_identifier(cls) -> str:
"""
Get an identifier for the source type.
"""
return to_snake(cls.__name__)

@computed_field
def source_type(self) -> str:
"""
Pydantic field based on the class identifier.
"""
return self.class_identifier()

@property
@abstractmethod
def id(self) -> str:
Expand All @@ -43,13 +62,47 @@ async def fetch(self) -> Path:
The path to the source.
"""

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401
Source._registry[cls.class_identifier()] = cls
super().__init_subclass__(**kwargs)


class SourceDiscriminator:
"""
Pydantic type annotation that automatically creates the correct subclass of Source based on the source_type field.
"""

@staticmethod
def _create_instance(fields: dict[str, Any]) -> Source:
source_type = fields.get("source_type")
if source_type is None:
raise ValueError("source_type is required to create a Source instance")

source_subclass = Source._registry.get(source_type)
if source_subclass is None:
raise ValueError(f"Unknown source type: {source_type}")
return source_subclass(**fields)

def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: # noqa: ANN401
create_instance_validator = core_schema.no_info_plain_validator_function(self._create_instance)

return core_schema.json_or_python_schema(
json_schema=create_instance_validator,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(Source),
create_instance_validator,
]
),
)


class LocalFileSource(Source):
"""
An object representing a local file source.
"""

source_type: Literal["local_file"] = "local_file"
path: Path

@property
Expand Down Expand Up @@ -96,7 +149,6 @@ class GCSSource(Source):
An object representing a GCS file source.
"""

source_type: Literal["gcs"] = "gcs"
bucket: str
object_name: str

Expand Down Expand Up @@ -170,7 +222,6 @@ class HuggingFaceSource(Source):
An object representing a Hugging Face dataset source.
"""

source_type: Literal["huggingface"] = "huggingface"
path: str
split: str = "train"
row: int
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions packages/ragbits-document-search/tests/unit/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_key(self) -> str:
"foo": "bar",
"document_meta": {
"document_type": "txt",
"source": {"source_type": "local_file", "path": "/example/path"},
"source": {"source_type": "local_file_source", "path": "/example/path"},
},
},
)
Expand All @@ -30,4 +30,4 @@ def get_key(self) -> str:
assert element.foo == "bar"
assert element.get_key() == "barbar"
assert element.document_meta.document_type == DocumentType.TXT
assert element.document_meta.source.source_type == "local_file"
assert element.document_meta.source.source_type == "local_file_source"
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from pathlib import Path
from typing import Annotated

import pydantic
import pytest

from ragbits.document_search.documents.sources import LocalFileSource, Source, SourceDiscriminator


class ModelWithSource(pydantic.BaseModel):
source: Annotated[Source, SourceDiscriminator()]
foo: str


def test_serialization():
source = LocalFileSource(path=Path("test"))
model = ModelWithSource(source=source, foo="bar")
assert model.model_dump() == {
"source": {
"source_type": "local_file_source",
"path": Path("test"),
},
"foo": "bar",
}
assert model.model_dump_json() == '{"source":{"path":"test","source_type":"local_file_source"},"foo":"bar"}'


def test_deserialization_from_json():
json = '{"source":{"path":"test","source_type":"local_file_source"},"foo":"bar"}'
model = ModelWithSource.model_validate_json(json)
assert isinstance(model.source, LocalFileSource)
assert model.source.path == Path("test")
assert model.foo == "bar"


def test_deserialization_from_dict():
dict = {
"source": {
"source_type": "local_file_source",
"path": Path("test"),
},
"foo": "bar",
}
model = ModelWithSource.model_validate(dict)
assert isinstance(model.source, LocalFileSource)
assert model.source.path == Path("test")
assert model.foo == "bar"


def test_deserialization_from_dict_with_invalid_source():
dict = {
"source": {
"source_type": "invalid_source",
"path": Path("test"),
},
"foo": "bar",
}
with pytest.raises(pydantic.ValidationError) as e:
ModelWithSource.model_validate(dict)
assert e.match("source")


def test_deserialization_from_dict_with_missing_source_type():
dict = {
"source": {
"path": Path("test"),
},
"foo": "bar",
}
with pytest.raises(pydantic.ValidationError) as e:
ModelWithSource.model_validate(dict)
assert e.match("source")

0 comments on commit a60cdfe

Please sign in to comment.