Skip to content

Commit

Permalink
feat: BI-5951 presigned url generation in file-uploader-api (#756)
Browse files Browse the repository at this point in the history
* feat: BI-5951 WIP presigned url generation in file-uploader-api

* fix handler, fix mypy

* update comments
  • Loading branch information
KonstantAnxiety authored Dec 20, 2024
1 parent 56c416f commit 413bacc
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def create_app(self, app_version: str) -> web.Application:
app.router.add_route("get", "/api/v2/metrics", MetricsView)

app.router.add_route("post", "/api/v2/files", files_views.FilesView)
app.router.add_route("post", "/api/v2/make_presigned_url", files_views.MakePresignedUrlView)
app.router.add_route("post", "/api/v2/links", files_views.LinksView)
app.router.add_route("post", "/api/v2/documents", files_views.DocumentsView)
app.router.add_route("post", "/api/v2/update_connection_data", files_views.UpdateConnectionDataView)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ class FileUploadResponseSchema(ma.Schema):
title = ma.fields.String()


class MakePresignedUrlRequestSchema(ma.Schema):
content_md5 = ma.fields.String(required=True)


class PresignedUrlSchema(ma.Schema):
class PresignedUrlFields(ma.Schema):
class Meta:
unknown = ma.INCLUDE

key = ma.fields.String()
x_amz_algorithm = ma.fields.String(attribute="x-amz-algorithm", data_key="x-amz-algorithm")
x_amz_credential = ma.fields.String(attribute="x-amz-credential", data_key="x-amz-credential")
x_amz_date = ma.fields.String(attribute="x-amz-date", data_key="x-amz-date")
policy = ma.fields.String()
x_amz_signature = ma.fields.String(attribute="x-amz-signature", data_key="x-amz-signature")

url = ma.fields.String(required=True)
_fields = ma.fields.Nested(PresignedUrlFields, required=True, attribute="fields", data_key="fields")


class FileStatusRequestSchema(BaseRequestSchema):
file_id = ma.fields.String(required=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ClassVar,
Optional,
)
import uuid

from aiohttp import web
from aiohttp.multipart import BodyPartReader
Expand Down Expand Up @@ -82,13 +83,13 @@ async def post(self) -> web.StreamResponse:
s3 = self.dl_request.get_s3_service()

rmm = self.dl_request.get_redis_model_manager()
df = DataFile(
dfile = DataFile(
manager=rmm,
filename=filename,
file_type=file_type,
status=FileProcessingStatus.in_progress,
)
LOGGER.info(f"Data file id: {df.id}")
LOGGER.info(f"Data file id: {dfile.id}")

async def _chunk_iter(chunk_size: int = 10 * 1024 * 1024) -> AsyncGenerator[bytes, None]:
assert isinstance(field, BodyPartReader)
Expand All @@ -104,31 +105,55 @@ async def _chunk_iter(chunk_size: int = 10 * 1024 * 1024) -> AsyncGenerator[byte
data_stream = RawBytesAsyncDataStream(data_iter=_chunk_iter())
async with S3RawFileAsyncDataSink(
s3=s3.client,
s3_key=df.s3_key,
s3_key=dfile.s3_key_old,
bucket_name=s3.tmp_bucket_name,
max_file_size_exc=exc.FileLimitError,
) as data_sink:
await data_sink.dump_data_stream(data_stream)

# df.size = size # TODO

await df.save()
await dfile.save()
LOGGER.info(f'Uploaded file "{filename}".')

task_processor = self.dl_request.get_task_processor()
if file_type == FileType.xlsx:
await task_processor.schedule(ProcessExcelTask(file_id=df.id))
LOGGER.info(f"Scheduled ProcessExcelTask for file_id {df.id}")
await task_processor.schedule(ProcessExcelTask(file_id=dfile.id))
LOGGER.info(f"Scheduled ProcessExcelTask for file_id {dfile.id}")
else:
await task_processor.schedule(ParseFileTask(file_id=df.id))
LOGGER.info(f"Scheduled ParseFileTask for file_id {df.id}")
await task_processor.schedule(ParseFileTask(file_id=dfile.id))
LOGGER.info(f"Scheduled ParseFileTask for file_id {dfile.id}")

return web.json_response(
files_schemas.FileUploadResponseSchema().dump({"file_id": df.id, "title": df.filename}),
files_schemas.FileUploadResponseSchema().dump({"file_id": dfile.id, "title": dfile.filename}),
status=HTTPStatus.CREATED,
)


class MakePresignedUrlView(FileUploaderBaseView):
async def post(self) -> web.StreamResponse:
req_data = await self._load_post_request_schema_data(files_schemas.MakePresignedUrlRequestSchema)
content_md5: str = req_data["content_md5"]

s3 = self.dl_request.get_s3_service()
s3_key = "{}_{}".format(self.dl_request.rci.user_id or "unknown", str(uuid.uuid4()))

url = await s3.client.generate_presigned_post(
Bucket=s3.tmp_bucket_name,
Key=s3_key,
ExpiresIn=60 * 60, # 1 hour
Conditions=[
["content-length-range", 1, 200 * 1024 * 1024], # 1B .. 200MB # TODO use constant from DataSink
{"Content-MD5": content_md5},
],
)

return web.json_response(
files_schemas.PresignedUrlSchema().dump(url),
status=HTTPStatus.OK,
)


class LinksView(FileUploaderBaseView):
REQUIRED_RESOURCES: ClassVar[frozenset[RequiredResource]] = frozenset() # Don't skip CSRF check

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ async def test_file_upload_cors(fu_client, s3_tmp_bucket, upload_file_req):
assert cors_header in resp.headers


@pytest.mark.asyncio
async def test_make_presigned_url(fu_client, s3_tmp_bucket):
expected_url_fields = ("key", "x-amz-algorithm", "x-amz-credential", "x-amz-date", "policy", "x-amz-signature")

resp = await fu_client.make_request(ReqBuilder.presigned_url("mymd5"))
assert resp.status == 200
assert "url" in resp.json, resp.json
assert "fields" in resp.json, resp.json
assert all(field in resp.json["fields"] for field in expected_url_fields), resp.json


@pytest.mark.asyncio
async def test_file_upload(
fu_client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def upload_documents(
require_ok=require_ok,
)

@classmethod
def presigned_url(cls, content_md5: str, *, require_ok: bool = True) -> Req:
return Req(
method="post",
url="/api/v2/make_presigned_url",
data_json={
"content_md5": content_md5,
},
require_ok=require_ok,
)

@classmethod
def file_status(cls, file_id: str) -> Req:
return Req(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def is_applicable(self) -> bool:

@attr.s(init=True, kw_only=True)
class DataFile(RedisModelUserIdAuth):
s3_key: Optional[str] = attr.ib(default=None)
filename: Optional[str] = attr.ib()
file_type: Optional[FileType] = attr.ib(default=None)
file_settings: Optional[FileSettings] = attr.ib(default=None)
Expand All @@ -195,10 +196,14 @@ class DataFile(RedisModelUserIdAuth):
error: Optional[FileProcessingError] = attr.ib(default=None)

KEY_PREFIX: ClassVar[str] = "df"
DEFAULT_TTL: ClassVar[Optional[int]] = 12 * 60 * 60 # 12 hours
DEFAULT_TTL: ClassVar[Optional[int]] = 3 * 60 * 60 # 3 hours

@property
def s3_key(self) -> str:
def s3_key_old(self) -> str:
# transition from s3_key generated by self.id to stored self.s3_key, to be removed in future releases
# see also: DataFileSchema
if self.s3_key is not None:
return self.s3_key
return self.id

def get_secret_keys(self) -> set[DataKey]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class DataFileSchema(BaseModelSchema):
class Meta(BaseModelSchema.Meta):
target = DataFile

s3_key = fields.String(load_default=None, dump_default=None) # TODO remove defaults after transition
filename = fields.String()
file_type = fields.Enum(FileType, allow_none=True)
file_settings = fields.Nested(FileSettingsSchema, allow_none=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def _chunk_iter(chunk_size: int = 10 * 1024 * 1024) -> AsyncGenerator[byte
data_stream = RawBytesAsyncDataStream(data_iter=_chunk_iter())
async with S3RawFileAsyncDataSink(
s3=s3.client,
s3_key=dfile.s3_key,
s3_key=dfile.s3_key_old,
bucket_name=s3.tmp_bucket_name,
max_file_size_exc=exc.FileLimitError,
) as data_sink:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def run(self) -> TaskResult:

s3_resp = await s3.client.get_object(
Bucket=s3.tmp_bucket_name,
Key=dfile.s3_key,
Key=dfile.s3_key_old,
)
file_obj = await s3_resp["Body"].read()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def _prepare_file(
dst_source_id: str,
conn_raw_schema: list[SchemaColumn],
) -> str:
src_filename = dfile.s3_key if dfile.file_type == FileType.csv else src_source.s3_key
src_filename = dfile.s3_key_old if dfile.file_type == FileType.csv else src_source.s3_key

tmp_s3_filename = _make_tmp_source_s3_filename(dst_source_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _make_sample_text(self) -> str:

s3_resp = await self.s3.client.get_object(
Bucket=self.s3.tmp_bucket_name,
Key=self.dfile.s3_key,
Key=self.dfile.s3_key_old,
Range=f"bytes=0-{self.sample_size}",
)
sample_bytes = await s3_resp["Body"].read()
Expand Down Expand Up @@ -165,7 +165,7 @@ async def guess_header_and_schema(
has_header = self.file_settings["first_line_is_header"]
LOGGER.info(f"Overriding `has_header` with user defined: has_header={has_header}")

data_stream = await loop.run_in_executor(self.tpe, self._get_sync_s3_data_stream, self.dfile.s3_key)
data_stream = await loop.run_in_executor(self.tpe, self._get_sync_s3_data_stream, self.dfile.s3_key_old)
has_header, raw_schema = await loop.run_in_executor(
self.tpe,
guess_header_and_schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@pytest.fixture(scope="function")
async def upload_file(s3_tmp_bucket, s3_persistent_bucket, s3_client, redis_model_manager):
async def uploader(csv_data: bytes) -> DataFile:
data_file_desc = DataFile(
dfile = DataFile(
manager=redis_model_manager,
filename="test_file.csv",
file_type=FileType.csv,
Expand All @@ -36,12 +36,12 @@ async def uploader(csv_data: bytes) -> DataFile:
await s3_client.put_object(
ACL="private",
Bucket=s3_tmp_bucket,
Key=data_file_desc.s3_key,
Key=dfile.s3_key_old,
Body=csv_data,
)

await data_file_desc.save()
return data_file_desc
await dfile.save()
return dfile

yield uploader

Expand Down Expand Up @@ -96,7 +96,7 @@ async def uploaded_file_dt_id(uploaded_file_dt) -> str:
async def uploaded_10mb_file_id(s3_tmp_bucket, s3_persistent_bucket, s3_client, redis_model_manager) -> str:
csv_data = generate_sample_csv_data_str(row_count=10000, str_cols_count=30).encode("utf-8")

data_file_desc = DataFile(
dfile = DataFile(
manager=redis_model_manager,
filename="test_file_10mb.csv",
file_type=FileType.csv,
Expand All @@ -106,18 +106,18 @@ async def uploaded_10mb_file_id(s3_tmp_bucket, s3_persistent_bucket, s3_client,
await s3_client.put_object(
ACL="private",
Bucket=s3_tmp_bucket,
Key=data_file_desc.s3_key,
Key=dfile.s3_key_old,
Body=csv_data,
)

await data_file_desc.save()
yield data_file_desc.id
await dfile.save()
yield dfile.id


@pytest.fixture(scope="function")
async def uploaded_excel_file(s3_tmp_bucket, s3_persistent_bucket, s3_client, redis_model_manager):
async def uploader(filename: str) -> DataFile:
data_file_desc = DataFile(
dfile = DataFile(
manager=redis_model_manager,
filename=filename,
file_type=FileType.xlsx,
Expand All @@ -131,12 +131,12 @@ async def uploader(filename: str) -> DataFile:
await s3_client.put_object(
ACL="private",
Bucket=s3_tmp_bucket,
Key=data_file_desc.s3_key,
Key=dfile.s3_key_old,
Body=fd.read(),
)

await data_file_desc.save()
return data_file_desc
await dfile.save()
return dfile

yield uploader

Expand Down
Loading

0 comments on commit 413bacc

Please sign in to comment.