From f24ff078626241dc8567443a8433b8d0a0867075 Mon Sep 17 00:00:00 2001 From: yanghua Date: Sat, 21 Sep 2024 13:23:31 +0800 Subject: [PATCH] Extract mpu and staging part logic into mpu.py --- tosfs/core.py | 152 +++++++----------------------------------------- tosfs/mpu.py | 157 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 132 deletions(-) create mode 100644 tosfs/mpu.py diff --git a/tosfs/core.py b/tosfs/core.py index a7ced85..3341c33 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -14,13 +14,11 @@ """The core module of TOSFS.""" import io -import itertools import logging import mimetypes import os import tempfile import time -from concurrent.futures import ThreadPoolExecutor from glob import has_magic from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union @@ -31,7 +29,6 @@ from tos.exceptions import TosClientError, TosServerError from tos.models import CommonPrefixInfo from tos.models2 import ( - CreateMultipartUploadOutput, ListedObject, ListedObjectVersion, ListObjectType2Output, @@ -57,6 +54,7 @@ ) from tosfs.exceptions import TosfsError from tosfs.fsspec_utils import glob_translate +from tosfs.mpu import MultipartUploader from tosfs.stability import retryable_func_executor from tosfs.utils import find_bucket_key, get_brange @@ -2032,19 +2030,18 @@ def __init__( self.path = path self.mode = mode self.autocommit = autocommit - self.mpu: CreateMultipartUploadOutput = None - self.parts: list = [] self.append_block = False self.buffer: Optional[io.BytesIO] = io.BytesIO() - self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs) - self.part_size = fs.multipart_size - self.thread_pool_size = fs.multipart_thread_pool_size - self.staging_buffer_size = fs.multipart_staging_buffer_size - self.multipart_threshold = fs.multipart_threshold - self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size) - self.staging_files: list[str] = [] - self.staging_buffer: io.BytesIO = io.BytesIO() + self.multipart_uploader = MultipartUploader( + fs=fs, + bucket=bucket, + key=key, + part_size=fs.multipart_size, + thread_pool_size=fs.multipart_thread_pool_size, + staging_buffer_size=fs.multipart_staging_buffer_size, + multipart_threshold=fs.multipart_threshold, + ) if "a" in mode and fs.exists(path): head = retryable_func_executor( @@ -2091,111 +2088,20 @@ def _upload_chunk(self, final: bool = False) -> bool: self.autocommit and not self.append_block and final - and self.tell() < max(self.blocksize, self.multipart_threshold) + and self.tell() + < max(self.blocksize, self.multipart_uploader.multipart_threshold) ): # only happens when closing small file, use one-shot PUT pass else: - self.parts = [] - - self.mpu = retryable_func_executor( - lambda: self.fs.tos_client.create_multipart_upload( - self.bucket, self.key - ), - max_retry_num=self.fs.max_retry_num, - ) - - if self.append_block: - # use existing data in key when appending, - # and block is big enough - out = retryable_func_executor( - lambda: self.fs.tos_client.upload_part_copy( - bucket=self.bucket, - key=self.key, - part_number=1, - upload_id=self.mpu.upload_id, - ), - max_retry_num=self.fs.max_retry_num, - ) - - self.parts.append({"PartNumber": out.part_number, "ETag": out.etag}) - - self._upload_multiple_chunks(bucket, key) + self.multipart_uploader.initiate_upload() + self.multipart_uploader.upload_multiple_chunks(self.buffer) if self.autocommit and final: self.commit() return not final - def _upload_multiple_chunks(self, bucket: str, key: str) -> None: - if self.buffer: - self.buffer.seek(0) - while True: - chunk = self.buffer.read(self.part_size) - if not chunk: - break - - self._write_to_staging_buffer(chunk) - - def _write_to_staging_buffer(self, chunk: bytes) -> None: - self.staging_buffer.write(chunk) - if self.staging_buffer.tell() >= self.staging_buffer_size: - self._flush_staging_buffer() - - def _flush_staging_buffer(self) -> None: - if self.staging_buffer.tell() == 0: - return - - self.staging_buffer.seek(0) - staging_dir = next(self.staging_dirs) - with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: - tmp.write(self.staging_buffer.read()) - self.staging_files.append(tmp.name) - - self.staging_buffer = io.BytesIO() - - def _upload_staged_files(self) -> None: - self._flush_staging_buffer() - futures = [] - for i, staging_file in enumerate(self.staging_files): - part_number = i + 1 - futures.append( - self.executor.submit( - self._upload_part_from_file, staging_file, part_number - ) - ) - - for future in futures: - part_info = future.result() - self.parts.append(part_info) - - self.staging_files = [] - - def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo: - with open(staging_file, "rb") as f: - content = f.read() - - out = retryable_func_executor( - lambda: self.fs.tos_client.upload_part( - bucket=self.bucket, - key=self.key, - part_number=part_number, - upload_id=self.mpu.upload_id, - content=content, - ), - max_retry_num=self.fs.max_retry_num, - ) - - os.remove(staging_file) - return PartInfo( - part_number=part_number, - etag=out.etag, - part_size=len(content), - offset=None, - hash_crc64_ecma=None, - is_completed=None, - ) - def _fetch_range(self, start: int, end: int) -> bytes: if start == end: logger.debug( @@ -2225,9 +2131,9 @@ def commit(self) -> None: if self.tell() == 0: if self.buffer is not None: logger.debug("Empty file committed %s", self) - self._abort_mpu() + self.multipart_uploader.abort_upload() self.fs.touch(self.path, **self.kwargs) - elif not self.staging_files: + elif not self.multipart_uploader.staging_files: if self.buffer is not None: logger.debug("One-shot upload of %s", self) self.buffer.seek(0) @@ -2241,17 +2147,9 @@ def commit(self) -> None: else: raise RuntimeError else: - self._upload_staged_files() logger.debug("Complete multi-part upload for %s ", self) - write_result = retryable_func_executor( - lambda: self.fs.tos_client.complete_multipart_upload( - self.bucket, - self.key, - upload_id=self.mpu.upload_id, - parts=self.parts, - ), - max_retry_num=self.fs.max_retry_num, - ) + self.multipart_uploader._upload_staged_files() + self.multipart_uploader.complete_upload() if self.fs.version_aware: self.version_id = write_result.version_id @@ -2260,15 +2158,5 @@ def commit(self) -> None: def discard(self) -> None: """Close the file without writing.""" - self._abort_mpu() - self.buffer = None # file becomes unusable - - def _abort_mpu(self) -> None: - if self.mpu: - retryable_func_executor( - lambda: self.fs.tos_client.abort_multipart_upload( - self.bucket, self.key, self.mpu.upload_id - ), - max_retry_num=self.fs.max_retry_num, - ) - self.mpu = None + self.multipart_uploader.abort_upload() + self.buffer = None diff --git a/tosfs/mpu.py b/tosfs/mpu.py new file mode 100644 index 0000000..f574fa6 --- /dev/null +++ b/tosfs/mpu.py @@ -0,0 +1,157 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The module contains the MultipartUploader class for the tosfs package.""" + +import io +import itertools +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Optional + +from tos.models2 import CreateMultipartUploadOutput, PartInfo + +from tosfs.stability import retryable_func_executor + +if TYPE_CHECKING: + from tosfs.core import TosFileSystem + + +class MultipartUploader: + """A class to upload large files to the object store using multipart upload.""" + + def __init__( + self, + fs: "TosFileSystem", + bucket: str, + key: str, + part_size: int, + thread_pool_size: int, + staging_buffer_size: int, + multipart_threshold: int, + ): + """Instantiate a MultipartUploader object.""" + self.fs = fs + self.bucket = bucket + self.key = key + self.part_size = part_size + self.thread_pool_size = thread_pool_size + self.staging_buffer_size = staging_buffer_size + self.multipart_threshold = multipart_threshold + self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size) + self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs) + self.staging_files: list[str] = [] + self.staging_buffer: io.BytesIO = io.BytesIO() + self.parts: list = [] + self.mpu: CreateMultipartUploadOutput = None + + def initiate_upload(self) -> None: + """Initiate the multipart upload.""" + self.mpu = retryable_func_executor( + lambda: self.fs.tos_client.create_multipart_upload(self.bucket, self.key), + max_retry_num=self.fs.max_retry_num, + ) + + def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None: + """Upload multiple chunks of data to the object store.""" + if buffer: + buffer.seek(0) + while True: + chunk = buffer.read(self.part_size) + if not chunk: + break + self._write_to_staging_buffer(chunk) + + def _write_to_staging_buffer(self, chunk: bytes) -> None: + self.staging_buffer.write(chunk) + if self.staging_buffer.tell() >= self.staging_buffer_size: + self._flush_staging_buffer() + + def _flush_staging_buffer(self) -> None: + if self.staging_buffer.tell() == 0: + return + + self.staging_buffer.seek(0) + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read()) + self.staging_files.append(tmp.name) + + self.staging_buffer = io.BytesIO() + + def _upload_staged_files(self) -> None: + self._flush_staging_buffer() + futures = [] + for i, staging_file in enumerate(self.staging_files): + part_number = i + 1 + futures.append( + self.executor.submit( + self._upload_part_from_file, staging_file, part_number + ) + ) + + for future in futures: + part_info = future.result() + self.parts.append(part_info) + + self.staging_files = [] + + def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo: + with open(staging_file, "rb") as f: + content = f.read() + + out = retryable_func_executor( + lambda: self.fs.tos_client.upload_part( + bucket=self.bucket, + key=self.key, + part_number=part_number, + upload_id=self.mpu.upload_id, + content=content, + ), + max_retry_num=self.fs.max_retry_num, + ) + + os.remove(staging_file) + return PartInfo( + part_number=part_number, + etag=out.etag, + part_size=len(content), + offset=None, + hash_crc64_ecma=None, + is_completed=None, + ) + + def complete_upload(self) -> None: + """Complete the multipart upload.""" + retryable_func_executor( + lambda: self.fs.tos_client.complete_multipart_upload( + self.bucket, + self.key, + upload_id=self.mpu.upload_id, + parts=self.parts, + ), + max_retry_num=self.fs.max_retry_num, + ) + + def abort_upload(self) -> None: + """Abort the multipart upload.""" + if self.mpu: + retryable_func_executor( + lambda: self.fs.tos_client.abort_multipart_upload( + self.bucket, self.key, self.mpu.upload_id + ), + max_retry_num=self.fs.max_retry_num, + ) + self.mpu = None