diff --git a/tosfs/core.py b/tosfs/core.py index 6f58a41..7f9621e 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -13,6 +13,7 @@ # limitations under the License. """The core module of TOSFS.""" +import io import logging import mimetypes import os @@ -21,9 +22,16 @@ import tos from fsspec import AbstractFileSystem +from fsspec.spec import AbstractBufferedFile from fsspec.utils import setup_logging as setup_logger from tos.models import CommonPrefixInfo -from tos.models2 import ListedObject, ListedObjectVersion +from tos.models2 import ( + CreateMultipartUploadOutput, + ListedObject, + ListedObjectVersion, + PartInfo, + UploadPartOutput, +) from tosfs.consts import TOS_SERVER_RESPONSE_CODE_NOT_FOUND from tosfs.exceptions import TosfsError @@ -58,6 +66,7 @@ class TosFileSystem(AbstractFileSystem): """ retries = 5 + default_block_size = 5 * 2**20 def __init__( self, @@ -67,7 +76,10 @@ def __init__( region: Optional[str] = None, version_aware: bool = False, credentials_provider: Optional[object] = None, - **kwargs: Union[str, bool, float, None], + default_block_size: Optional[int] = None, + default_fill_cache: bool = True, + default_cache_type: str = "readahead", + **kwargs: Any, ) -> None: """Initialise the TosFileSystem.""" self.tos_client = tos.TosClientV2( @@ -78,8 +90,77 @@ def __init__( credentials_provider=credentials_provider, ) self.version_aware = version_aware + self.default_block_size = default_block_size or self.default_block_size + self.default_fill_cache = default_fill_cache + self.default_cache_type = default_cache_type + super().__init__(**kwargs) + def _open( + self, + path: str, + mode: str = "rb", + block_size: Optional[int] = None, + version_id: Optional[str] = None, + fill_cache: Optional[bool] = None, + cache_type: Optional[str] = None, + autocommit: bool = True, + **kwargs: Any, + ) -> AbstractBufferedFile: + """Open a file for reading or writing. + + Parameters + ---------- + path: string + Path of file on TOS + mode: string + One of 'r', 'w', 'a', 'rb', 'wb', or 'ab'. These have the same meaning + as they do for the built-in `open` function. + block_size: int + Size of data-node blocks if reading + fill_cache: bool + If seeking to new a part of the file beyond the current buffer, + with this True, the buffer will be filled between the sections to + best support random access. When reading only a few specific chunks + out of a file, performance may be better if False. + version_id : str + Explicit version of the object to open. This requires that the s3 + filesystem is version aware and bucket versioning is enabled on the + relevant bucket. + cache_type : str + See fsspec's documentation for available cache_type values. Set to "none" + if no caching is desired. If None, defaults to ``self.default_cache_type``. + autocommit : bool + If True, writes will be committed to the filesystem on flush or close. + kwargs: dict-like + Additional parameters. + + """ + if block_size is None: + block_size = self.default_block_size + if fill_cache is None: + fill_cache = self.default_fill_cache + + if not self.version_aware and version_id: + raise ValueError( + "version_id cannot be specified if the filesystem " + "is not version aware" + ) + + if cache_type is None: + cache_type = self.default_cache_type + + return TosFile( + self, + path, + mode, + block_size=block_size, + version_id=version_id, + fill_cache=fill_cache, + cache_type=cache_type, + autocommit=autocommit, + ) + def ls( self, path: str, @@ -1116,10 +1197,10 @@ def _split_path(self, path: str) -> Tuple[str, str, Optional[str]]: Examples -------- - >>> split_path("tos://mybucket/path/to/file") + >>> self._split_path("tos://mybucket/path/to/file") ['mybucket', 'path/to/file', None] # pylint: disable=line-too-long - >>> split_path("tos://mybucket/path/to/versioned_file?versionId=some_version_id") + >>> self._split_path("tos://mybucket/path/to/versioned_file?versionId=some_version_id") ['mybucket', 'path/to/versioned_file', 'some_version_id'] """ @@ -1175,3 +1256,210 @@ def _fill_bucket_info(bucket_name: str) -> dict: "type": "directory", "name": bucket_name, } + + +class TosFile(AbstractBufferedFile): + """File-like operations for TOS.""" + + retries = 5 + part_min = 5 * 2**20 + part_max = 5 * 2**30 + + def __init__( + self, + fs: TosFileSystem, + path: str, + mode: str = "rb", + block_size: Union[int, str] = "default", + autocommit: bool = True, + cache_type: str = "readahead", + **kwargs: Any, + ): + """Instantiate a TOS file.""" + bucket, key, path_version_id = fs._split_path(path) + if not key: + raise ValueError("Attempt to open non key-like path: %s" % path) + super().__init__( + fs, + path, + mode, + block_size=block_size, + autocommit=autocommit, + cache_type=cache_type, + **kwargs, + ) + self.fs = fs + self.bucket = bucket + self.key = key + self.path = path + self.mode = mode + self.autocommit = autocommit + self.mpu: CreateMultipartUploadOutput = None + self.parts: Optional[list] = None + self.append_block = False + self.buffer: Optional[io.BytesIO] = io.BytesIO() + + def _initiate_upload(self) -> None: + """Create remote file/upload.""" + if self.autocommit and not self.append_block and self.tell() < self.blocksize: + # only happens when closing small file, use on-shot PUT + return + logger.debug("Initiate upload for %s", self) + self.parts = [] + + self.mpu = self.fs.tos_client.create_multipart_upload(self.bucket, self.key) + + if self.append_block: + # use existing data in key when appending, + # and block is big enough + out = self.fs.tos_client.upload_part_copy( + bucket=self.bucket, + key=self.key, + part_number=1, + upload_id=self.mpu.upload_id, + ) + + self.parts.append({"PartNumber": out.part_number, "ETag": out.etag}) + + def _upload_chunk(self, final: bool = False) -> bool: + """Write one part of a multi-block file upload. + + Parameters + ---------- + final: bool + This is the last block, so should complete file, if + self.autocommit is True. + + """ + bucket, key, _ = self.fs._split_path(self.path) + if self.buffer: + logger.debug( + "Upload for %s, final=%s, loc=%s, buffer loc=%s", + self, + final, + self.loc, + self.buffer.tell(), + ) + + if ( + self.autocommit + and not self.append_block + and final + and self.tell() < self.blocksize + ): + # only happens when closing small file, use one-shot PUT + pass + else: + self._upload_multiple_chunks(bucket, key) + + if self.autocommit and final: + self.commit() + + return not final + + def _upload_multiple_chunks(self, bucket: str, key: str) -> None: + if self.buffer: + self.buffer.seek(0) + current_chunk: Optional[bytes] = self.buffer.read(self.blocksize) + + while current_chunk: + (previous_chunk, current_chunk) = ( + current_chunk, + self.buffer.read(self.blocksize) if self.buffer else None, + ) + current_chunk_size = len(current_chunk if current_chunk else b"") + + # Define a helper function to handle the remainder logic + def handle_remainder( + previous_chunk: bytes, + current_chunk: Optional[bytes], + blocksize: int, + part_max: int, + ) -> Tuple[bytes, Optional[bytes]]: + if current_chunk: + remainder = previous_chunk + current_chunk + else: + remainder = previous_chunk + + remainder_size = ( + blocksize + len(current_chunk) if current_chunk else blocksize + ) + + if remainder_size <= part_max: + return remainder, None + else: + partition = remainder_size // 2 + return remainder[:partition], remainder[partition:] + + # Use the helper function in the main code + if 0 < current_chunk_size < self.blocksize: + previous_chunk, current_chunk = handle_remainder( + previous_chunk, current_chunk, self.blocksize, self.part_max + ) + + part = len(self.parts) + 1 if self.parts is not None else 1 + logger.debug("Upload chunk %s, %s", self, part) + + out: UploadPartOutput = self.fs.tos_client.upload_part( + bucket=bucket, + key=key, + part_number=part, + upload_id=self.mpu.upload_id, + content=previous_chunk, + ) + + ( + self.parts.append( + PartInfo( + part_number=part, + etag=out.etag, + part_size=len(previous_chunk), + offset=None, + hash_crc64_ecma=None, + is_completed=None, + ) + ) + if self.parts is not None + else None + ) + + def commit(self) -> None: + """Complete multipart upload or PUT.""" + logger.debug("Commit %s", self) + if self.tell() == 0: + if self.buffer is not None: + logger.debug("Empty file committed %s", self) + self._abort_mpu() + self.fs.touch(self.path, **self.kwargs) + elif not self.parts: + if self.buffer is not None: + logger.debug("One-shot upload of %s", self) + self.buffer.seek(0) + data = self.buffer.read() + write_result = self.fs.tos_client.put_object( + self.bucket, self.key, content=data + ) + else: + raise RuntimeError + else: + logger.debug("Complete multi-part upload for %s ", self) + write_result = self.fs.tos_client.complete_multipart_upload( + self.bucket, self.key, upload_id=self.mpu.upload_id, parts=self.parts + ) + + if self.fs.version_aware: + self.version_id = write_result.version_id + + self.buffer = None + + def discard(self) -> None: + """Close the file without writing.""" + self._abort_mpu() + self.buffer = None # file becomes unusable + + def _abort_mpu(self) -> None: + if self.mpu: + self.fs.tos_client.abort_multipart_upload( + self.bucket, self.key, self.mpu.upload_id + ) + self.mpu = None diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index 1b45005..3a87e05 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -341,3 +341,93 @@ def test_get_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) - tosfs.get_file(f"{bucket}/{temporary_workspace}/nonexistent", lpath) tosfs.rm_file(rpath) + + +########################################################### +# File operation tests # +########################################################### + + +def test_file_write( + tosfs: TosFileSystem, bucket: str, temporary_workspace: str +) -> None: + file_name = random_str() + content = "hello world" + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "w") as f: + f.write(content) + assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len( + content + ) + + tosfs.touch(f"{bucket}/{temporary_workspace}/{file_name}") + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "wb") as f: + f.write(content.encode("utf-8")) + assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len( + content + ) + + tosfs.rm_file(f"{bucket}/{temporary_workspace}/{file_name}") + + +def test_file_write_encdec( + tosfs: TosFileSystem, bucket: str, temporary_workspace: str +) -> None: + file_name = random_str() + content = "你好" + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "wb") as f: + f.write(content.encode("gbk")) + response = tosfs.tos_client.get_object( + bucket=bucket, key=f"{temporary_workspace}/{file_name}" + ) + assert response.read().decode("gbk") == content + + tosfs.touch(f"{bucket}/{temporary_workspace}/{file_name}") + + content = "\u00af\\_(\u30c4)_/\u00af" + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "wb") as f: + f.write(content.encode("utf-16-le")) + response = tosfs.tos_client.get_object( + bucket=bucket, key=f"{temporary_workspace}/{file_name}" + ) + assert response.read().decode("utf-16-le") == content + + with tosfs.open( + f"{bucket}/{temporary_workspace}/{file_name}", "w", encoding="utf-8" + ) as f: + f.write("\u00af\\_(\u30c4)_/\u00af") + response = tosfs.tos_client.get_object( + bucket=bucket, key=f"{temporary_workspace}/{file_name}" + ) + assert response.read().decode("utf-8") == content + + content = "Hello, World!" + with tosfs.open( + f"{bucket}/{temporary_workspace}/{file_name}", "w", encoding="ibm500" + ) as f: + f.write(content) + response = tosfs.tos_client.get_object( + bucket=bucket, key=f"{temporary_workspace}/{file_name}" + ) + assert response.read().decode("ibm500") == content + + tosfs.rm_file(f"{bucket}/{temporary_workspace}/{file_name}") + + +def test_file_write_mpu( + tosfs: TosFileSystem, bucket: str, temporary_workspace: str +) -> None: + file_name = random_str() + + # mock a content let the write logic trigger mpu: + content = "a" * 13 * 1024 * 1024 + block_size = 4 * 1024 * 1024 + with tosfs.open( + f"{bucket}/{temporary_workspace}/{file_name}", "w", block_size=block_size + ) as f: + f.write(content) + + assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len( + content + ) + + tosfs.rm_file(f"{bucket}/{temporary_workspace}/{file_name}")