Skip to content

Commit

Permalink
Merge branch 'main' into 93-enabler-document-batch-ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 17, 2024
2 parents dd89398 + bf7d2c6 commit 889b5d1
Show file tree
Hide file tree
Showing 19 changed files with 496 additions and 57 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 5 additions & 2 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ def main() -> None:

cli_enabled_modules = [
module
for i, module in enumerate(pkgutil.iter_modules(ragbits.__path__))
if module.ispkg and module.name != "cli" and (Path(ragbits.__path__[i]) / module.name / "cli.py").exists()
for module in pkgutil.iter_modules(ragbits.__path__)
if module.ispkg
and module.name != "cli"
and (Path(module.module_finder.path) / module.name / "cli.py").exists() # type: ignore
]

for module in cli_enabled_modules:
register_func = importlib.import_module(f"ragbits.{module.name}.cli").register
register_func(app, help_only)
Expand Down
55 changes: 55 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pylint: disable=missing-function-docstring,missing-return-doc

import asyncio
from functools import wraps
from importlib.util import find_spec
from typing import Callable, ParamSpec, TypeVar

_P = ParamSpec("_P")
_T = TypeVar("_T")


def requires_dependencies(
dependencies: str | list[str],
extras: str | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
Decorator to check if the dependencies are installed before running the function.
Args:
dependencies: The dependencies to check.
extras: The extras to install.
Returns:
The decorated function.
"""
if isinstance(dependencies, str):
dependencies = [dependencies]

def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def run_check() -> None:
missing_dependencies = [dependency for dependency in dependencies if not find_spec(dependency)]
if len(missing_dependencies) > 0:
missing_deps = ", ".join(missing_dependencies)
install_cmd = (
f"pip install 'ragbits[{extras}]'" if extras else f"pip install {' '.join(missing_dependencies)}"
)
raise ImportError(
f"Following dependencies are missing: {missing_deps}. Please install them using `{install_cmd}`."
)

@wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return func(*args, **kwargs)

@wraps(func)
async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _T:
run_check()
return await func(*args, **kwargs) # type: ignore

if asyncio.iscoroutinefunction(func):
return wrapper_async # type: ignore
return wrapper

return decorator
47 changes: 47 additions & 0 deletions packages/ragbits-core/tests/unit/utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from ragbits.core.utils.decorators import requires_dependencies


def test_single_dependency_installed() -> None:
@requires_dependencies("pytest")
def some_function() -> str:
return "success"

assert some_function() == "success"


def test_single_dependency_missing() -> None:
@requires_dependencies("nonexistent_dependency")
def some_function() -> str:
return "success"

with pytest.raises(ImportError) as exc:
some_function()

assert (
str(exc.value)
== "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`."
)


def test_multiple_dependencies_installed() -> None:
@requires_dependencies(["pytest", "asyncio"])
def some_function() -> str:
return "success"

assert some_function() == "success"


def test_multiple_dependencies_some_missing() -> None:
@requires_dependencies(["pytest", "nonexistent_dependency"])
def some_function() -> str:
return "success"

with pytest.raises(ImportError) as exc:
some_function()

assert (
str(exc.value)
== "Following dependencies are missing: nonexistent_dependency. Please install them using `pip install nonexistent_dependency`."
)
5 changes: 4 additions & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ dependencies = [
"numpy~=1.24.0",
"unstructured>=0.15.13",
"unstructured-client>=0.26.0",
"ragbits-core==0.1.0"
"ragbits-core==0.1.0",
]

[project.optional-dependencies]
gcs = [
"gcloud-aio-storage~=9.3.0"
]
huggingface = [
"datasets~=3.0.1",
]

[tool.uv]
dev-dependencies = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import tempfile
from enum import Enum
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import GCSSource, LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, HuggingFaceSource, LocalFileSource


class DocumentType(str, Enum):
Expand Down Expand Up @@ -39,7 +38,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")
source: LocalFileSource | GCSSource | HuggingFaceSource = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand All @@ -49,7 +48,7 @@ def id(self) -> str:
Returns:
The document ID.
"""
return self.source.get_id()
return self.source.id

async def fetch(self) -> "Document":
"""
Expand Down Expand Up @@ -98,7 +97,7 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":
)

@classmethod
async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta":
async def from_source(cls, source: LocalFileSource | GCSSource | HuggingFaceSource) -> "DocumentMeta":
"""
Create a document metadata from a source.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class SourceError(Exception):
"""
Class for all exceptions raised by the document source.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class SourceConnectionError(SourceError):
"""
Raised when there is an error connecting to the document source.
"""

def __init__(self) -> None:
super().__init__("Connection error.")


class SourceNotFoundError(SourceError):
"""
Raised when the document is not found.
"""

def __init__(self, source_id: str) -> None:
super().__init__(f"Source with ID {source_id} not found.")
self.source_id = source_id
Loading

0 comments on commit 889b5d1

Please sign in to comment.