Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): add gcs source to the DocumentMeta #25

Merged
merged 8 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ dependencies = [
"ragbits"
]

[project.optional-dependencies]
gcs = [
"gcloud-aio-storage~=9.3.0"
]

[tool.uv]
dev-dependencies = [
"pre-commit~=3.8.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, LocalFileSource


class DocumentType(str, Enum):
Expand All @@ -21,7 +21,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource] = Field(..., discriminator="source_type")
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal

from pydantic import BaseModel

try:
from gcloud.aio.storage import Storage

HAS_GCLOUD_AIO = True
except ImportError:
HAS_GCLOUD_AIO = False

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"


class Source(BaseModel, ABC):
"""
Expand Down Expand Up @@ -54,3 +65,59 @@ async def fetch(self) -> Path:
The local path to the object fetched from the source.
"""
return self.path


class GCSSource(Source):
"""
An object representing a GCS file source.
"""

source_type: Literal["gcs"] = "gcs"

bucket: str
object_name: str

def get_id(self) -> str:
"""
Get unique identifier of the object in the source.

Returns:
Unique identifier.
"""
return f"gcs:gs://{self.bucket}/{self.object_name}"

async def fetch(self) -> Path:
"""
Fetch the file from Google Cloud Storage and store it locally.

The file is downloaded to a local directory specified by `local_dir`. If the file already exists locally,
it will not be downloaded again. If the file doesn't exist locally, it will be fetched from GCS.
The local directory is determined by the environment variable `LOCAL_STORAGE_DIR_ENV`. If this environment
variable is not set, a temporary directory is used.

Returns:
Path: The local path to the downloaded file.

Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
local_dir = Path(tempfile.gettempdir())
else:
local_dir = Path(local_dir_env)

bucket_local_dir = local_dir / self.bucket
bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name

if not path.is_file():
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
async with Storage() as client:
content = await client.download(self.bucket, self.object_name)
with open(path, mode="wb+") as file_object:
file_object.write(content)

return path
22 changes: 22 additions & 0 deletions packages/ragbits-document-search/tests/unit/test_gcs_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from pathlib import Path

import aiohttp
import pytest

from ragbits.document_search.documents.sources import GCSSource

TEST_FILE_PATH = Path(__file__)

os.environ["LOCAL_STORAGE_DIR_ENV"] = TEST_FILE_PATH.parent.as_posix()


async def test_gcs_source_fetch():
source = GCSSource(bucket="", object_name="test_gcs_source.py")

path = await source.fetch()
assert path == TEST_FILE_PATH

source = GCSSource(bucket="", object_name="not_found_file.py")
with pytest.raises(aiohttp.ClientConnectorError):
await source.fetch()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"ragbits[litellm,local]",
"ragbits-dev-kit",
"ragbits-document-search",
"ragbits-document-search[gcs]",
"ragbits-cli"
]

Expand Down
Loading