Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sources): add hf data source #106

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pylint: disable=missing-function-docstring,missing-return-doc

import asyncio
from functools import wraps
from importlib.util import find_spec
from typing import Callable, ParamSpec, TypeVar

_P = ParamSpec("_P")
_T = TypeVar("_T")


def requires_dependencies(
micpst marked this conversation as resolved.
Show resolved Hide resolved
dependencies: str | list[str],
extras: str | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Decorator to check if the dependencies are installed before running the function.

Args:
dependencies: The dependencies to check.
extras: The extras to install.

Returns:
The decorated function.
"""
if isinstance(dependencies, str):
dependencies = [dependencies]

def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def run_check() -> None:
missing_dependencies = [dependency for dependency in dependencies if not find_spec(dependency)]
if len(missing_dependencies) > 0:
missing_deps = ", ".join(missing_dependencies)
install_cmd = (
f"pip install 'ragbits[{extras}]'" if extras else f"pip install {' '.join(missing_dependencies)}"
)
raise ImportError(
f"Following dependencies are missing: {missing_deps}. Please install them using `{install_cmd}`."
)

@wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return func(*args, **kwargs)

@wraps(func)
async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return await func(*args, **kwargs) # type: ignore

if asyncio.iscoroutinefunction(func):
return wrapper_async # type: ignore
return wrapper

return decorator
5 changes: 4 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ dependencies = [
"numpy~=1.24.0",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
"ragbits-core==0.1.0"
"ragbits-core==0.1.0",
]

[project.optional-dependencies]
gcs = [
"gcloud-aio-storage~=9.3.0"
]
huggingface = [
"datasets~=3.0.1",
]

[tool.uv]
dev-dependencies = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import tempfile
from enum import Enum
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import GCSSource, LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource


class DocumentType(str, Enum):
Expand Down Expand Up @@ -39,7 +38,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")
source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand All @@ -49,7 +48,7 @@ def id(self) -> str:
Returns:
The document ID.
"""
return self.source.get_id()
return self.source.id

async def fetch(self) -> "Document":
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from pydantic import BaseModel

try:
from datasets import load_dataset
from gcloud.aio.storage import Storage

HAS_GCLOUD_AIO = True
except ImportError:
HAS_GCLOUD_AIO = False
pass

from ragbits.core.utils.decorators import requires_dependencies

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"
LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR"


class Source(BaseModel, ABC):
"""
An object representing a source.
"""

@property
@abstractmethod
def get_id(self) -> str:
def id(self) -> str:
"""
Get the source ID.

Expand All @@ -48,7 +50,8 @@ class LocalFileSource(Source):
source_type: Literal["local_file"] = "local_file"
path: Path

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Expand All @@ -73,11 +76,11 @@ class GCSSource(Source):
"""

source_type: Literal["gcs"] = "gcs"

bucket: str
object_name: str

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Expand All @@ -86,39 +89,94 @@ def get_id(self) -> str:
"""
return f"gcs:gs://{self.bucket}/{self.object_name}"

@requires_dependencies(["gcloud.aio.storage"], "gcs")
Copy link
Collaborator Author

@micpst micpst Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing a new way of handling optional dependencies. Since managing error messages from extra deps is starting to get tedious, I decided to wrap them in @requires_dependencies decorator that handles them automatically.

If you like it, we can add a new issue to refactor the handling of optional deps, there're several other places where this decorator should be applied.

async def fetch(self) -> Path:
"""
Fetch the file from Google Cloud Storage and store it locally.

The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally,
it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS.
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.

Returns:
Path: The local path to the downloaded file.

Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
ImportError: If the required 'gcloud' package is not installed.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
local_dir = Path(tempfile.gettempdir()) / "ragbits"
else:
local_dir = Path(local_dir_env)

local_dir = get_local_storage_dir()
bucket_local_dir = local_dir / self.bucket
bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client:
async with Storage() as client: # type: ignore
content = await client.download(self.bucket, self.object_name)
Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True)
with open(path, mode="wb+") as file_object:
file_object.write(content)

