Skip to content

Commit

Permalink
chore: Allow indexer to use filenames when downloading files from Goo…
Browse files Browse the repository at this point in the history
…gle Drive.
  • Loading branch information
osala-eng committed Oct 28, 2023
1 parent 99016ec commit f4d7d6a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion source/docq/data_source/support/opendal_reader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def download_from_gdrive(files: List[dict], temp_dir: str, service: Any,) -> Lis
if suffix not in DEFAULT_FILE_READER_CLS:
continue

file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
file_path = f"{temp_dir}/{file['name']}" # type: ignore
indexed_on = datetime.timestamp(datetime.now().utcnow())
services.google_drive.download_file(service, file["id"], file_path)
downloaded_files.append((file["webViewLink"], file_path, int(indexed_on), int(file["size"])))
Expand Down
18 changes: 13 additions & 5 deletions source/docq/manage_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unicodedata
from datetime import datetime
from mimetypes import guess_type
from typing import Optional

from llama_index.schema import NodeWithScore
from streamlit import runtime
Expand Down Expand Up @@ -42,10 +43,17 @@ def delete_all(space: SpaceKey) -> None:

reindex(space)

def _is_web_address(uri: str) -> bool:
"""Return true if the uri is a web address."""
return uri.startswith("http://") or uri.startswith("https://")


def _get_download_link(filename: str, path: str) -> str:
"""Return the download link for the file if runtime exists, otherwise return an empty string."""
if runtime.exists() and os.path.isfile(path):
if _is_web_address(path):
return path

elif runtime.exists() and os.path.isfile(path):
return runtime.get_instance().media_file_mgr.add(
path_or_data=path,
mimetype=guess_type(path)[0] or "application/octet-stream",
Expand All @@ -68,16 +76,16 @@ def _parse_metadata(metadata: dict) -> tuple:
s_type = metadata.get(str(DocumentMetadata.DATA_SOURCE_TYPE.name).lower())
uri = metadata.get(str(DocumentMetadata.SOURCE_URI.name).lower())
if s_type == "SpaceDataSourceWebBased":
website = _remove_ascii_control_characters(metadata.get("source_website"))
page_title = _remove_ascii_control_characters(metadata.get("page_title"))
website = _remove_ascii_control_characters(metadata.get("source_website", ""))
page_title = _remove_ascii_control_characters(metadata.get("page_title", ""))
return website, page_title, uri, s_type
else:
file_name = metadata.get("file_name")
page_label = metadata.get("page_label")
return file_name, page_label, uri, s_type


def _classify_file_sources(name: str, uri: str, page: str, sources: dict = None) -> str:
def _classify_file_sources(name: str, uri: str, page: str, sources: Optional[dict] = None) -> dict:
"""Classify file sources for easy grouping."""
if sources is None:
sources = {}
Expand All @@ -88,7 +96,7 @@ def _classify_file_sources(name: str, uri: str, page: str, sources: dict = None)
return sources


def _classify_web_sources(website: str, uri: str, page_title: str, sources: dict = None) -> str:
def _classify_web_sources(website: str, uri: str, page_title: str, sources: Optional[dict] = None) -> dict:
"""Classify web sources for easy grouping."""
if sources is None:
sources = {}
Expand Down

0 comments on commit f4d7d6a

Please sign in to comment.