diff --git a/chowda/utils.py b/chowda/utils.py index 843d8b2..08c3202 100644 --- a/chowda/utils.py +++ b/chowda/utils.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from sqlalchemy.dialects.postgresql import insert from starlette.requests import Request +from starlette.responses import FileResponse, StreamingResponse def adapt_url(url): @@ -148,3 +149,61 @@ def yes() -> str: from random import choice return choice(YES) + + +async def download_mmif(pks: list[str]) -> StreamingResponse | FileResponse: + """Download MMIF files from S3 and return a zip archive of them.""" + import zipfile + from tempfile import TemporaryDirectory + + import boto3 + from sqlmodel import Session, select + + from chowda.config import MMIF_S3_BUCKET_NAME + from chowda.db import engine + from chowda.models import MMIF + + s3 = boto3.client('s3') + downloaded_mmif_files = [] + tmp_dir = TemporaryDirectory() + download_errors = {} + with Session(engine) as db: + mmifs = db.exec(select(MMIF).where(MMIF.id.in_(pks))) + + for mmif in mmifs: + mmif_tmp_location = f'{tmp_dir.name}/{mmif.mmif_location.split("/")[-1]}' + try: + s3.download_file( + MMIF_S3_BUCKET_NAME, mmif.mmif_location, mmif_tmp_location + ) + downloaded_mmif_files.append(mmif_tmp_location) + except Exception as ex: + # TODO: log errors and notify user of them + download_errors[mmif.mmif_location] = ex + + if len(pks) == 1: + # If only one batch was downloaded, return the file directly + return FileResponse(downloaded_mmif_files[0]) + + # Create zip archive + import io + from datetime import datetime + + current_datetime = datetime.now().strftime('%Y-%m-%d_%H%M%S') + # TODO: include batch count, or names in the download file name? + zip_filename = f'chowda_mmif_download.{current_datetime}.zip' + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, 'w') as zip: + for downloaded_mmif_file in downloaded_mmif_files: + filename = downloaded_mmif_file.split('/')[-1] + zip.write(downloaded_mmif_file, arcname=filename) + + # Reset buffer to beginning of stream + zip_buffer.seek(0) + + # Send download response + return StreamingResponse( + zip_buffer, + headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}, + media_type='application/zip', + ) diff --git a/chowda/views.py b/chowda/views.py index 101d962..1b81a93 100644 --- a/chowda/views.py +++ b/chowda/views.py @@ -1,7 +1,6 @@ from datetime import datetime, timedelta from typing import Any, ClassVar, Dict, List, Set -from fastapi import status from metaflow import Flow from metaflow.exception import MetaflowNotFound from metaflow.integrations import ArgoEvent @@ -9,7 +8,7 @@ from sqlmodel import Session, select from starlette.datastructures import FormData from starlette.requests import Request -from starlette.responses import RedirectResponse, Response +from starlette.responses import Response from starlette.templating import Jinja2Templates from starlette_admin import CustomView, action, row_action from starlette_admin._types import RequestAction @@ -35,7 +34,7 @@ SuccessfulField, ) from chowda.models import MMIF, Batch, Collection, MediaFile -from chowda.utils import get_duplicates, validate_media_file_guids, yes +from chowda.utils import download_mmif, get_duplicates, validate_media_file_guids, yes from templates import filters # noqa: F401 @@ -432,76 +431,25 @@ async def combine_batches(self, request: Request, pks: List[Any]) -> str: async def download_mmif( self, request: Request, pks: list[int | str] | int | str ) -> str: - """Create a new batch from the selected batch""" - import zipfile - if not isinstance(pks, list): pks = [pks] - from tempfile import TemporaryDirectory - - import boto3 - - from chowda.config import MMIF_S3_BUCKET_NAME + with Session(engine) as db: + # Get all of the MMIF S3 locations from the database. + batches = db.exec(select(Batch).where(Batch.id.in_(pks))).all() + all_mmif_pks = [mmif.id for batch in batches for mmif in batch.output_mmifs] try: - with Session(engine) as db: - # Get all of the MMIF S3 locations from the database. - batches = db.exec(select(Batch).where(Batch.id.in_(pks))).all() - all_mmif_locations = [ - mmif.mmif_location - for batch in batches - for mmif in batch.output_mmifs - ] - - # Download files from S3 - s3 = boto3.client('s3') - downloaded_mmif_files = [] - tmp_dir = TemporaryDirectory() - download_errors = {} - for mmif_location in all_mmif_locations: - mmif_tmp_location = f'{tmp_dir.name}/{mmif_location.split("/")[-1]}' - try: - s3.download_file( - MMIF_S3_BUCKET_NAME, mmif_location, mmif_tmp_location - ) - downloaded_mmif_files.append(mmif_tmp_location) - except Exception as ex: - # TODO: log errors and notify user of them - download_errors[mmif_location] = ex - - # Create zip archive - import io - from datetime import datetime - - from starlette.responses import StreamingResponse - - current_datetime = datetime.now().strftime('%Y-%m-%d_%H%M%S') - # TODO: include batch count, or names in the download file name? - zip_filename = f'chowda_mmif_download.{current_datetime}.zip' - zip_buffer = io.BytesIO() - with zipfile.ZipFile(zip_buffer, 'w') as zip: - for downloaded_mmif_file in downloaded_mmif_files: - filename = downloaded_mmif_file.split('/')[-1] - zip.write(downloaded_mmif_file, arcname=filename) - - # Reset buffer to beginning of stream - zip_buffer.seek(0) - - # Send download response - return StreamingResponse( - zip_buffer, - headers={ - 'Content-Disposition': f'attachment; filename="{zip_filename}"' - }, - media_type='application/zip', - ) + return await download_mmif(all_mmif_pks) except Exception as error: # TODO: pop 'error' out of session and display with javascript # dangrerAlert() when admin/batch/list renders. # See statics/js/alerts.js from starlette-admin. + from fastapi import status + from starlette.responses import RedirectResponse + request.session['error'] = f'{error!s}' return RedirectResponse( - request.url_for('admin:list', identity='batch'), + request.url_for('admin:list', identity=request.path_params['identity']), status_code=status.HTTP_303_SEE_OTHER, ) @@ -816,3 +764,28 @@ async def add_to_existing_batch( return f'Added {len(pks)} MMIFs to Batch {batch.id}' except Exception as error: raise ActionFailed(f'{error!s}') from error + + @action( + name='download mmif', + text='Download MMIF', + confirmation='Download MMIF JSON for this MMIF?', + icon_class='fa fa-download', + submit_btn_text=yes(), + submit_btn_class='btn-outline-primary', + custom_response=True, + ) + @row_action( + name='download mmif', + text='Download MMIF', + confirmation='Download MMIF JSON for this MMIF?', + icon_class='fa fa-download', + submit_btn_text=yes(), + submit_btn_class='btn-outline-primary', + custom_response=True, + ) + async def download_mmif( + self, request: Request, pks: list[int | str] | int | str + ) -> str: + if not isinstance(pks, list): + pks = [pks] + return await download_mmif(pks)