Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sources): add hf data source #106

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pylint: disable=missing-function-docstring,missing-return-doc

import asyncio
from functools import wraps
from importlib.util import find_spec
from typing import Callable, ParamSpec, TypeVar

_P = ParamSpec("_P")
_T = TypeVar("_T")


def requires_dependencies(
micpst marked this conversation as resolved.
Show resolved Hide resolved
dependencies: str | list[str],
extras: str | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Decorator to check if the dependencies are installed before running the function.

Args:
dependencies: The dependencies to check.
extras: The extras to install.

Returns:
The decorated function.
"""
if isinstance(dependencies, str):
dependencies = [dependencies]

def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def run_check() -> None:
missing_dependencies = [dependency for dependency in dependencies if not find_spec(dependency)]
if len(missing_dependencies) > 0:
missing_deps = ", ".join(missing_dependencies)
install_cmd = (
f"pip install 'ragbits[{extras}]'" if extras else f"pip install {' '.join(missing_dependencies)}"
)
raise ImportError(
f"Following dependencies are missing: {missing_deps}. Please install them using `{install_cmd}`."
)

@wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return func(*args, **kwargs)

@wraps(func)
async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return await func(*args, **kwargs) # type: ignore

if asyncio.iscoroutinefunction(func):
return wrapper_async # type: ignore
return wrapper

return decorator
5 changes: 4 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ dependencies = [
"numpy~=1.24.0",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
"ragbits-core==0.1.0"
"ragbits-core==0.1.0",
]

[project.optional-dependencies]
gcs = [
"gcloud-aio-storage~=9.3.0"
]
huggingface = [
"datasets~=3.0.1",
]

[tool.uv]
dev-dependencies = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import tempfile
from enum import Enum
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import GCSSource, LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource


class DocumentType(str, Enum):
Expand Down Expand Up @@ -39,7 +38,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")
source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand All @@ -49,7 +48,7 @@ def id(self) -> str:
Returns:
The document ID.
"""
return self.source.get_id()
return self.source.id

async def fetch(self) -> "Document":
"""
Expand Down Expand Up @@ -98,7 +97,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
)

@classmethod
async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta":
async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta":
"""
Create a document metadata from a source.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class SourceError(Exception):
"""
Class for all exceptions raised by the document source.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class SourceConnectionError(SourceError):
"""
Raised when there is an error connecting to the document source.
"""

def __init__(self) -> None:
super().__init__("Connection error.")


class SourceNotFoundError(SourceError):
"""
Raised when the document is not found.
"""

def __init__(self, source_id: str) -> None:
super().__init__(f"Source with ID {source_id} not found.")
self.source_id = source_id
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
from pydantic import BaseModel

try:
from datasets import load_dataset
from datasets.exceptions import DatasetNotFoundError
from gcloud.aio.storage import Storage

HAS_GCLOUD_AIO = True
except ImportError:
HAS_GCLOUD_AIO = False
pass

from ragbits.core.utils.decorators import requires_dependencies
from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"
LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR"


class Source(BaseModel, ABC):
"""
An object representing a source.
"""

@property
@abstractmethod
def get_id(self) -> str:
def id(self) -> str:
"""
Get the source ID.

Expand All @@ -48,7 +52,8 @@ class LocalFileSource(Source):
source_type: Literal["local_file"] = "local_file"
path: Path

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Expand All @@ -63,7 +68,12 @@ async def fetch(self) -> Path:

Returns:
The local path to the object fetched from the source.

Raises:
SourceNotFoundError: If the source document is not found.
"""
if not self.path.is_file():
raise SourceNotFoundError(source_id=self.id)
return self.path


Expand All @@ -73,11 +83,11 @@ class GCSSource(Source):
"""

source_type: Literal["gcs"] = "gcs"

bucket: str
object_name: str

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Expand All @@ -86,39 +96,107 @@ def get_id(self) -> str:
"""
return f"gcs:gs://{self.bucket}/{self.object_name}"

@requires_dependencies(["gcloud.aio.storage"], "gcs")
Copy link
Collaborator Author

@micpst micpst Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing a new way of handling optional dependencies. Since managing error messages from extra deps is starting to get tedious, I decided to wrap them in @requires_dependencies decorator that handles them automatically.

If you like it, we can add a new issue to refactor the handling of optional deps, there're several other places where this decorator should be applied.

async def fetch(self) -> Path:
"""
Fetch the file from Google Cloud Storage and store it locally.

The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally,
it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS.
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.

Returns:
Path: The local path to the downloaded file.

Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
ImportError: If the 'gcp' extra is not installed.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
local_dir = Path(tempfile.gettempdir()) / "ragbits"
else:
local_dir = Path(local_dir_env)

local_dir = get_local_storage_dir()
bucket_local_dir = local_dir / self.bucket
bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client:
async with Storage() as client: # type: ignore
# TODO: Add error handling for download
content = await client.download(self.bucket, self.object_name)
Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True)
with open(path, mode="wb+") as file_object:
file_object.write(content)