return path


class HuggingFaceSource(Source):
"""
An object representing a Hugging Face dataset source.
"""

source_type: Literal["huggingface"] = "huggingface"
hf_path: str

@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.

Returns:
Unique identifier.
"""
return f"huggingface:{self.hf_path}"

@requires_dependencies(["datasets"], "huggingface")
async def fetch(self) -> Path:
"""
Fetch the file from Hugging Face and store it locally.

Returns:
Path: The local path to the downloaded file.

Raises:
ImportError: If the required 'datasets' package is not installed.
"""
hf_path, row = self.hf_path.split("?row=")
dataset = load_dataset(path=hf_path, split="train") # type: ignore
data = dataset[int(row)] # type: ignore

storage_dir = get_local_storage_dir()
source_dir = storage_dir / Path(data["source"]).parent # type: ignore
source_dir.mkdir(parents=True, exist_ok=True)
path = storage_dir / data["source"] # type: ignore

if not path.is_file():
with open(path, mode="w", encoding="utf-8") as file:
file.write(data["content"]) # type: ignore

return path
akonarski-ds marked this conversation as resolved.
Show resolved Hide resolved


def get_local_storage_dir() -> Path:
"""
Get the local storage directory.

The local storage directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.

Returns:
The local storage directory.
"""
return (
Path(local_dir_env)
if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is not None
else Path(tempfile.gettempdir()) / "ragbits"
)
31 changes: 31 additions & 0 deletions packages/ragbits-document-search/tests/integration/test_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from pathlib import Path

import pytest

from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, HuggingFaceSource

from ..helpers import env_vars_not_set

os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix()

HF_TOKEN_ENV = "HF_TOKEN" # nosec
HF_DATASET_PATH = "micpst/hf-docs?row=0"


@pytest.mark.skipif(
env_vars_not_set([HF_TOKEN_ENV]),
reason="Hugging Face environment variables not set",
)
async def test_huggingface_source_fetch() -> None:
source = HuggingFaceSource(hf_path=HF_DATASET_PATH)
path = await source.fetch()

assert path.is_file()
assert path.name == "README.md"
assert (
path.read_text()
== " `tokenizers-linux-x64-musl`\n\nThis is the **x86_64-unknown-linux-musl** binary for `tokenizers`\n"
)

path.unlink()
22 changes: 0 additions & 22 deletions packages/ragbits-document-search/tests/unit/test_gcs_source.py

This file was deleted.

39 changes: 39 additions & 0 deletions packages/ragbits-document-search/tests/unit/test_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from pathlib import Path
from unittest.mock import patch

from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, GCSSource, HuggingFaceSource

os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix()


async def test_gcs_source_fetch() -> None:
data = b"This is the content of the file."
source = GCSSource(bucket="", object_name="doc.md")

with patch("ragbits.document_search.documents.sources.Storage.download", return_value=data):
path = await source.fetch()

assert source.id == "gcs:gs:///doc.md"
assert path.name == "doc.md"
assert path.read_text() == "This is the content of the file."

path.unlink()


async def test_huggingface_source_fetch() -> None:
dataset = [
{"content": "This is the first document.", "source": "first_document.txt"},
{"content": "This is the second document.", "source": "second_document.txt"},
{"content": "This is the third document.", "source": "third_document.txt"},
]
source = HuggingFaceSource(hf_path="org/docs?row=1")

with patch("ragbits.document_search.documents.sources.load_dataset", return_value=dataset):
path = await source.fetch()

assert source.id == "huggingface:org/docs?row=1"
assert path.name == "second_document.txt"
assert path.read_text() == "This is the second document."

path.unlink()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ragbits-core[litellm,local,lab,chromadb]",
"ragbits-document-search[gcs]",
"ragbits-document-search[gcs, huggingface]",
"ragbits-cli"
]

Expand Down
Loading
Loading