Skip to content

Commit

Permalink
Changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla committed Sep 23, 2024
1 parent ac8b000 commit 4281382
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
]

[project.optional-dependencies]
google-cloud-storage = [
gcs = [
"gcloud-aio-storage~=9.3.0"
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel, Field

from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.documents.sources import GCSSource, LocalFileSource


class DocumentType(str, Enum):
Expand All @@ -21,7 +21,7 @@ class DocumentMeta(BaseModel):
"""

document_type: DocumentType
source: Union[LocalFileSource] = Field(..., discriminator="source_type")
source: Union[LocalFileSource, GCSSource] = Field(..., discriminator="source_type")

@property
def id(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal
Expand All @@ -11,6 +12,8 @@
except ImportError:
HAS_GCLOUD_AIO = False

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR_ENV"


class Source(BaseModel, ABC):
"""
Expand Down Expand Up @@ -63,18 +66,16 @@ async def fetch(self) -> Path:
return self.path


class GoogleCloudStorageSource(Source):
class GCSSource(Source):
"""
An object representing a GCS file source.
"""

source_type: Literal["google_cloud_storage_file"] = "google_cloud_storage_file"
source_type: Literal["gcs_file"] = "gcs_file"

bucket: str
object_name: str

local_dir: Path = Path("tmp/ragbits/")

def get_id(self) -> str:
"""
Get unique identifier of the object in the source.
Expand All @@ -96,12 +97,17 @@ async def fetch(self) -> Path:
Raises:
ImportError: If the required 'gcloud' package is not installed for Google Cloud Storage source.
ValueError: If LOCAL_STORAGE_DIR_ENV is not set.
"""

if not HAS_GCLOUD_AIO:
raise ImportError("You need to install the 'gcloud' package to use Google Cloud Storage")
raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage")

if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None:
raise ValueError(f"{LOCAL_STORAGE_DIR_ENV} environment variable is not set")

bucket_local_dir = self.local_dir / self.bucket
local_dir: Path = Path(local_dir_env)
bucket_local_dir = local_dir / self.bucket

bucket_local_dir.mkdir(parents=True, exist_ok=True)
path = bucket_local_dir / self.object_name
Expand Down

0 comments on commit 4281382

Please sign in to comment.