diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 398172d97..12db5d33d 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -1,9 +1,12 @@ import io from furl import furl +import requests +from loguru import logger from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatmap_parallel +from daras_ai_v2.exceptions import raise_for_status def is_gdrive_url(f: furl) -> bool: @@ -60,7 +63,7 @@ def gdrive_list_urls_of_files_in_folder(f: furl, max_depth: int = 4) -> list[str return filter(None, urls) -def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: +def gdrive_download(f: furl, mime_type: str, export_links: dict) -> tuple[bytes, str]: from googleapiclient import discovery # get drive file id @@ -70,7 +73,7 @@ def gdrive_download(f: furl, mime_type: str) -> tuple[bytes, str]: request, mime_type = service_request(service, file_id, f, mime_type) file_bytes, mime_type = download_blob_file_content( - service, request, file_id, f, mime_type + service, request, file_id, f, mime_type, export_links ) return file_bytes, mime_type @@ -96,7 +99,7 @@ def service_request( def download_blob_file_content( - service, request, file_id: str, f: furl, mime_type: str + service, request, file_id: str, f: furl, mime_type: str, export_links: dict ) -> tuple[bytes, str]: from googleapiclient.http import MediaIoBaseDownload from googleapiclient.errors import HttpError @@ -105,38 +108,45 @@ def download_blob_file_content( file = io.BytesIO() downloader = MediaIoBaseDownload(file, request) - done = False - try: - while done is False: - _, done = downloader.next_chunk() - # print(f"Download {int(status.progress() * 100)}%") - except HttpError as error: - # retry if error exporting google docs format files e.g .pptx/.docx files uploaded to docs.google.com - if "presentation" in f.path.segments: - # update mime_type to download the file directly - mime_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" - request, _ = service_request( - service, file_id, f, mime_type, retried_request=True - ) - downloader = MediaIoBaseDownload(file, request) - done = False - while done is False: - _, done = downloader.next_chunk() + if ( + mime_type + == "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ): + + f_url_export = export_links.get(mime_type, None) + if f_url_export: + f_bytes, mime_type = download_from_exportlinks(f_url_export, mime_type) + else: - elif "document" in f.path.segments: - mime_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - request, _ = service_request( - service, file_id, f, mime_type, retried_request=True + request = service.files().get_media( + fileId=file_id, + supportsAllDrives=True, ) downloader = MediaIoBaseDownload(file, request) done = False while done is False: _, done = downloader.next_chunk() + # print(f"Download {int(status.progress() * 100)}%") + f_bytes = file.getvalue() + + else: + done = False + while done is False: + _, done = downloader.next_chunk() + # print(f"Download {int(status.progress() * 100)}%") + f_bytes = file.getvalue() + + return f_bytes, mime_type - else: - raise error - f_bytes = file.getvalue() +def download_from_exportlinks(f: furl, mime_type: str) -> tuple[bytes, str]: + + try: + r = requests.get(f) + f_bytes = r.content + + except requests.RequestException as e: + raise_for_status(e) return f_bytes, mime_type @@ -157,8 +167,10 @@ def docs_export_mimetype(f: furl) -> tuple[str, str]: mime_type = "text/csv" ext = ".csv" elif "presentation" in f.path.segments: - mime_type = "application/pdf" - ext = ".pdf" + mime_type = ( + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ) + ext = ".pptx" elif "drawings" in f.path.segments: mime_type = "application/pdf" ext = ".pdf" @@ -176,7 +188,7 @@ def gdrive_metadata(file_id: str) -> dict: .get( supportsAllDrives=True, fileId=file_id, - fields="name,md5Checksum,modifiedTime,mimeType,size", + fields="name,md5Checksum,modifiedTime,mimeType,size,exportLinks", ) .execute() ) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index f78c39260..8c5afbc4e 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -310,6 +310,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: etag = meta.get("md5Checksum") or meta.get("modifiedTime") mime_type = meta["mimeType"] total_bytes = int(meta.get("size") or 0) + export_links = meta.get("exportLinks") else: try: if is_user_uploaded_url(f_url): @@ -347,9 +348,12 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: # guess mimetype from name as a fallback if not mime_type: mime_type = mimetypes.guess_type(name)[0] - return FileMetadata( + + file_metadata = FileMetadata( name=name, etag=etag, mime_type=mime_type or "", total_bytes=total_bytes ) + file_metadata.export_links = export_links or {} + return file_metadata def yt_dlp_get_video_entries(url: str) -> list[dict]: @@ -650,7 +654,7 @@ def doc_url_to_text_pages( Download document from url and convert to text pages. """ f_bytes, mime_type = download_content_bytes( - f_url=f_url, mime_type=file_meta.mime_type, is_user_url=is_user_url + f_url=f_url, mime_type=file_meta.mime_type, is_user_url=is_user_url,export_links=file_meta.export_links ) if not f_bytes: return [] @@ -664,14 +668,14 @@ def doc_url_to_text_pages( def download_content_bytes( - *, f_url: str, mime_type: str, is_user_url: bool = True + *, f_url: str, mime_type: str, is_user_url: bool = True, export_links:dict = None ) -> tuple[bytes, str]: if is_yt_dlp_able_url(f_url): return download_youtube_to_wav(f_url), "audio/wav" f = furl(f_url) if is_gdrive_url(f): # download from google drive - return gdrive_download(f, mime_type) + return gdrive_download(f, mime_type,export_links) try: # download from url if is_user_uploaded_url(f_url): diff --git a/files/models.py b/files/models.py index 12af91a7a..3f7d3e678 100644 --- a/files/models.py +++ b/files/models.py @@ -1,5 +1,6 @@ from django.db import models from django.template.defaultfilters import filesizeformat +from collections import defaultdict class FileMetadata(models.Model): @@ -7,6 +8,10 @@ class FileMetadata(models.Model): etag = models.CharField(max_length=255, null=True) mime_type = models.CharField(max_length=255, default="", blank=True) total_bytes = models.PositiveIntegerField(default=0, blank=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.export_links = kwargs.get('export_links', defaultdict(dict)) def __str__(self): ret = f"{self.name or 'Unnamed'} - {self.mime_type}"