Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to async streaming download: #1982

Merged
merged 9 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 31 additions & 65 deletions backend/btrixcloud/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
List,
Dict,
AsyncIterator,
AsyncIterable,
TYPE_CHECKING,
Any,
cast,
)
from urllib.parse import urlsplit
from contextlib import asynccontextmanager
Expand All @@ -27,15 +27,14 @@
from zipfile import ZipInfo

from fastapi import Depends, HTTPException
from stream_zip import stream_zip, NO_COMPRESSION_64, Method
from stream_zip import async_stream_zip, NO_COMPRESSION_64, Method
from remotezip import RemoteZip

import aiobotocore.session
import boto3
import aiohttp

from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.type_defs import CompletedPartTypeDef
from types_aiobotocore_s3 import S3Client as AIOS3Client
from types_aiobotocore_s3.type_defs import CompletedPartTypeDef

from .models import (
BaseFile,
Expand Down Expand Up @@ -289,35 +288,6 @@ async def get_s3_client(
) as client:
yield client, bucket, key

@asynccontextmanager
async def get_sync_client(
self, org: Organization
) -> AsyncIterator[tuple[S3Client, str, str]]:
"""context manager for s3 client"""
storage = self.get_org_primary_storage(org)

endpoint_url = storage.endpoint_url

if not endpoint_url.endswith("/"):
endpoint_url += "/"

parts = urlsplit(endpoint_url)
bucket, key = parts.path[1:].split("/", 1)

endpoint_url = parts.scheme + "://" + parts.netloc

try:
client = boto3.client(
"s3",
region_name=storage.region,
endpoint_url=endpoint_url,
aws_access_key_id=storage.access_key,
aws_secret_access_key=storage.secret_key,
)
yield client, bucket, key
finally:
client.close()

async def verify_storage_upload(self, storage: S3Storage, filename: str) -> None:
"""Test credentials and storage endpoint by uploading an empty test file"""

Expand Down Expand Up @@ -682,25 +652,36 @@ def _sync_get_filestream(self, wacz_url: str, filename: str) -> Iterator[bytes]:
with remote_zip.open(filename) as file_stream:
yield from file_stream

def _sync_dl(
self, all_files: List[CrawlFileOut], client: S3Client, bucket: str, key: str
) -> Iterator[bytes]:
async def download_streaming_wacz(
self, _: Organization, all_files: List[CrawlFileOut]
) -> AsyncIterable[bytes]:
"""generate streaming zip as sync"""
for file_ in all_files:
file_.path = file_.name

datapackage = {
"profile": "multi-wacz-package",
"resources": [file_.dict() for file_ in all_files],
"resources": [
{
"name": file_.name,
"path": file_.name,
"hash": "sha256:" + file_.hash,
"bytes": file_.size,
}
for file_ in all_files
],
}
datapackage_bytes = json.dumps(datapackage).encode("utf-8")
datapackage_bytes = json.dumps(datapackage, indent=2).encode("utf-8")

async def get_datapackage() -> AsyncIterable[bytes]:
yield datapackage_bytes

def get_file(name) -> Iterator[bytes]:
response = client.get_object(Bucket=bucket, Key=key + name)
return response["Body"].iter_chunks(chunk_size=CHUNK_SIZE)
async def get_file(path: str) -> AsyncIterable[bytes]:
path = self.resolve_internal_access_path(path)
async with aiohttp.ClientSession() as session:
async with session.get(path) as response:
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
yield chunk

def member_files() -> (
Iterable[tuple[str, datetime, int, Method, Iterable[bytes]]]
async def member_files() -> (
AsyncIterable[tuple[str, datetime, int, Method, AsyncIterable[bytes]]]
):
modified_at = datetime(year=1980, month=1, day=1)
perms = 0o664
Expand All @@ -710,7 +691,7 @@ def member_files() -> (
modified_at,
perms,
NO_COMPRESSION_64(file_.size, 0),
get_file(file_.name),
get_file(file_.path),
)

yield (
Expand All @@ -720,25 +701,10 @@ def member_files() -> (
NO_COMPRESSION_64(
len(datapackage_bytes), zlib.crc32(datapackage_bytes)
),
(datapackage_bytes,),
)

# stream_zip() is an Iterator but defined as an Iterable, can cast
return cast(Iterator[bytes], stream_zip(member_files(), chunk_size=CHUNK_SIZE))

async def download_streaming_wacz(
self, org: Organization, files: List[CrawlFileOut]
) -> Iterator[bytes]:
"""return an iter for downloading a stream nested wacz file
from list of files"""
async with self.get_sync_client(org) as (client, bucket, key):
loop = asyncio.get_event_loop()

resp = await loop.run_in_executor(
None, self._sync_dl, files, client, bucket, key
get_datapackage(),
)

return resp
return async_stream_zip(member_files(), chunk_size=CHUNK_SIZE)


# ============================================================================
Expand Down
2 changes: 0 additions & 2 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ humanize
python-multipart
pathvalidate
https://github.com/ikreymer/stream-zip/archive/refs/heads/crc32-optional.zip
boto3
backoff>=2.2.1
python-slugify>=8.0.1
mypy_boto3_s3
types_aiobotocore_s3
types-redis
types-python-slugify
Expand Down
10 changes: 10 additions & 0 deletions backend/test/test_run_crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import csv
import codecs
import json
from tempfile import TemporaryFile
from zipfile import ZipFile, ZIP_STORED

Expand Down Expand Up @@ -406,6 +407,15 @@ def test_download_wacz_crawls(
assert filename.endswith(".wacz") or filename == "datapackage.json"
assert zip_file.getinfo(filename).compress_type == ZIP_STORED

if filename == "datapackage.json":
data = zip_file.read(filename).decode("utf-8")
datapackage = json.loads(data)
assert len(datapackage["resources"]) == 1
for resource in datapackage["resources"]:
assert resource["name"] == resource["path"]
assert resource["hash"]
assert resource["bytes"]


def test_update_crawl(
admin_auth_headers,
Expand Down
Loading