Skip to content

Commit

Permalink
feat: document search sources string resolver (#264)
Browse files Browse the repository at this point in the history
Co-authored-by: Ludwik Trammer <[email protected]>
Co-authored-by: kdziedzic68 <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent 66abadd commit 839271e
Show file tree
Hide file tree
Showing 6 changed files with 635 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.documents.source_resolver import SourceResolver
from ragbits.document_search.documents.sources import Source
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.processor_strategies import (
Expand Down Expand Up @@ -197,19 +198,27 @@ async def search(self, query: str, config: SearchConfig | None = None) -> Sequen
@traceable
async def ingest(
self,
documents: Sequence[DocumentMeta | Document | Source],
documents: str | Sequence[DocumentMeta | Document | Source],
document_processor: BaseProvider | None = None,
) -> None:
"""
Ingest multiple documents.
"""Ingest documents into the search index.
Args:
documents: The documents or metadata of the documents to ingest.
documents: Either:
- A sequence of `Document`, `DocumentMetadata`, or `Source` objects
- A source-specific URI string (e.g., "gcs://bucket/*") to specify source location(s), for example:
- "file:///path/to/files/*.txt"
- "gcs://bucket/folder/*"
- "huggingface://dataset/split/row"
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
if isinstance(documents, str):
sources: Sequence[DocumentMeta | Document | Source] = await SourceResolver.resolve(documents)
else:
sources = documents
elements = await self.processing_strategy.process_documents(
documents, self.document_processor_router, document_processor
sources, self.document_processor_router, document_processor
)
await self._remove_entries_with_same_sources(elements)
await self.insert_elements(elements)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
from ragbits.document_search.documents.sources import Source


class SourceResolver:
"""Registry for source URI protocols and their handlers.
This class provides a mechanism to register and resolve different source protocols (like 'file://', 'gcs://', etc.)
to their corresponding Source implementations.
Example:
>>> SourceResolver.register_protocol("gcs", GCSSource)
>>> sources = await SourceResolver.resolve("gcs://my-bucket/path/to/files/*")
"""

_protocol_handlers: ClassVar[dict[str, type["Source"]]] = {}

@classmethod
def register_protocol(cls, protocol: str, source_class: type["Source"]) -> None:
"""Register a source class for a specific protocol.
Args:
protocol: The protocol identifier (e.g., 'file', 'gcs', 's3')
source_class: The Source subclass that handles this protocol
"""
cls._protocol_handlers[protocol] = source_class

@classmethod
async def resolve(cls, uri: str) -> Sequence["Source"]:
"""Resolve a URI into a sequence of Source objects.
The URI format should be: protocol://path
For example:
- file:///path/to/files/*
- gcs://bucket/prefix/*
Args:
uri: The URI to resolve
Returns:
A sequence of Source objects
Raises:
ValueError: If the URI format is invalid or the protocol is not supported
"""
try:
protocol, path = uri.split("://", 1)
except ValueError as err:
raise ValueError(f"Invalid URI format: {uri}. Expected format: protocol://path") from err

if protocol not in cls._protocol_handlers:
supported = ", ".join(sorted(cls._protocol_handlers.keys()))
raise ValueError(f"Unsupported protocol: {protocol}. Supported protocols are: {supported}")

handler_class = cls._protocol_handlers[protocol]
return await handler_class.from_uri(path)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import suppress
from pathlib import Path
from typing import Any, ClassVar
Expand All @@ -11,14 +12,16 @@
from pydantic_core import CoreSchema, core_schema

with suppress(ImportError):
from gcloud.aio.storage import Storage
from gcloud.aio.storage import Storage as StorageClient

with suppress(ImportError):
from datasets import load_dataset
from datasets.exceptions import DatasetNotFoundError


from ragbits.core.utils.decorators import requires_dependencies
from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError
from ragbits.document_search.documents.source_resolver import SourceResolver

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR"

Expand All @@ -30,6 +33,7 @@ class Source(BaseModel, ABC):

# Registry of all subclasses by their unique identifier
_registry: ClassVar[dict[str, type["Source"]]] = {}
protocol: ClassVar[str | None] = None

@classmethod
def class_identifier(cls) -> str:
Expand Down Expand Up @@ -64,10 +68,34 @@ async def fetch(self) -> Path:
The path to the source.
"""

@classmethod
@abstractmethod
async def from_uri(cls, path: str) -> Sequence["Source"]:
"""Create Source instances from a URI path.
The path can contain glob patterns (asterisks) to match multiple sources, but pattern support
varies by source type. Each source implementation defines which patterns it supports:
- LocalFileSource: Supports full glob patterns ('*', '**', etc.) via Path.glob
- GCSSource: Supports simple prefix matching with '*' at the end of path
- HuggingFaceSource: Does not support glob patterns
Args:
path: The path part of the URI (after protocol://). Pattern support depends on source type.
Returns:
A sequence of Source objects matching the path pattern
Raises:
ValueError: If the path contains unsupported pattern for this source type
"""

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401
Source._registry[cls.class_identifier()] = cls
super().__init_subclass__(**kwargs)
Source._registry[cls.class_identifier()] = cls
if cls.protocol is not None:
SourceResolver.register_protocol(cls.protocol, cls)


class SourceDiscriminator:
Expand Down Expand Up @@ -112,6 +140,7 @@ class LocalFileSource(Source):
"""

path: Path
protocol: ClassVar[str] = "file"

@property
def id(self) -> str:
Expand Down Expand Up @@ -151,14 +180,71 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou
"""
return [cls(path=file_path) for file_path in path.glob(file_pattern)]

@classmethod
async def from_uri(cls, path: str) -> Sequence["LocalFileSource"]:
"""Create LocalFileSource instances from a URI path.
Supports full glob patterns via Path.glob:
- "**/*.txt" - all .txt files in any subdirectory
- "*.py" - all Python files in the current directory
- "**/*" - all files in any subdirectory
- '?' matches exactly one character
Args:
path: The path part of the URI (after file://). Pattern support depends on source type.
Returns:
A sequence of LocalFileSource objects
"""
path_obj: Path = Path(path)
base_path, pattern = cls._split_path_and_pattern(path=path_obj)
if base_path.is_file():
return [cls(path=base_path)]
if not pattern:
return []
return [cls(path=f) for f in base_path.glob(pattern) if f.is_file()]

@staticmethod
def _split_path_and_pattern(path: Path) -> tuple[Path, str]:
parts = path.parts
# Find the first part containing '*' or '?'
for i, part in enumerate(parts):
if "*" in part or "?" in part:
base_path = Path(*parts[:i])
pattern = str(Path(*parts[i:]))
return base_path, pattern
return path, ""


class GCSSource(Source):
"""
An object representing a GCS file source.
"""
"""An object representing a GCS file source."""

bucket: str
object_name: str
protocol: ClassVar[str] = "gcs"
_storage: "StorageClient | None" = None # Storage client for dependency injection

@classmethod
def set_storage(cls, storage: "StorageClient | None") -> None:
"""Set the storage client for all instances.
Args:
storage: The `gcloud-aio-storage` `Storage` object to use as the storage client.
By default, the object will be created automatically.
"""
cls._storage = storage

@classmethod
@requires_dependencies(["gcloud.aio.storage"], "gcs")
async def _get_storage(cls) -> "StorageClient":
"""Get the storage client.
Returns:
The storage client to use. If none was injected, creates a new one.
"""
if cls._storage is None:
cls._storage = StorageClient()
return cls._storage

@property
def id(self) -> str:
Expand Down Expand Up @@ -192,8 +278,8 @@ async def fetch(self) -> Path:
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client: # type: ignore
# TODO: Add error handling for download
storage = await self._get_storage()
async with storage as client:
content = await client.download(self.bucket, self.object_name)
Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True)
with open(path, mode="wb+") as file_object:
Expand All @@ -204,8 +290,7 @@ async def fetch(self) -> Path:
@classmethod
@requires_dependencies(["gcloud.aio.storage"], "gcs")
async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
"""
List all sources in the given GCS bucket, matching the prefix.
"""List all sources in the given GCS bucket, matching the prefix.
Args:
bucket: The GCS bucket.
Expand All @@ -217,12 +302,47 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
async with Storage() as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources
async with await cls._get_storage() as storage:
result = await storage.list_objects(bucket, params={"prefix": prefix})
items = result.get("items", [])
return [cls(bucket=bucket, object_name=item["name"]) for item in items]

@classmethod
async def from_uri(cls, path: str) -> Sequence["GCSSource"]:
"""Create GCSSource instances from a URI path.
Supports simple prefix matching with '*' at the end of path.
For example:
- "bucket/folder/*" - matches all files in the folder
- "bucket/folder/prefix*" - matches all files starting with prefix
More complex patterns like '**' or '?' are not supported.
Args:
path: The path part of the URI (after gcs://). Can end with '*' for pattern matching.
Returns:
A sequence of GCSSource objects matching the pattern
Raises:
ValueError: If an unsupported pattern is used
"""
if "**" in path or "?" in path:
raise ValueError(
"GCSSource only supports '*' at the end of path. Patterns like '**' or '?' are not supported."
)

# Split into bucket and prefix
bucket, prefix = path.split("/", 1) if "/" in path else (path, "")

if "*" in prefix:
if not prefix.endswith("*"):
raise ValueError(f"GCSSource only supports '*' at the end of path. Invalid pattern: {prefix}")
# Remove the trailing * for GCS prefix listing
prefix = prefix[:-1]
return await cls.list_sources(bucket=bucket, prefix=prefix)

return [cls(bucket=bucket, object_name=prefix)]


class HuggingFaceSource(Source):
Expand All @@ -233,6 +353,7 @@ class HuggingFaceSource(Source):
path: str
split: str = "train"
row: int
protocol: ClassVar[str] = "huggingface"

@property
def id(self) -> str:
Expand Down Expand Up @@ -280,6 +401,33 @@ async def fetch(self) -> Path:

return path

@classmethod
async def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]:
"""Create HuggingFaceSource instances from a URI path.
Pattern matching is not supported. The path must be in the format:
huggingface://dataset_path/split/row
Args:
path: The path part of the URI (after huggingface://)
Returns:
A sequence containing a single HuggingFaceSource
Raises:
ValueError: If the path contains patterns or has invalid format
"""
if "*" in path or "?" in path:
raise ValueError(
"HuggingFaceSource does not support patterns. Path must be in format: dataset_path/split/row"
)

try:
dataset_path, split, row = path.split("/")
return [cls(path=dataset_path, split=split, row=int(row))]
except ValueError as err:
raise ValueError("Invalid HuggingFace path format. Expected: dataset_path/split/row") from err

@classmethod
async def list_sources(cls, path: str, split: str) -> list["HuggingFaceSource"]:
"""
Expand Down
Loading

0 comments on commit 839271e

Please sign in to comment.