Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Oct 17, 2024
2 parents 50a6730 + bf7d2c6 commit 1d8dce4
Show file tree
Hide file tree
Showing 18 changed files with 334 additions and 60 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
55 changes: 55 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pylint: disable=missing-function-docstring,missing-return-doc

import asyncio
from functools import wraps
from importlib.util import find_spec
from typing import Callable, ParamSpec, TypeVar

_P = ParamSpec("_P")
_T = TypeVar("_T")


def requires_dependencies(
dependencies: str | list[str],
extras: str | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Decorator to check if the dependencies are installed before running the function.
Args:
dependencies: The dependencies to check.
extras: The extras to install.
Returns:
The decorated function.
"""
if isinstance(dependencies, str):
dependencies = [dependencies]

def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def run_check() -> None:
missing_dependencies = [dependency for dependency in dependencies if not find_spec(dependency)]
if len(missing_dependencies) > 0:
missing_deps = ", ".join(missing_dependencies)
install_cmd = (
f"pip install 'ragbits[{extras}]'" if extras else f"pip install {' '.join(missing_dependencies)}"
)
raise ImportError(
f"Following dependencies are missing: {missing_deps}. Please install them using `{install_cmd}`."
)

@wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return func(*args, **kwargs)

@wraps(func)
async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return await func(*args, **kwargs) # type: ignore

if asyncio.iscoroutinefunction(func):
return wrapper_async # type: ignore
return wrapper

return decorator
47 changes: 47 additions & 0 deletions packages/ragbits-core/tests/unit/utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from ragbits.core.utils.decorators import requires_dependencies


def test_single_dependency_installed() -> None:
@requires_dependencies("pytest")
def some_function() -> str:
return "success"

assert some_function() == "success"


def test_single_dependency_missing() -> None:
@requires_dependencies("nonexistent_dependency")
def some_function() -> str:
return "success"

with pytest.raises(ImportError) as exc:
some_function()

assert (
str(exc.value)
== "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`."
)


def test_multiple_dependencies_installed() -> None:
@requires_dependencies(["pytest", "asyncio"])
def some_function() -> str:
return "success"

assert some_function() == "success"


def test_multiple_dependencies_some_missing() -> None:
@requires_dependencies(["pytest", "nonexistent_dependency"])
def some_function() -> str:
return "success"

with pytest.raises(ImportError) as exc:
some_function()

assert (
str(exc.value)
== "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`."
)
5 changes: 4 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ dependencies = [
"numpy~=1.24.0",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
"ragbits-core==0.1.0"
"ragbits-core==0.1.0",
]

[project.optional-dependencies]
gcs = [
"gcloud-aio-storage~=9.3.0"
]
huggingface = [
"datasets~=3.0.1",
]

[tool.uv]
dev-dependencies = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import tempfile
from enum import Enum
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import GCSSource, LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource


class DocumentType(str, Enum):
Expand Down Expand Up @@ -39,7 +38,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")
source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand All @@ -49,7 +48,7 @@ def id(self) -> str:
Returns:
The document ID.
"""
return self.source.get_id()
return self.source.id

async def fetch(self) -> "Document":
"""
Expand Down Expand Up @@ -98,7 +97,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
)

@classmethod
async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta":
async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta":
"""
Create a document metadata from a source.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class SourceError(Exception):
"""
Class for all exceptions raised by the document source.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class SourceConnectionError(SourceError):
"""
Raised when there is an error connecting to the document source.
"""

def __init__(self) -> None:
super().__init__("Connection error.")


class SourceNotFoundError(SourceError):
"""
Raised when the document is not found.
"""

def __init__(self, source_id: str) -> None:
super().__init__(f"Source with ID {source_id} not found.")
self.source_id = source_id
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
from pydantic import BaseModel

try:
from datasets import load_dataset
from datasets.exceptions import DatasetNotFoundError
from gcloud.aio.storage import Storage

HAS_GCLOUD_AIO = True
except ImportError:
HAS_GCLOUD_AIO = False
pass

from ragbits.core.utils.decorators import requires_dependencies
from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"
LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR"


class Source(BaseModel, ABC):
"""
An object representing a source.
"""

@property
@abstractmethod
def get_id(self) -> str:
def id(self) -> str:
"""
Get the source ID.
Expand All @@ -48,7 +52,8 @@ class LocalFileSource(Source):
source_type: Literal["local_file"] = "local_file"
path: Path

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.
Expand All @@ -63,7 +68,12 @@ async def fetch(self) -> Path:
Returns:
The local path to the object fetched from the source.
Raises:
SourceNotFoundError: If the source document is not found.
"""
if not self.path.is_file():
raise SourceNotFoundError(source_id=self.id)
return self.path


Expand All @@ -73,11 +83,11 @@ class GCSSource(Source):
"""

source_type: Literal["gcs"] = "gcs"

bucket: str
object_name: str

def get_id(self) -> str:
@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.
Expand All @@ -86,39 +96,107 @@ def get_id(self) -> str:
"""
return f"gcs:gs://{self.bucket}/{self.object_name}"

@requires_dependencies(["gcloud.aio.storage"], "gcs")
async def fetch(self) -> Path:
"""
Fetch the file from Google Cloud Storage and store it locally.
The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally,
it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS.
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.
Returns:
Path: The local path to the downloaded file.
Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
ImportError: If the 'gcp' extra is not installed.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
local_dir = Path(tempfile.gettempdir()) / "ragbits"
else:
local_dir = Path(local_dir_env)

local_dir = get_local_storage_dir()
bucket_local_dir = local_dir / self.bucket
bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client:
async with Storage() as client: # type: ignore
# TODO: Add error handling for download
content = await client.download(self.bucket, self.object_name)
Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True)
with open(path, mode="wb+") as file_object:
file_object.write(content)

return path


class HuggingFaceSource(Source):
"""
An object representing a Hugging Face dataset source.
"""

source_type: Literal["huggingface"] = "huggingface"
path: str
split: str = "train"
row: int

@property
def id(self) -> str:
"""
Get unique identifier of the object in the source.
Returns:
Unique identifier.
"""
return f"huggingface:{self.path}/{self.split}/{self.row}"

@requires_dependencies(["datasets"], "huggingface")
async def fetch(self) -> Path:
"""
Fetch the file from Hugging Face and store it locally.
Returns:
Path: The local path to the downloaded file.
Raises:
ImportError: If the 'huggingface' extra is not installed.
SourceConnectionError: If the source connection fails.
SourceNotFoundError: If the source document is not found.
"""
try:
dataset = load_dataset(self.path, split=self.split, streaming=True) # type: ignore
except ConnectionError as exc:
raise SourceConnectionError() from exc
except DatasetNotFoundError as exc: # type: ignore
raise SourceNotFoundError(source_id=self.id) from exc

try:
data = next(iter(dataset.skip(self.row).take(1))) # type: ignore
except StopIteration as exc:
raise SourceNotFoundError(source_id=self.id) from exc

storage_dir = get_local_storage_dir()
source_dir = storage_dir / Path(data["source"]).parent
source_dir.mkdir(parents=True, exist_ok=True)
path = storage_dir / data["source"]

if not path.is_file():
with open(path, mode="w", encoding="utf-8") as file:
file.write(data["content"])

return path


def get_local_storage_dir() -> Path:
"""
Get the local storage directory.
The local storage directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment
variable is not set, a temporary directory is used.
Returns:
The local storage directory.
"""
return (
Path(local_dir_env)
if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is not None
else Path(tempfile.gettempdir()) / "ragbits"
)
Loading

0 comments on commit 1d8dce4

Please sign in to comment.