Skip to content

Commit

Permalink
switch to async streaming download:
Browse files Browse the repository at this point in the history
- remove unused sync functions
- use async methods from stream-zip
- note that stream-zip still does a sync->async conversion under the hood
- follow-up to #1933 for streaming download improvements
  • Loading branch information
ikreymer committed Jul 30, 2024
1 parent 0c29008 commit 630255c
Showing 1 changed file with 20 additions and 62 deletions.
82 changes: 20 additions & 62 deletions backend/btrixcloud/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
List,
Dict,
AsyncIterator,
AsyncIterable,
TYPE_CHECKING,
Any,
cast,
)
from urllib.parse import urlsplit
from contextlib import asynccontextmanager
Expand All @@ -27,13 +27,12 @@
from zipfile import ZipInfo

from fastapi import Depends, HTTPException
from stream_zip import stream_zip, NO_COMPRESSION_64, Method
from stream_zip import async_stream_zip, NO_COMPRESSION_64, Method
from remotezip import RemoteZip

import aiobotocore.session
import boto3
import aiohttp

from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.type_defs import CompletedPartTypeDef
from types_aiobotocore_s3 import S3Client as AIOS3Client

Expand Down Expand Up @@ -289,35 +288,6 @@ async def get_s3_client(
) as client:
yield client, bucket, key

@asynccontextmanager
async def get_sync_client(
self, org: Organization
) -> AsyncIterator[tuple[S3Client, str, str]]:
"""context manager for s3 client"""
storage = self.get_org_primary_storage(org)

endpoint_url = storage.endpoint_url

if not endpoint_url.endswith("/"):
endpoint_url += "/"

parts = urlsplit(endpoint_url)
bucket, key = parts.path[1:].split("/", 1)

endpoint_url = parts.scheme + "://" + parts.netloc

try:
client = boto3.client(
"s3",
region_name=storage.region,
endpoint_url=endpoint_url,
aws_access_key_id=storage.access_key,
aws_secret_access_key=storage.secret_key,
)
yield client, bucket, key
finally:
client.close()

async def verify_storage_upload(self, storage: S3Storage, filename: str) -> None:
"""Test credentials and storage endpoint by uploading an empty test file"""

Expand Down Expand Up @@ -682,25 +652,28 @@ def _sync_get_filestream(self, wacz_url: str, filename: str) -> Iterator[bytes]:
with remote_zip.open(filename) as file_stream:
yield from file_stream

def _sync_dl(
self, all_files: List[CrawlFileOut], client: S3Client, bucket: str, key: str
) -> Iterator[bytes]:
async def download_streaming_wacz(
self, _: Organization, all_files: List[CrawlFileOut]
) -> AsyncIterable[bytes]:
"""generate streaming zip as sync"""
for file_ in all_files:
file_.path = file_.name

datapackage = {
"profile": "multi-wacz-package",
"resources": [file_.dict() for file_ in all_files],
}
datapackage_bytes = json.dumps(datapackage).encode("utf-8")

def get_file(name) -> Iterator[bytes]:
response = client.get_object(Bucket=bucket, Key=key + name)
return response["Body"].iter_chunks(chunk_size=CHUNK_SIZE)
async def get_datapackage() -> AsyncIterable[bytes]:
yield datapackage_bytes

async def get_file(path: str) -> AsyncIterable[bytes]:
path = self.resolve_internal_access_path(path)
async with aiohttp.ClientSession() as session:
async with session.get(path) as response:
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
yield chunk

def member_files() -> (
Iterable[tuple[str, datetime, int, Method, Iterable[bytes]]]
async def member_files() -> (
AsyncIterable[tuple[str, datetime, int, Method, AsyncIterable[bytes]]]
):
modified_at = datetime(year=1980, month=1, day=1)
perms = 0o664
Expand All @@ -710,7 +683,7 @@ def member_files() -> (
modified_at,
perms,
NO_COMPRESSION_64(file_.size, 0),
get_file(file_.name),
get_file(file_.path),
)

yield (
Expand All @@ -720,25 +693,10 @@ def member_files() -> (
NO_COMPRESSION_64(
len(datapackage_bytes), zlib.crc32(datapackage_bytes)
),
(datapackage_bytes,),
)

# stream_zip() is an Iterator but defined as an Iterable, can cast
return cast(Iterator[bytes], stream_zip(member_files(), chunk_size=CHUNK_SIZE))

async def download_streaming_wacz(
self, org: Organization, files: List[CrawlFileOut]
) -> Iterator[bytes]:
"""return an iter for downloading a stream nested wacz file
from list of files"""
async with self.get_sync_client(org) as (client, bucket, key):
loop = asyncio.get_event_loop()

resp = await loop.run_in_executor(
None, self._sync_dl, files, client, bucket, key
get_datapackage(),
)

return resp
return async_stream_zip(member_files(), chunk_size=CHUNK_SIZE)


# ============================================================================
Expand Down

0 comments on commit 630255c

Please sign in to comment.