diff --git a/packages/ragbits-document-search/examples/documents_chat.py b/examples/apps/documents_chat.py similarity index 100% rename from packages/ragbits-document-search/examples/documents_chat.py rename to examples/apps/documents_chat.py diff --git a/packages/ragbits-core/examples/llm_example.py b/examples/core/llm.py similarity index 100% rename from packages/ragbits-core/examples/llm_example.py rename to examples/core/llm.py diff --git a/packages/ragbits-core/examples/prompt_example.py b/examples/core/prompt.py similarity index 100% rename from packages/ragbits-core/examples/prompt_example.py rename to examples/core/prompt.py diff --git a/packages/ragbits-document-search/examples/simple_text.py b/examples/document-search/basic.py similarity index 100% rename from packages/ragbits-document-search/examples/simple_text.py rename to examples/document-search/basic.py diff --git a/packages/ragbits-core/examples/chromadb_example.py b/examples/document-search/chroma.py similarity index 100% rename from packages/ragbits-core/examples/chromadb_example.py rename to examples/document-search/chroma.py diff --git a/packages/ragbits-document-search/examples/from_config_example.py b/examples/document-search/from_config.py similarity index 100% rename from packages/ragbits-document-search/examples/from_config_example.py rename to examples/document-search/from_config.py diff --git a/packages/ragbits-core/src/ragbits/core/utils/decorators.py b/packages/ragbits-core/src/ragbits/core/utils/decorators.py new file mode 100644 index 000000000..a585fe5ef --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/utils/decorators.py @@ -0,0 +1,55 @@ +# pylint: disable=missing-function-docstring,missing-return-doc + +import asyncio +from functools import wraps +from importlib.util import find_spec +from typing import Callable, ParamSpec, TypeVar + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def requires_dependencies( + dependencies: str | list[str], + extras: str | None = None, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Decorator to check if the dependencies are installed before running the function. + + Args: + dependencies: The dependencies to check. + extras: The extras to install. + + Returns: + The decorated function. + """ + if isinstance(dependencies, str): + dependencies = [dependencies] + + def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + def run_check() -> None: + missing_dependencies = [dependency for dependency in dependencies if not find_spec(dependency)] + if len(missing_dependencies) > 0: + missing_deps = ", ".join(missing_dependencies) + install_cmd = ( + f"pip install 'ragbits[{extras}]'" if extras else f"pip install {' '.join(missing_dependencies)}" + ) + raise ImportError( + f"Following dependencies are missing: {missing_deps}. Please install them using `{install_cmd}`." + ) + + @wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + run_check() + return func(*args, **kwargs) + + @wraps(func) + async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _T: + run_check() + return await func(*args, **kwargs) # type: ignore + + if asyncio.iscoroutinefunction(func): + return wrapper_async # type: ignore + return wrapper + + return decorator diff --git a/packages/ragbits-core/tests/unit/utils/test_decorators.py b/packages/ragbits-core/tests/unit/utils/test_decorators.py new file mode 100644 index 000000000..49752e3d4 --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/test_decorators.py @@ -0,0 +1,47 @@ +import pytest + +from ragbits.core.utils.decorators import requires_dependencies + + +def test_single_dependency_installed() -> None: + @requires_dependencies("pytest") + def some_function() -> str: + return "success" + + assert some_function() == "success" + + +def test_single_dependency_missing() -> None: + @requires_dependencies("nonexistent_dependency") + def some_function() -> str: + return "success" + + with pytest.raises(ImportError) as exc: + some_function() + + assert ( + str(exc.value) + == "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`." + ) + + +def test_multiple_dependencies_installed() -> None: + @requires_dependencies(["pytest", "asyncio"]) + def some_function() -> str: + return "success" + + assert some_function() == "success" + + +def test_multiple_dependencies_some_missing() -> None: + @requires_dependencies(["pytest", "nonexistent_dependency"]) + def some_function() -> str: + return "success" + + with pytest.raises(ImportError) as exc: + some_function() + + assert ( + str(exc.value) + == "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`." + ) diff --git a/packages/ragbits-document-search/pyproject.toml b/packages/ragbits-document-search/pyproject.toml index 996e6baf7..ecd0063a6 100644 --- a/packages/ragbits-document-search/pyproject.toml +++ b/packages/ragbits-document-search/pyproject.toml @@ -34,13 +34,16 @@ dependencies = [ "numpy~=1.24.0", "unstructured>=0.15.13", "unstructured-client>=0.26.0", - "ragbits-core==0.1.0" + "ragbits-core==0.1.0", ] [project.optional-dependencies] gcs = [ "gcloud-aio-storage~=9.3.0" ] +huggingface = [ + "datasets~=3.0.1", +] [tool.uv] dev-dependencies = [ diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py index 0d43df918..581f33006 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py @@ -1,11 +1,10 @@ import tempfile from enum import Enum from pathlib import Path -from typing import Union from pydantic import BaseModel, Field -from ragbits.document_search.documents.sources import GCSSource, LocalFileSource +from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource class DocumentType(str, Enum): @@ -39,7 +38,7 @@ class DocumentMeta(BaseModel): """ document_type: DocumentType - source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type") + source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type") @property def id(self) -> str: @@ -49,7 +48,7 @@ def id(self) -> str: Returns: The document ID. """ - return self.source.get_id() + return self.source.id async def fetch(self) -> "Document": """ @@ -98,7 +97,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta": ) @classmethod - async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta": + async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta": """ Create a document metadata from a source. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/exceptions.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/exceptions.py new file mode 100644 index 000000000..9a41e5fd5 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/exceptions.py @@ -0,0 +1,27 @@ +class SourceError(Exception): + """ + Class for all exceptions raised by the document source. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class SourceConnectionError(SourceError): + """ + Raised when there is an error connecting to the document source. + """ + + def __init__(self) -> None: + super().__init__("Connection error.") + + +class SourceNotFoundError(SourceError): + """ + Raised when the document is not found. + """ + + def __init__(self, source_id: str) -> None: + super().__init__(f"Source with ID {source_id} not found.") + self.source_id = source_id diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index fc5a93a83..8d254d2d5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -7,13 +7,16 @@ from pydantic import BaseModel try: + from datasets import load_dataset + from datasets.exceptions import DatasetNotFoundError from gcloud.aio.storage import Storage - - HAS_GCLOUD_AIO = True except ImportError: - HAS_GCLOUD_AIO = False + pass + +from ragbits.core.utils.decorators import requires_dependencies +from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError -LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV" +LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR" class Source(BaseModel, ABC): @@ -21,8 +24,9 @@ class Source(BaseModel, ABC): An object representing a source. """ + @property @abstractmethod - def get_id(self) -> str: + def id(self) -> str: """ Get the source ID. @@ -48,7 +52,8 @@ class LocalFileSource(Source): source_type: Literal["local_file"] = "local_file" path: Path - def get_id(self) -> str: + @property + def id(self) -> str: """ Get unique identifier of the object in the source. @@ -63,7 +68,12 @@ async def fetch(self) -> Path: Returns: The local path to the object fetched from the source. + + Raises: + SourceNotFoundError: If the source document is not found. """ + if not self.path.is_file(): + raise SourceNotFoundError(source_id=self.id) return self.path @@ -73,11 +83,11 @@ class GCSSource(Source): """ source_type: Literal["gcs"] = "gcs" - bucket: str object_name: str - def get_id(self) -> str: + @property + def id(self) -> str: """ Get unique identifier of the object in the source. @@ -86,39 +96,107 @@ def get_id(self) -> str: """ return f"gcs:gs://{self.bucket}/{self.object_name}" + @requires_dependencies(["gcloud.aio.storage"], "gcs") async def fetch(self) -> Path: """ Fetch the file from Google Cloud Storage and store it locally. The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally, it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS. - The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment + The local directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment variable is not set, a temporary directory is used. Returns: Path: The local path to the downloaded file. Raises: - ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source. + ImportError: If the 'gcp' extra is not installed. """ - - if not HAS_GCLOUD_AIO: - raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage") - - if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None: - local_dir = Path(tempfile.gettempdir()) / "ragbits" - else: - local_dir = Path(local_dir_env) - + local_dir = get_local_storage_dir() bucket_local_dir = local_dir / self.bucket bucket_local_dir.mkdir(parents=True, exist_ok=True) path = bucket_local_dir / self.object_name if not path.is_file(): - async with Storage() as client: + async with Storage() as client: # type: ignore + # TODO: Add error handling for download content = await client.download(self.bucket, self.object_name) Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True) with open(path, mode="wb+") as file_object: file_object.write(content) return path + + +class HuggingFaceSource(Source): + """ + An object representing a Hugging Face dataset source. + """ + + source_type: Literal["huggingface"] = "huggingface" + path: str + split: str = "train" + row: int + + @property + def id(self) -> str: + """ + Get unique identifier of the object in the source. + + Returns: + Unique identifier. + """ + return f"huggingface:{self.path}/{self.split}/{self.row}" + + @requires_dependencies(["datasets"], "huggingface") + async def fetch(self) -> Path: + """ + Fetch the file from Hugging Face and store it locally. + + Returns: + Path: The local path to the downloaded file. + + Raises: + ImportError: If the 'huggingface' extra is not installed. + SourceConnectionError: If the source connection fails. + SourceNotFoundError: If the source document is not found. + """ + try: + dataset = load_dataset(self.path, split=self.split, streaming=True) # type: ignore + except ConnectionError as exc: + raise SourceConnectionError() from exc + except DatasetNotFoundError as exc: # type: ignore + raise SourceNotFoundError(source_id=self.id) from exc + + try: + data = next(iter(dataset.skip(self.row).take(1))) # type: ignore + except StopIteration as exc: + raise SourceNotFoundError(source_id=self.id) from exc + + storage_dir = get_local_storage_dir() + source_dir = storage_dir / Path(data["source"]).parent + source_dir.mkdir(parents=True, exist_ok=True) + path = storage_dir / data["source"] + + if not path.is_file(): + with open(path, mode="w", encoding="utf-8") as file: + file.write(data["content"]) + + return path + + +def get_local_storage_dir() -> Path: + """ + Get the local storage directory. + + The local storage directory is determined by the environment variable `LOCAL_STORAGE_DIR`. If this environment + variable is not set, a temporary directory is used. + + Returns: + The local storage directory. + """ + return ( + Path(local_dir_env) + if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is not None + else Path(tempfile.gettempdir()) / "ragbits" + ) diff --git a/packages/ragbits-document-search/tests/integration/test_sources.py b/packages/ragbits-document-search/tests/integration/test_sources.py new file mode 100644 index 000000000..f42ab8a56 --- /dev/null +++ b/packages/ragbits-document-search/tests/integration/test_sources.py @@ -0,0 +1,45 @@ +import os +from pathlib import Path + +import pytest + +from ragbits.document_search.documents.exceptions import SourceNotFoundError +from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, HuggingFaceSource + +from ..helpers import env_vars_not_set + +os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix() + +HF_TOKEN_ENV = "HF_TOKEN" # nosec +HF_DATASET_PATH = "micpst/hf-docs" + + +@pytest.mark.skipif( + env_vars_not_set([HF_TOKEN_ENV]), + reason="Hugging Face environment variables not set", +) +async def test_huggingface_source_fetch() -> None: + source = HuggingFaceSource(path=HF_DATASET_PATH, row=0) + path = await source.fetch() + + assert path.is_file() + assert path.name == "README.md" + assert ( + path.read_text() + == " `tokenizers-linux-x64-musl`\n\nThis is the **x86_64-unknown-linux-musl** binary for `tokenizers`\n" + ) + + path.unlink() + + +@pytest.mark.skipif( + env_vars_not_set([HF_TOKEN_ENV]), + reason="Hugging Face environment variables not set", +) +async def test_huggingface_source_fetch_not_found() -> None: + source = HuggingFaceSource(path=HF_DATASET_PATH, row=1000) + + with pytest.raises(SourceNotFoundError) as exc: + await source.fetch() + + assert str(exc.value) == "Source with ID huggingface:micpst/hf-docs/train/1000 not found." diff --git a/packages/ragbits-document-search/tests/unit/test_gcs_source.py b/packages/ragbits-document-search/tests/unit/test_gcs_source.py deleted file mode 100644 index da32b5a9f..000000000 --- a/packages/ragbits-document-search/tests/unit/test_gcs_source.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from pathlib import Path - -import aiohttp -import pytest - -from ragbits.document_search.documents.sources import GCSSource - -TEST_FILE_PATH = Path(__file__) - -os.environ["LOCAL_STORAGE_DIR_ENV"] = TEST_FILE_PATH.parent.as_posix() - - -async def test_gcs_source_fetch(): - source = GCSSource(bucket="", object_name="test_gcs_source.py") - - path = await source.fetch() - assert path == TEST_FILE_PATH - - source = GCSSource(bucket="", object_name="not_found_file.py") - with pytest.raises(aiohttp.ClientConnectionError): - await source.fetch() diff --git a/packages/ragbits-document-search/tests/unit/test_sources.py b/packages/ragbits-document-search/tests/unit/test_sources.py new file mode 100644 index 000000000..1c90df6e0 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_sources.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, GCSSource, HuggingFaceSource + +os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix() + + +async def test_gcs_source_fetch() -> None: + data = b"This is the content of the file." + source = GCSSource(bucket="", object_name="doc.md") + + with patch("ragbits.document_search.documents.sources.Storage.download", return_value=data): + path = await source.fetch() + + assert source.id == "gcs:gs:///doc.md" + assert path.name == "doc.md" + assert path.read_text() == "This is the content of the file." + + path.unlink() + + +async def test_huggingface_source_fetch() -> None: + take = MagicMock(return_value=[{"content": "This is the content of the file.", "source": "doc.md"}]) + skip = MagicMock(return_value=MagicMock(take=take)) + data = MagicMock(skip=skip) + source = HuggingFaceSource(path="org/docs", split="train", row=1) + + with patch("ragbits.document_search.documents.sources.load_dataset", return_value=data): + path = await source.fetch() + + assert source.id == "huggingface:org/docs/train/1" + assert path.name == "doc.md" + assert path.read_text() == "This is the content of the file." + + path.unlink() diff --git a/pyproject.toml b/pyproject.toml index a63771965..171b3fecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.10" dependencies = [ "ragbits-cli", "ragbits-core[litellm,local,lab,chromadb]", - "ragbits-document-search[gcs]", + "ragbits-document-search[gcs, huggingface]", "ragbits-evaluate", ] diff --git a/scripts/create_ragbits_package.py b/scripts/create_ragbits_package.py index 56b9e0031..f17afcdc6 100644 --- a/scripts/create_ragbits_package.py +++ b/scripts/create_ragbits_package.py @@ -36,9 +36,6 @@ def run() -> None: src_dir.mkdir(exist_ok=True, parents=True) (src_dir / "__init__.py").touch() - examples_dir = package_dir / "examples" - examples_dir.mkdir(exist_ok=True) - tests_dir = package_dir / "tests" tests_dir.mkdir(exist_ok=True) diff --git a/uv.lock b/uv.lock index 82205c530..eb65837c6 100644 --- a/uv.lock +++ b/uv.lock @@ -260,30 +260,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.35.39" +version = "1.35.42" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b8/29/10988ceaa300ddc628cb899875d85d9998e3da4803226398e002d95b2741/boto3-1.35.39.tar.gz", hash = "sha256:670f811c65e3c5fe4ed8c8d69be0b44b1d649e992c0fc16de43816d1188f88f1", size = 110975 } +sdist = { url = "https://files.pythonhosted.org/packages/04/e4/a60d99f727766d9801f4e14dca7e2df0245831844411562d81e0df2cc179/boto3-1.35.42.tar.gz", hash = "sha256:a5b00f8b82dce62870759f04861747944da834d64a64355970120c475efdafc0", size = 111011 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/ad/ba203ea67522d1184aa879d7ac063e5ffc7e6bafe00b4f79124e5fca0128/boto3-1.35.39-py3-none-any.whl", hash = "sha256:5970b62c1ec8177501e02520f0d41839ca5fc549b30bac4e8c0c0882ae776217", size = 139143 }, + { url = "https://files.pythonhosted.org/packages/ff/7c/88d9031a7a409393da450bcfca2f3597e0afccac4cae2d97fc4e7190f012/boto3-1.35.42-py3-none-any.whl", hash = "sha256:e1f36f8be453505cebcc3da178ea081b2a06c0e5e1cdee774f1067599b8d9c3e", size = 139159 }, ] [[package]] name = "botocore" -version = "1.35.39" +version = "1.35.42" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/28/d83dbd69d7015892b53ada4fded79a5bc1b7d77259361eb8302f88c2da81/botocore-1.35.39.tar.gz", hash = "sha256:cb7f851933b5ccc2fba4f0a8b846252410aa0efac5bfbe93b82d10801f5f8e90", size = 12826384 } +sdist = { url = "https://files.pythonhosted.org/packages/d3/0c/2bcd566397ab06661b222b9b5156ba0c40d5a97d3727c88ccaefea275cb4/botocore-1.35.42.tar.gz", hash = "sha256:af348636f73dc24b7e2dc760a34d08c8f2f94366e9b4c78d877307b128abecef", size = 12835012 } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/1f/296dc3b4c70b84328715fb7ee45f1d422fbed83cbcd464a3d4f29e91d197/botocore-1.35.39-py3-none-any.whl", hash = "sha256:781c547eb6a79c0e4b0bedd87b81fbfed957816b4841d33e20c8f1989c7c19ce", size = 12613407 }, + { url = "https://files.pythonhosted.org/packages/2e/f5/0e67c7e6a7f5f8c068cf444dc25d03097a22428380587542978d7ad9d86a/botocore-1.35.42-py3-none-any.whl", hash = "sha256:05af0bb8b9cea7ce7bc589c332348d338a21b784e9d088a588fd10ec145007ff", size = 12621471 }, ] [[package]] @@ -3496,6 +3496,9 @@ local = [ { name = "torch" }, { name = "transformers" }, ] +promptfoo = [ + { name = "pyyaml" }, +] [package.dev-dependencies] dev = [ @@ -3514,6 +3517,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.24.0" }, { name = "pydantic", specifier = ">=2.9.1" }, + { name = "pyyaml", marker = "extra == 'promptfoo'", specifier = "~=6.0.2" }, { name = "tomli", specifier = "~=2.0.2" }, { name = "torch", marker = "extra == 'local'", specifier = "~=2.2.1" }, { name = "transformers", marker = "extra == 'local'", specifier = "~=4.44.2" }, @@ -3544,6 +3548,9 @@ dependencies = [ gcs = [ { name = "gcloud-aio-storage" }, ] +huggingface = [ + { name = "datasets" }, +] [package.dev-dependencies] dev = [ @@ -3557,6 +3564,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "datasets", marker = "extra == 'huggingface'", specifier = "~=3.0.1" }, { name = "gcloud-aio-storage", marker = "extra == 'gcs'", specifier = "~=9.3.0" }, { name = "numpy", specifier = "~=1.24.0" }, { name = "ragbits-core", editable = "packages/ragbits-core" }, @@ -3616,7 +3624,7 @@ source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, { name = "ragbits-core", extra = ["chromadb", "lab", "litellm", "local"] }, - { name = "ragbits-document-search", extra = ["gcs"] }, + { name = "ragbits-document-search", extra = ["gcs", "huggingface"] }, { name = "ragbits-evaluate" }, ] @@ -3641,7 +3649,7 @@ dev = [ requires-dist = [ { name = "ragbits-cli", editable = "packages/ragbits-cli" }, { name = "ragbits-core", extras = ["litellm", "local", "lab", "chromadb"], editable = "packages/ragbits-core" }, - { name = "ragbits-document-search", extras = ["gcs"], editable = "packages/ragbits-document-search" }, + { name = "ragbits-document-search", extras = ["gcs", "huggingface"], editable = "packages/ragbits-document-search" }, { name = "ragbits-evaluate", editable = "packages/ragbits-evaluate" }, ]