Skip to content

Commit

Permalink
Extract mpu and staging part logic into mpu.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 21, 2024
1 parent f44911e commit ac6818d
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 132 deletions.
152 changes: 20 additions & 132 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

"""The core module of TOSFS."""
import io
import itertools
import logging
import mimetypes
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from glob import has_magic
from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union

Expand All @@ -31,7 +29,6 @@
from tos.exceptions import TosClientError, TosServerError
from tos.models import CommonPrefixInfo
from tos.models2 import (
CreateMultipartUploadOutput,
ListedObject,
ListedObjectVersion,
ListObjectType2Output,
Expand All @@ -58,6 +55,7 @@
from tosfs.exceptions import TosfsError
from tosfs.fsspec_utils import glob_translate
from tosfs.retry import retryable_func_executor
from tosfs.mpu import MultipartUploader
from tosfs.utils import find_bucket_key, get_brange

logger = logging.getLogger("tosfs")
Expand Down Expand Up @@ -2032,19 +2030,18 @@ def __init__(
self.path = path
self.mode = mode
self.autocommit = autocommit
self.mpu: CreateMultipartUploadOutput = None
self.parts: list = []
self.append_block = False
self.buffer: Optional[io.BytesIO] = io.BytesIO()

self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs)
self.part_size = fs.multipart_size
self.thread_pool_size = fs.multipart_thread_pool_size
self.staging_buffer_size = fs.multipart_staging_buffer_size
self.multipart_threshold = fs.multipart_threshold
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.staging_files: list[str] = []
self.staging_buffer: io.BytesIO = io.BytesIO()
self.multipart_uploader = MultipartUploader(
fs=fs,
bucket=bucket,
key=key,
part_size=fs.multipart_size,
thread_pool_size=fs.multipart_thread_pool_size,
staging_buffer_size=fs.multipart_staging_buffer_size,
multipart_threshold=fs.multipart_threshold,
)

if "a" in mode and fs.exists(path):
head = retryable_func_executor(
Expand Down Expand Up @@ -2097,111 +2094,20 @@ def _upload_chunk(self, final: bool = False) -> bool:
self.autocommit
and not self.append_block
and final
and self.tell() < max(self.blocksize, self.multipart_threshold)
and self.tell()
< max(self.blocksize, self.multipart_uploader.multipart_threshold)
):
# only happens when closing small file, use one-shot PUT
pass
else:
self.parts = []

self.mpu = retryable_func_executor(
lambda: self.fs.tos_client.create_multipart_upload(
self.bucket, self.key
),
max_retry_num=self.fs.max_retry_num,
)

if self.append_block:
# use existing data in key when appending,
# and block is big enough
out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part_copy(
bucket=self.bucket,
key=self.key,
part_number=1,
upload_id=self.mpu.upload_id,
),
max_retry_num=self.fs.max_retry_num,
)

self.parts.append({"PartNumber": out.part_number, "ETag": out.etag})

self._upload_multiple_chunks(bucket, key)
self.multipart_uploader.initiate_upload()
self.multipart_uploader.upload_multiple_chunks(self.buffer)

if self.autocommit and final:
self.commit()

return not final

def _upload_multiple_chunks(self, bucket: str, key: str) -> None:
if self.buffer:
self.buffer.seek(0)
while True:
chunk = self.buffer.read(self.part_size)
if not chunk:
break

self._write_to_staging_buffer(chunk)

def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.staging_buffer_size:
self._flush_staging_buffer()

def _flush_staging_buffer(self) -> None:
if self.staging_buffer.tell() == 0:
return

self.staging_buffer.seek(0)
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)

self.staging_buffer = io.BytesIO()

def _upload_staged_files(self) -> None:
self._flush_staging_buffer()
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
futures.append(
self.executor.submit(
self._upload_part_from_file, staging_file, part_number
)
)

for future in futures:
part_info = future.result()
self.parts.append(part_info)

self.staging_files = []

def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo:
with open(staging_file, "rb") as f:
content = f.read()

out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part(
bucket=self.bucket,
key=self.key,
part_number=part_number,
upload_id=self.mpu.upload_id,
content=content,
),
max_retry_num=self.fs.max_retry_num,
)

os.remove(staging_file)
return PartInfo(
part_number=part_number,
etag=out.etag,
part_size=len(content),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)

