diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 9af75f08..6827312e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -8,7 +8,7 @@ from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element -from ragbits.document_search.documents.sources import GCSSource, LocalFileSource, Source +from ragbits.document_search.documents.sources import Source from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.base import BaseProvider from ragbits.document_search.retrieval.rephrasers import get_rephraser @@ -109,7 +109,7 @@ async def search(self, query: str, config: SearchConfig | None = None) -> list[E async def _process_document( self, - document: DocumentMeta | Document | LocalFileSource | GCSSource, + document: DocumentMeta | Document | Source, document_processor: BaseProvider | None = None, ) -> list[Element]: """ @@ -138,7 +138,7 @@ async def _process_document( async def ingest( self, - documents: Sequence[DocumentMeta | Document | LocalFileSource | GCSSource], + documents: Sequence[DocumentMeta | Document | Source], document_processor: BaseProvider | None = None, ) -> None: """ diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py index 6c9c58ea..e824cc03 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py @@ -1,10 +1,11 @@ import tempfile from enum import Enum from pathlib import Path +from typing import Annotated -from pydantic import BaseModel, Field +from pydantic import BaseModel -from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource +from ragbits.document_search.documents.sources import LocalFileSource, Source, SourceDiscriminator class DocumentType(str, Enum): @@ -42,7 +43,7 @@ class DocumentMeta(BaseModel): """ document_type: DocumentType - source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type") + source: Annotated[Source, SourceDiscriminator()] @property def id(self) -> str: @@ -101,7 +102,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta": ) @classmethod - async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta": + async def from_source(cls, source: Source) -> "DocumentMeta": """ Create a document metadata from a source. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index 0f043e0b..4798adb3 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -2,9 +2,11 @@ import tempfile from abc import ABC, abstractmethod from pathlib import Path -from typing import Literal +from typing import Any, ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, GetCoreSchemaHandler, computed_field +from pydantic.alias_generators import to_snake +from pydantic_core import CoreSchema, core_schema try: from datasets import load_dataset @@ -24,6 +26,23 @@ class Source(BaseModel, ABC): An object representing a source. """ + # Registry of all subclasses by their unique identifier + _registry: ClassVar[dict[str, type["Source"]]] = {} + + @classmethod + def class_identifier(cls) -> str: + """ + Get an identifier for the source type. + """ + return to_snake(cls.__name__) + + @computed_field + def source_type(self) -> str: + """ + Pydantic field based on the class identifier. + """ + return self.class_identifier() + @property @abstractmethod def id(self) -> str: @@ -43,13 +62,47 @@ async def fetch(self) -> Path: The path to the source. """ + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401 + Source._registry[cls.class_identifier()] = cls + super().__init_subclass__(**kwargs) + + +class SourceDiscriminator: + """ + Pydantic type annotation that automatically creates the correct subclass of Source based on the source_type field. + """ + + @staticmethod + def _create_instance(fields: dict[str, Any]) -> Source: + source_type = fields.get("source_type") + if source_type is None: + raise ValueError("source_type is required to create a Source instance") + + source_subclass = Source._registry.get(source_type) + if source_subclass is None: + raise ValueError(f"Unknown source type: {source_type}") + return source_subclass(**fields) + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: # noqa: ANN401 + create_instance_validator = core_schema.no_info_plain_validator_function(self._create_instance) + + return core_schema.json_or_python_schema( + json_schema=create_instance_validator, + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(Source), + create_instance_validator, + ] + ), + ) + class LocalFileSource(Source): """ An object representing a local file source. """ - source_type: Literal["local_file"] = "local_file" path: Path @property @@ -96,7 +149,6 @@ class GCSSource(Source): An object representing a GCS file source. """ - source_type: Literal["gcs"] = "gcs" bucket: str object_name: str @@ -170,7 +222,6 @@ class HuggingFaceSource(Source): An object representing a Hugging Face dataset source. """ - source_type: Literal["huggingface"] = "huggingface" path: str split: str = "train" row: int diff --git a/packages/ragbits-document-search/tests/unit/__init__.py b/packages/ragbits-document-search/tests/unit/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/packages/ragbits-document-search/tests/unit/test_elements.py b/packages/ragbits-document-search/tests/unit/test_elements.py index 7a9c44dc..5519c16f 100644 --- a/packages/ragbits-document-search/tests/unit/test_elements.py +++ b/packages/ragbits-document-search/tests/unit/test_elements.py @@ -20,7 +20,7 @@ def get_key(self) -> str: "foo": "bar", "document_meta": { "document_type": "txt", - "source": {"source_type": "local_file", "path": "/example/path"}, + "source": {"source_type": "local_file_source", "path": "/example/path"}, }, }, ) @@ -30,4 +30,4 @@ def get_key(self) -> str: assert element.foo == "bar" assert element.get_key() == "barbar" assert element.document_meta.document_type == DocumentType.TXT - assert element.document_meta.source.source_type == "local_file" + assert element.document_meta.source.source_type == "local_file_source" diff --git a/packages/ragbits-document-search/tests/unit/test_source_discriminator.py b/packages/ragbits-document-search/tests/unit/test_source_discriminator.py new file mode 100644 index 00000000..8a870331 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_source_discriminator.py @@ -0,0 +1,72 @@ +from pathlib import Path +from typing import Annotated + +import pydantic +import pytest + +from ragbits.document_search.documents.sources import LocalFileSource, Source, SourceDiscriminator + + +class ModelWithSource(pydantic.BaseModel): + source: Annotated[Source, SourceDiscriminator()] + foo: str + + +def test_serialization(): + source = LocalFileSource(path=Path("test")) + model = ModelWithSource(source=source, foo="bar") + assert model.model_dump() == { + "source": { + "source_type": "local_file_source", + "path": Path("test"), + }, + "foo": "bar", + } + assert model.model_dump_json() == '{"source":{"path":"test","source_type":"local_file_source"},"foo":"bar"}' + + +def test_deserialization_from_json(): + json = '{"source":{"path":"test","source_type":"local_file_source"},"foo":"bar"}' + model = ModelWithSource.model_validate_json(json) + assert isinstance(model.source, LocalFileSource) + assert model.source.path == Path("test") + assert model.foo == "bar" + + +def test_deserialization_from_dict(): + dict = { + "source": { + "source_type": "local_file_source", + "path": Path("test"), + }, + "foo": "bar", + } + model = ModelWithSource.model_validate(dict) + assert isinstance(model.source, LocalFileSource) + assert model.source.path == Path("test") + assert model.foo == "bar" + + +def test_deserialization_from_dict_with_invalid_source(): + dict = { + "source": { + "source_type": "invalid_source", + "path": Path("test"), + }, + "foo": "bar", + } + with pytest.raises(pydantic.ValidationError) as e: + ModelWithSource.model_validate(dict) + assert e.match("source") + + +def test_deserialization_from_dict_with_missing_source_type(): + dict = { + "source": { + "path": Path("test"), + }, + "foo": "bar", + } + with pytest.raises(pydantic.ValidationError) as e: + ModelWithSource.model_validate(dict) + assert e.match("source")