return path


class HuggingFaceSource(Source):
"""
An object representing a Hugging Face dataset source.
"""

source_type: Literal["huggingface"] = "huggingface"
path: str
split: str = "train"
row: int

@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Returns:
Unique identifier.
"""
return f"huggingface:{self.path}/{self.split}/{self.row}"

@requires_dependencies(["datasets"], "huggingface")
async def fetch(self) -> Path:
"""
Fetch the file from Hugging Face and store it locally.

Returns:
Path: The local path to the downloaded file.

Raises:
ImportError: If the 'huggingface' extra is not installed.
SourceConnectionError: If the source connection fails.
SourceNotFoundError: If the source document is not found.
"""
try:
dataset = load_dataset(self.path, split=self.split, streaming=True) # type: ignore
except ConnectionError as exc:
raise SourceConnectionError() from exc
except DatasetNotFoundError as exc: # type: ignore
raise SourceNotFoundError(source_id=self.id) from exc

try:
data = next(iter(dataset.skip(self.row).take(1))) # type: ignore
except StopIteration as exc:
raise SourceNotFoundError(source_id=self.id) from exc

storage_dir = get_local_storage_dir()
source_dir = storage_dir / Path(data["source"]).parent
source_dir.mkdir(parents=True, exist_ok=True)
path = storage_dir / data["source"]

if not path.is_file():
with open(path, mode="w", encoding="utf-8") as file:
file.write(data["content"])

return path


def get_local_storage_dir() -> Path:
"""
Get the local storage directory.

The local storage directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.

Returns:
The local storage directory.
"""
return (
Path(local_dir_env)
if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is not None
else Path(tempfile.gettempdir()) / "ragbits"
)
45 changes: 45 additions & 0 deletions packages/ragbits-document-search/tests/integration/test_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from pathlib import Path

import pytest

from ragbits.document_search.documents.exceptions import SourceNotFoundError
from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, HuggingFaceSource

from ..helpers import env_vars_not_set

os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix()

HF_TOKEN_ENV = "HF_TOKEN" # nosec
HF_DATASET_PATH = "micpst/hf-docs"


@pytest.mark.skipif(
env_vars_not_set([HF_TOKEN_ENV]),
reason="Hugging Face environment variables not set",
)
async def test_huggingface_source_fetch() -> None:
source = HuggingFaceSource(path=HF_DATASET_PATH, row=0)
path = await source.fetch()

assert path.is_file()
assert path.name == "README.md"
assert (
path.read_text()
== " `tokenizers-linux-x64-musl`\n\nThis is the **x86_64-unknown-linux-musl** binary for `tokenizers`\n"
)

path.unlink()


@pytest.mark.skipif(
env_vars_not_set([HF_TOKEN_ENV]),
reason="Hugging Face environment variables not set",
)
async def test_huggingface_source_fetch_not_found() -> None:
source = HuggingFaceSource(path=HF_DATASET_PATH, row=1000)

with pytest.raises(SourceNotFoundError) as exc:
await source.fetch()

assert str(exc.value) == "Source with ID huggingface:micpst/hf-docs/train/1000 not found."
22 changes: 0 additions & 22 deletions packages/ragbits-document-search/tests/unit/test_gcs_source.py

This file was deleted.

Loading
Loading