Skip to content

Commit

Permalink
Optimize: Introduce multiple disk write for MPU
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 20, 2024
1 parent 93bd8b1 commit 188c8d8
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 95 deletions.
216 changes: 130 additions & 86 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

"""The core module of TOSFS."""
import io
import itertools
import logging
import mimetypes
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from glob import has_magic
from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union

Expand All @@ -35,7 +38,6 @@
ListObjectVersionsOutput,
PartInfo,
UploadPartCopyOutput,
UploadPartOutput,
)

from tosfs.consts import (
Expand Down Expand Up @@ -113,6 +115,11 @@ def __init__(
default_block_size: Optional[int] = None,
default_fill_cache: bool = True,
default_cache_type: str = "readahead",
multipart_staging_dirs: str = tempfile.mkdtemp(),
multipart_size: int = 8 << 20,
multipart_thread_pool_size: int = max(2, os.cpu_count() or 1),
multipart_staging_buffer_size: int = 4 << 10,
multipart_threshold: int = 10 << 20,
**kwargs: Any,
) -> None:
"""Initialise the TosFileSystem.
Expand Down Expand Up @@ -157,6 +164,26 @@ def __init__(
Whether to fill the cache (default is True).
default_cache_type : str, optional
The default cache type (default is 'readahead').
multipart_staging_dirs : str, optional
The staging directories for multipart uploads (default is a temporary
directory). Separate the staging dirs with comma if there are many
staging dir paths.
multipart_size : int, optional
The multipart upload part size of the given object storage.
(default is 8MB).
multipart_thread_pool_size : int, optional
The size of thread pool used for uploading multipart in parallel for the
given object storage. (default is max(2, os.cpu_count()).
multipart_staging_buffer_size : int, optional
The max byte size which will buffer the staging data in-memory before
flushing to the staging file. It will decrease the random write in local
staging disk dramatically if writing plenty of small files.
(default is 4096).
multipart_threshold : int, optional
The threshold which control whether enable multipart upload during
writing data to the given object storage, if the write data size is less
than threshold, will write data via simple put instead of multipart upload.
(default is 10 MB).
kwargs : Any, optional
Additional arguments.
Expand Down Expand Up @@ -184,6 +211,14 @@ def __init__(
self.default_cache_type = default_cache_type
self.max_retry_num = max_retry_num

self.multipart_staging_dirs = [
d.strip() for d in multipart_staging_dirs.split(",")
]
self.multipart_size = multipart_size
self.multipart_thread_pool_size = multipart_thread_pool_size
self.multipart_staging_buffer_size = multipart_staging_buffer_size
self.multipart_threshold = multipart_threshold

super().__init__(**kwargs)

def _open(
Expand Down Expand Up @@ -1998,10 +2033,19 @@ def __init__(
self.mode = mode
self.autocommit = autocommit
self.mpu: CreateMultipartUploadOutput = None
self.parts: Optional[list] = None
self.parts: list = []
self.append_block = False
self.buffer: Optional[io.BytesIO] = io.BytesIO()

self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs)
self.part_size = fs.multipart_size
self.thread_pool_size = fs.multipart_thread_pool_size
self.staging_buffer_size = fs.multipart_staging_buffer_size
self.multipart_threshold = fs.multipart_threshold
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.staging_files: list[str] = []
self.staging_buffer: io.BytesIO = io.BytesIO()

if "a" in mode and fs.exists(path):
head = retryable_func_executor(
lambda: self.fs.tos_client.head_object(bucket, key),
Expand All @@ -2022,27 +2066,6 @@ def _initiate_upload(self) -> None:
# only happens when closing small file, use on-shot PUT
return
logger.debug("Initiate upload for %s", self)
self.parts = []

self.mpu = retryable_func_executor(
lambda: self.fs.tos_client.create_multipart_upload(self.bucket, self.key),
max_retry_num=self.fs.max_retry_num,
)

if self.append_block:
# use existing data in key when appending,
# and block is big enough
out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part_copy(
bucket=self.bucket,
key=self.key,
part_number=1,
upload_id=self.mpu.upload_id,
),
max_retry_num=self.fs.max_retry_num,
)

self.parts.append({"PartNumber": out.part_number, "ETag": out.etag})

def _upload_chunk(self, final: bool = False) -> bool:
"""Write one part of a multi-block file upload.
Expand All @@ -2068,11 +2091,35 @@ def _upload_chunk(self, final: bool = False) -> bool:
self.autocommit
and not self.append_block
and final
and self.tell() < self.blocksize
and self.tell() < max(self.blocksize, self.multipart_threshold)
):
# only happens when closing small file, use one-shot PUT
pass
else:
self.parts = []

self.mpu = retryable_func_executor(
lambda: self.fs.tos_client.create_multipart_upload(
self.bucket, self.key
),
max_retry_num=self.fs.max_retry_num,
)

if self.append_block:
# use existing data in key when appending,
# and block is big enough
out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part_copy(
bucket=self.bucket,
key=self.key,
part_number=1,
upload_id=self.mpu.upload_id,
),
max_retry_num=self.fs.max_retry_num,
)

self.parts.append({"PartNumber": out.part_number, "ETag": out.etag})

self._upload_multiple_chunks(bucket, key)

if self.autocommit and final:
Expand All @@ -2083,75 +2130,71 @@ def _upload_chunk(self, final: bool = False) -> bool:
def _upload_multiple_chunks(self, bucket: str, key: str) -> None:
if self.buffer:
self.buffer.seek(0)
current_chunk: Optional[bytes] = self.buffer.read(self.blocksize)
while True:
chunk = self.buffer.read(self.part_size)
if not chunk:
break