def _fetch_range(self, start: int, end: int) -> bytes:
if start == end:
logger.debug(
Expand Down Expand Up @@ -2231,9 +2137,9 @@ def commit(self) -> None:
if self.tell() == 0:
if self.buffer is not None:
logger.debug("Empty file committed %s", self)
self._abort_mpu()
self.multipart_uploader.abort_upload()
self.fs.touch(self.path, **self.kwargs)
elif not self.staging_files:
elif not self.multipart_uploader.staging_files:
if self.buffer is not None:
logger.debug("One-shot upload of %s", self)
self.buffer.seek(0)
Expand All @@ -2247,17 +2153,9 @@ def commit(self) -> None:
else:
raise RuntimeError
else:
self._upload_staged_files()
logger.debug("Complete multi-part upload for %s ", self)
write_result = retryable_func_executor(
lambda: self.fs.tos_client.complete_multipart_upload(
self.bucket,
self.key,
upload_id=self.mpu.upload_id,
parts=self.parts,
),
max_retry_num=self.fs.max_retry_num,
)
self.multipart_uploader._upload_staged_files()
self.multipart_uploader.complete_upload()

if self.fs.version_aware:
self.version_id = write_result.version_id
Expand All @@ -2266,15 +2164,5 @@ def commit(self) -> None:

def discard(self) -> None:
"""Close the file without writing."""
self._abort_mpu()
self.buffer = None # file becomes unusable

def _abort_mpu(self) -> None:
if self.mpu:
retryable_func_executor(
lambda: self.fs.tos_client.abort_multipart_upload(
self.bucket, self.key, self.mpu.upload_id
),
max_retry_num=self.fs.max_retry_num,
)
self.mpu = None
self.multipart_uploader.abort_upload()
self.buffer = None
157 changes: 157 additions & 0 deletions tosfs/mpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The module contains the MultipartUploader class for the tosfs package."""

import io
import itertools
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional

from tos.models2 import CreateMultipartUploadOutput, PartInfo

from tosfs.stability import retryable_func_executor

if TYPE_CHECKING:
from tosfs.core import TosFileSystem


class MultipartUploader:
"""A class to upload large files to the object store using multipart upload."""

def __init__(
self,
fs: "TosFileSystem",
bucket: str,
key: str,
part_size: int,
thread_pool_size: int,
staging_buffer_size: int,
multipart_threshold: int,
):
"""Instantiate a MultipartUploader object."""
self.fs = fs
self.bucket = bucket
self.key = key
self.part_size = part_size
self.thread_pool_size = thread_pool_size
self.staging_buffer_size = staging_buffer_size
self.multipart_threshold = multipart_threshold
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs)
self.staging_files: list[str] = []
self.staging_buffer: io.BytesIO = io.BytesIO()
self.parts: list = []
self.mpu: CreateMultipartUploadOutput = None

def initiate_upload(self) -> None:
"""Initiate the multipart upload."""
self.mpu = retryable_func_executor(
lambda: self.fs.tos_client.create_multipart_upload(self.bucket, self.key),
max_retry_num=self.fs.max_retry_num,
)

def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None:
"""Upload multiple chunks of data to the object store."""
if buffer:
buffer.seek(0)
while True:
chunk = buffer.read(self.part_size)
if not chunk:
break
self._write_to_staging_buffer(chunk)

def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.staging_buffer_size:
self._flush_staging_buffer()

def _flush_staging_buffer(self) -> None:
if self.staging_buffer.tell() == 0:
return

self.staging_buffer.seek(0)
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)

self.staging_buffer = io.BytesIO()

def _upload_staged_files(self) -> None:
self._flush_staging_buffer()
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
futures.append(
self.executor.submit(
self._upload_part_from_file, staging_file, part_number
)
)

for future in futures:
part_info = future.result()
self.parts.append(part_info)

self.staging_files = []

def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo:
with open(staging_file, "rb") as f:
content = f.read()

out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part(
bucket=self.bucket,
key=self.key,
part_number=part_number,
upload_id=self.mpu.upload_id,
content=content,
),
max_retry_num=self.fs.max_retry_num,
)

os.remove(staging_file)
return PartInfo(
part_number=part_number,
etag=out.etag,
part_size=len(content),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)

def complete_upload(self) -> None:
"""Complete the multipart upload."""
retryable_func_executor(
lambda: self.fs.tos_client.complete_multipart_upload(
self.bucket,
self.key,
upload_id=self.mpu.upload_id,
parts=self.parts,
),
max_retry_num=self.fs.max_retry_num,
)

def abort_upload(self) -> None:
"""Abort the multipart upload."""
if self.mpu:
retryable_func_executor(
lambda: self.fs.tos_client.abort_multipart_upload(
self.bucket, self.key, self.mpu.upload_id
),
max_retry_num=self.fs.max_retry_num,
)
self.mpu = None

0 comments on commit ac6818d

Please sign in to comment.