Skip to content

Commit

Permalink
WIP: Splits download mmif function for reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
mrharpo committed Apr 2, 2024
1 parent 9925b5e commit 0aaea77
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 63 deletions.
59 changes: 59 additions & 0 deletions chowda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import insert
from starlette.requests import Request
from starlette.responses import FileResponse, StreamingResponse


def adapt_url(url):
Expand Down Expand Up @@ -148,3 +149,61 @@ def yes() -> str:
from random import choice

return choice(YES)


async def download_mmif(pks: list[str]) -> StreamingResponse | FileResponse:
"""Download MMIF files from S3 and return a zip archive of them."""
import zipfile
from tempfile import TemporaryDirectory

import boto3
from sqlmodel import Session, select

from chowda.config import MMIF_S3_BUCKET_NAME
from chowda.db import engine
from chowda.models import MMIF

s3 = boto3.client('s3')
downloaded_mmif_files = []
tmp_dir = TemporaryDirectory()
download_errors = {}
with Session(engine) as db:
mmifs = db.exec(select(MMIF).where(MMIF.id.in_(pks)))

for mmif in mmifs:
mmif_tmp_location = f'{tmp_dir.name}/{mmif.mmif_location.split("/")[-1]}'
try:
s3.download_file(
MMIF_S3_BUCKET_NAME, mmif.mmif_location, mmif_tmp_location
)
downloaded_mmif_files.append(mmif_tmp_location)
except Exception as ex:
# TODO: log errors and notify user of them
download_errors[mmif.mmif_location] = ex

if len(pks) == 1:
# If only one batch was downloaded, return the file directly
return FileResponse(downloaded_mmif_files[0])

# Create zip archive
import io
from datetime import datetime

current_datetime = datetime.now().strftime('%Y-%m-%d_%H%M%S')
# TODO: include batch count, or names in the download file name?
zip_filename = f'chowda_mmif_download.{current_datetime}.zip'
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip:
for downloaded_mmif_file in downloaded_mmif_files:
filename = downloaded_mmif_file.split('/')[-1]
zip.write(downloaded_mmif_file, arcname=filename)

# Reset buffer to beginning of stream
zip_buffer.seek(0)

# Send download response
return StreamingResponse(
zip_buffer,
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'},
media_type='application/zip',
)
99 changes: 36 additions & 63 deletions chowda/views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, List, Set

from fastapi import status
from metaflow import Flow
from metaflow.exception import MetaflowNotFound
from metaflow.integrations import ArgoEvent
from multipart.exceptions import MultipartParseError
from sqlmodel import Session, select
from starlette.datastructures import FormData
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.responses import Response
from starlette.templating import Jinja2Templates
from starlette_admin import CustomView, action, row_action
from starlette_admin._types import RequestAction
Expand All @@ -35,7 +34,7 @@
SuccessfulField,
)
from chowda.models import MMIF, Batch, Collection, MediaFile
from chowda.utils import get_duplicates, validate_media_file_guids, yes
from chowda.utils import download_mmif, get_duplicates, validate_media_file_guids, yes
from templates import filters # noqa: F401


Expand Down Expand Up @@ -432,76 +431,25 @@ async def combine_batches(self, request: Request, pks: List[Any]) -> str:
async def download_mmif(
self, request: Request, pks: list[int | str] | int | str
) -> str:
"""Create a new batch from the selected batch"""
import zipfile

if not isinstance(pks, list):
pks = [pks]
from tempfile import TemporaryDirectory

import boto3

from chowda.config import MMIF_S3_BUCKET_NAME

with Session(engine) as db:
# Get all of the MMIF S3 locations from the database.
batches = db.exec(select(Batch).where(Batch.id.in_(pks))).all()
all_mmif_pks = [mmif.id for batch in batches for mmif in batch.output_mmifs]
try:
with Session(engine) as db:
# Get all of the MMIF S3 locations from the database.
batches = db.exec(select(Batch).where(Batch.id.in_(pks))).all()
all_mmif_locations = [
mmif.mmif_location
for batch in batches
for mmif in batch.output_mmifs
]

# Download files from S3
s3 = boto3.client('s3')
downloaded_mmif_files = []
tmp_dir = TemporaryDirectory()
download_errors = {}
for mmif_location in all_mmif_locations:
mmif_tmp_location = f'{tmp_dir.name}/{mmif_location.split("/")[-1]}'
try:
s3.download_file(
MMIF_S3_BUCKET_NAME, mmif_location, mmif_tmp_location
)
downloaded_mmif_files.append(mmif_tmp_location)
except Exception as ex:
# TODO: log errors and notify user of them
download_errors[mmif_location] = ex

# Create zip archive
import io
from datetime import datetime

from starlette.responses import StreamingResponse

current_datetime = datetime.now().strftime('%Y-%m-%d_%H%M%S')
# TODO: include batch count, or names in the download file name?
zip_filename = f'chowda_mmif_download.{current_datetime}.zip'
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zip:
for downloaded_mmif_file in downloaded_mmif_files:
filename = downloaded_mmif_file.split('/')[-1]
zip.write(downloaded_mmif_file, arcname=filename)

# Reset buffer to beginning of stream
zip_buffer.seek(0)

# Send download response
return StreamingResponse(
zip_buffer,
headers={
'Content-Disposition': f'attachment; filename="{zip_filename}"'
},
media_type='application/zip',
)
return await download_mmif(all_mmif_pks)
except Exception as error:
# TODO: pop 'error' out of session and display with javascript
# dangrerAlert() when admin/batch/list renders.
# See statics/js/alerts.js from starlette-admin.
from fastapi import status
from starlette.responses import RedirectResponse

request.session['error'] = f'{error!s}'
return RedirectResponse(
request.url_for('admin:list', identity='batch'),
request.url_for('admin:list', identity=request.path_params['identity']),
status_code=status.HTTP_303_SEE_OTHER,
)

Expand Down Expand Up @@ -816,3 +764,28 @@ async def add_to_existing_batch(
return f'Added {len(pks)} MMIFs to Batch {batch.id}'
except Exception as error:
raise ActionFailed(f'{error!s}') from error

@action(
name='download mmif',
text='Download MMIF',
confirmation='Download MMIF JSON for this MMIF?',
icon_class='fa fa-download',
submit_btn_text=yes(),
submit_btn_class='btn-outline-primary',
custom_response=True,
)
@row_action(
name='download mmif',
text='Download MMIF',
confirmation='Download MMIF JSON for this MMIF?',
icon_class='fa fa-download',
submit_btn_text=yes(),
submit_btn_class='btn-outline-primary',
custom_response=True,
)
async def download_mmif(
self, request: Request, pks: list[int | str] | int | str
) -> str:
if not isinstance(pks, list):
pks = [pks]
return await download_mmif(pks)

0 comments on commit 0aaea77

Please sign in to comment.