while current_chunk:
(previous_chunk, current_chunk) = (
current_chunk,
self.buffer.read(self.blocksize) if self.buffer else None,
)
current_chunk_size = len(current_chunk if current_chunk else b"")

# Define a helper function to handle the remainder logic
def handle_remainder(
previous_chunk: bytes,
current_chunk: Optional[bytes],
blocksize: int,
part_max: int,
) -> Tuple[bytes, Optional[bytes]]:
if current_chunk:
remainder = previous_chunk + current_chunk
else:
remainder = previous_chunk
self._write_to_staging_buffer(chunk)

remainder_size = (
blocksize + len(current_chunk) if current_chunk else blocksize
)
def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.staging_buffer_size:
self._flush_staging_buffer()

if remainder_size <= part_max:
return remainder, None
else:
partition = remainder_size // 2
return remainder[:partition], remainder[partition:]
def _flush_staging_buffer(self) -> None:
if self.staging_buffer.tell() == 0:
return

# Use the helper function in the main code
if 0 < current_chunk_size < self.blocksize:
previous_chunk, current_chunk = handle_remainder(
previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE
)
self.staging_buffer.seek(0)
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)

self.staging_buffer = io.BytesIO()

def _upload_staged_files(self) -> None:
self._flush_staging_buffer()
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
futures.append(
self.executor.submit(
self._upload_part_from_file, staging_file, part_number
)
)

part = len(self.parts) + 1 if self.parts is not None else 1
logger.debug("Upload chunk %s, %s", self, part)
for future in futures:
part_info = future.result()
self.parts.append(part_info)

def _call_upload_part(
part: int = part, previous_chunk: Optional[bytes] = previous_chunk
) -> UploadPartOutput:
return self.fs.tos_client.upload_part(
bucket=bucket,
key=key,
part_number=part,
upload_id=self.mpu.upload_id,
content=previous_chunk,
)
self.staging_files = []

out = retryable_func_executor(
_call_upload_part, max_retry_num=self.fs.max_retry_num
)
def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo:
with open(staging_file, "rb") as f:
content = f.read()

(
self.parts.append(
PartInfo(
part_number=part,
etag=out.etag,
part_size=len(previous_chunk),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)
)
if self.parts is not None
else None
)
out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part(
bucket=self.bucket,
key=self.key,
part_number=part_number,
upload_id=self.mpu.upload_id,
content=content,
),
max_retry_num=self.fs.max_retry_num,
)

os.remove(staging_file)
return PartInfo(
part_number=part_number,
etag=out.etag,
part_size=len(content),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)

def _fetch_range(self, start: int, end: int) -> bytes:
if start == end:
Expand Down Expand Up @@ -2184,7 +2227,7 @@ def commit(self) -> None:
logger.debug("Empty file committed %s", self)
self._abort_mpu()
self.fs.touch(self.path, **self.kwargs)
elif not self.parts:
elif not self.staging_files:
if self.buffer is not None:
logger.debug("One-shot upload of %s", self)
self.buffer.seek(0)
Expand All @@ -2198,6 +2241,7 @@ def commit(self) -> None:
else:
raise RuntimeError
else:
self._upload_staged_files()
logger.debug("Complete multi-part upload for %s ", self)
write_result = retryable_func_executor(
lambda: self.fs.tos_client.complete_multipart_upload(
Expand Down
25 changes: 16 additions & 9 deletions tosfs/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@

from tosfs.exceptions import TosfsError

CONFLICT_CODE = "409"

TOS_SERVER_RETRYABLE_STATUS_CODES = {
"409", # CONFLICT
CONFLICT_CODE, # CONFLICT
"429", # TOO_MANY_REQUESTS
"500", # INTERNAL_SERVER_ERROR
}
Expand Down Expand Up @@ -93,13 +95,17 @@ def retryable_func_executor(
raise e

if is_retryable_exception(e):
logger.warn("Retry TOS request in the %d times, error: %s", attempt, e)
logger.warning(
"Retry TOS request in the %d times, error: %s", attempt, e
)
try:
time.sleep(min(1.7**attempt * 0.1, 15))
except InterruptedError as ie:
raise TosfsError(f"Request {func} interrupted.") from ie
else:
raise e
# Note: maybe not all the retryable exceptions are warped by `TosError`
# Will pay attention to those cases
except Exception as e:
raise TosfsError(f"{e}") from e

Expand All @@ -112,13 +118,14 @@ def is_retryable_exception(e: TosError) -> bool:


def _is_retryable_tos_server_exception(e: TosError) -> bool:
return (
isinstance(e, TosServerError)
and e.status_code in TOS_SERVER_RETRYABLE_STATUS_CODES
# exclude some special error code under 409(conflict) status code
# let it fast fail
and e.code not in TOS_SERVER_NOT_RETRYABLE_CONFLICT_ERROR_CODES
)
if not isinstance(e, TosServerError):
return False

# not all conflict errors are retryable
if e.status_code == CONFLICT_CODE:
return e.code not in TOS_SERVER_NOT_RETRYABLE_CONFLICT_ERROR_CODES

return e.status_code in TOS_SERVER_RETRYABLE_STATUS_CODES


def _is_retryable_tos_client_exception(e: TosError) -> bool:
Expand Down

0 comments on commit 188c8d8

Please sign in to comment.