From 0bdef04d913b20798047c3d730e29df47ca732c3 Mon Sep 17 00:00:00 2001
From: yanghua <yanghua1127@gmail.com>
Date: Thu, 5 Sep 2024 14:07:14 +0800
Subject: [PATCH 1/2] Core: Implement cp_file API

---
 tosfs/consts.py           |   6 ++
 tosfs/core.py             | 219 +++++++++++++++++++++++++++++++++++---
 tosfs/tests/test_tosfs.py |  44 ++++++++
 tosfs/utils.py            |  17 ++-
 4 files changed, 273 insertions(+), 13 deletions(-)

diff --git a/tosfs/consts.py b/tosfs/consts.py
index 01cc726..fc82585 100644
--- a/tosfs/consts.py
+++ b/tosfs/consts.py
@@ -26,3 +26,9 @@
     "InternalError",
     "ServiceUnavailable",
 }
+
+MANAGED_COPY_THRESHOLD = 5 * 2**30
+
+RETRY_NUM = 5
+PART_MIN_SIZE = 5 * 2**20
+PART_MAX_SIZE = 5 * 2**30
diff --git a/tosfs/core.py b/tosfs/core.py
index 59aea2c..deb1cd0 100644
--- a/tosfs/core.py
+++ b/tosfs/core.py
@@ -33,9 +33,14 @@
     UploadPartOutput,
 )
 
-from tosfs.consts import TOS_SERVER_RESPONSE_CODE_NOT_FOUND
+from tosfs.consts import (
+    MANAGED_COPY_THRESHOLD,
+    PART_MAX_SIZE,
+    RETRY_NUM,
+    TOS_SERVER_RESPONSE_CODE_NOT_FOUND,
+)
 from tosfs.exceptions import TosfsError
-from tosfs.utils import find_bucket_key, retryable_func_wrapper
+from tosfs.utils import find_bucket_key, get_brange, retryable_func_wrapper
 
 # environment variable names
 ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL"
@@ -66,7 +71,6 @@ class TosFileSystem(AbstractFileSystem):
     """
 
     protocol = ("tos", "tosfs")
-    retries = 5
     default_block_size = 5 * 2**20
 
     def __init__(
@@ -125,7 +129,7 @@ def _open(
             best support random access. When reading only a few specific chunks
             out of a file, performance may be better if False.
         version_id : str
-            Explicit version of the object to open.  This requires that the s3
+            Explicit version of the object to open.  This requires that the tos
             filesystem is version aware and bucket versioning is enabled on the
             relevant bucket.
         cache_type : str
@@ -636,7 +640,7 @@ def _read_chunks(body: BinaryIO, f: BinaryIO) -> None:
                     chunk = body.read(2**16)
                 except tos.exceptions.TosClientError as e:
                     failed_reads += 1
-                    if failed_reads >= self.retries:
+                    if failed_reads >= self.RETRY_NUM:
                         raise e
                     try:
                         body.close()
@@ -772,6 +776,202 @@ def find(
         else:
             return [o["name"] for o in out]
 
+    def cp_file(
+        self,
+        path1: str,
+        path2: str,
+        preserve_etag: Optional[bool] = None,
+        managed_copy_threshold: Optional[int] = MANAGED_COPY_THRESHOLD,
+        **kwargs: Any,
+    ) -> None:
+        """Copy file between locations on tos.
+
+        Parameters
+        ----------
+        path1 : str
+            The source path of the file to copy.
+        path2 : str
+            The destination path of the file to copy.
+        preserve_etag : bool, optional
+            Whether to preserve etag while copying. If the file is uploaded
+            as a single part, then it will be always equivalent to the md5
+            hash of the file hence etag will always be preserved. But if the
+            file is uploaded in multi parts, then this option will try to
+            reproduce the same multipart upload while copying and preserve
+            the generated etag.
+        managed_copy_threshold : int, optional
+            The threshold size of the file to copy using managed copy. If the
+            size of the file is greater than this threshold, then the file
+            will be copied using managed copy (default is 5 * 2**30).
+        **kwargs : Any, optional
+            Additional arguments.
+
+        Raises
+        ------
+        FileNotFoundError
+            If the source file does not exist.
+        ValueError
+            If the destination is a versioned file.
+        TosClientError
+            If there is a client error while copying the file.
+        TosServerError
+            If there is a server error while copying the file.
+        TosfsError
+            If there is an unknown error while copying the file.
+        """
+        path1 = self._strip_protocol(path1)
+        bucket, key, vers = self._split_path(path1)
+
+        info = self.info(path1, bucket, key, version_id=vers)
+        size = info["size"]
+
+        _, _, parts_suffix = info.get("ETag", "").strip('"').partition("-")
+        if preserve_etag and parts_suffix:
+            self._copy_etag_preserved(path1, path2, size, total_parts=int(parts_suffix))
+        elif size <= min(
+            MANAGED_COPY_THRESHOLD,
+            (
+                managed_copy_threshold
+                if managed_copy_threshold
+                else MANAGED_COPY_THRESHOLD
+            ),
+        ):
+            self._copy_basic(path1, path2, **kwargs)
+        else:
+            # if the preserve_etag is true, either the file is uploaded
+            # on multiple parts or the size is lower than 5GB
+            assert not preserve_etag
+
+            # serial multipart copy
+            self._copy_managed(path1, path2, size, **kwargs)
+
+    def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None:
+        """Copy file between locations on tos.
+
+        Not allowed where the origin is larger than 5GB.
+        """
+        buc1, key1, ver1 = self._split_path(path1)
+        buc2, key2, ver2 = self._split_path(path2)
+        if ver2:
+            raise ValueError("Cannot copy to a versioned file!")
+        try:
+            self.tos_client.copy_object(
+                bucket=buc2,
+                key=key2,
+                src_bucket=buc1,
+                src_key=key1,
+                src_version_id=ver1,
+            )
+        except tos.exceptions.TosClientError as e:
+            raise e
+        except tos.exceptions.TosServerError as e:
+            raise e
+        except Exception as e:
+            raise TosfsError("Copy failed (%r -> %r): %s" % (path1, path2, e)) from e
+
+    def _copy_etag_preserved(
+        self, path1: str, path2: str, size: int, total_parts: int, **kwargs: Any
+    ) -> None:
+        """Copy file between tos locations as multiple-part while preserving the etag.
+
+        (using the same part sizes for each part
+        """
+        bucket1, key1, version1 = self._split_path(path1)
+        bucket2, key2, version2 = self._split_path(path2)
+
+        upload_id = None
+
+        try:
+            mpu = self.tos_client.create_multipart_upload(bucket2, key2)
+            upload_id = mpu.upload_id
+
+            parts = []
+            brange_first = 0
+
+            for i in range(1, total_parts + 1):
+                part_size = min(size - brange_first, PART_MAX_SIZE)
+                brange_last = brange_first + part_size - 1
+                if brange_last > size:
+                    brange_last = size - 1
+
+                part = self.tos_client.upload_part_copy(
+                    bucket=bucket2,
+                    key=key2,
+                    part_number=i,
+                    upload_id=upload_id,
+                    src_bucket=bucket1,
+                    src_key=key1,
+                    copy_source_range_start=brange_first,
+                    copy_source_range_end=brange_last,
+                )
+                parts.append(
+                    PartInfo(
+                        part_number=part.part_number,
+                        etag=part.etag,
+                        part_size=size,
+                        offset=None,
+                        hash_crc64_ecma=None,
+                        is_completed=None,
+                    )
+                )
+                brange_first += part_size
+
+            self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts)
+        except Exception as e:
+            self.tos_client.abort_multipart_upload(bucket2, key2, upload_id)
+            raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e
+
+    def _copy_managed(
+        self, path1: str, path2: str, size: int, block: int = 5 * 2**30, **kwargs: Any
+    ) -> None:
+        """Copy file between locations on tos as multiple-part.
+
+        block: int
+            The size of the pieces, must be larger than 5MB and at most 5GB.
+            Smaller blocks mean more calls, only useful for testing.
+        """
+        if block < 5 * 2**20 or block > 5 * 2**30:
+            raise ValueError("Copy block size must be 5MB<=block<=5GB")
+
+        bucket1, key1, version1 = self._split_path(path1)
+        bucket2, key2, version2 = self._split_path(path2)
+
+        upload_id = None
+
+        try:
+            mpu = self.tos_client.create_multipart_upload(bucket2, key2)
+            upload_id = mpu.upload_id
+            out = [
+                self.tos_client.upload_part_copy(
+                    bucket=bucket2,
+                    key=key2,
+                    part_number=i + 1,
+                    upload_id=upload_id,
+                    src_bucket=bucket1,
+                    src_key=key1,
+                    copy_source_range_start=brange_first,
+                    copy_source_range_end=brange_last,
+                )
+                for i, (brange_first, brange_last) in enumerate(get_brange(size, block))
+            ]
+
+            parts = [
+                PartInfo(
+                    part_number=i + 1,
+                    etag=o.etag,
+                    part_size=size,
+                    offset=None,
+                    hash_crc64_ecma=None,
+                    is_completed=None,
+                )
+                for i, o in enumerate(out)
+            ]
+
+            self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts)
+        except Exception as e:
+            self.tos_client.abort_multipart_upload(bucket2, key2, upload_id)
+            raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e
+
     def _find_file_dir(
         self, key: str, path: str, prefix: str, withdirs: bool, kwargs: Any
     ) -> List[dict]:
@@ -1393,14 +1593,9 @@ def _fill_bucket_info(bucket_name: str) -> dict:
             "name": bucket_name,
         }
 
-
 class TosFile(AbstractBufferedFile):
     """File-like operations for TOS."""
 
-    retries = 5
-    part_min = 5 * 2**20
-    part_max = 5 * 2**30
-
     def __init__(
         self,
         fs: TosFileSystem,
@@ -1530,7 +1725,7 @@ def handle_remainder(
                 # Use the helper function in the main code
                 if 0 < current_chunk_size < self.blocksize:
                     previous_chunk, current_chunk = handle_remainder(
-                        previous_chunk, current_chunk, self.blocksize, self.part_max
+                        previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE
                     )
 
                 part = len(self.parts) + 1 if self.parts is not None else 1
@@ -1577,7 +1772,7 @@ def fetch() -> bytes:
                 bucket, key, version_id, range_start=start, range_end=end
             ).read()
 
-        return retryable_func_wrapper(fetch, retries=self.fs.retries)
+        return retryable_func_wrapper(fetch, retries=RETRY_NUM)
 
     def commit(self) -> None:
         """Complete multipart upload or PUT."""
diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py
index c693ca5..aa1b527 100644
--- a/tosfs/tests/test_tosfs.py
+++ b/tosfs/tests/test_tosfs.py
@@ -459,6 +459,50 @@ def test_find(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> No
     tosfs.rmdir(f"{bucket}/{temporary_workspace}")
 
 
+def test_cp_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None:
+    file_name = random_str()
+    file_content = "hello world"
+    src_path = f"{bucket}/{temporary_workspace}/{file_name}"
+    dest_path = f"{bucket}/{temporary_workspace}/copy_{file_name}"
+
+    with tosfs.open(src_path, "w") as f:
+        f.write(file_content)
+
+    tosfs.cp_file(src_path, dest_path)
+    assert tosfs.exists(dest_path)
+
+    with tosfs.open(dest_path, "r") as f:
+        assert f.read() == file_content
+
+    with pytest.raises(FileNotFoundError):
+        tosfs.cp_file(f"{bucket}/{temporary_workspace}/nonexistent", dest_path)
+
+    sub_dir_name = random_str()
+    dest_path = f"{bucket}/{temporary_workspace}/{sub_dir_name}"
+    tosfs.cp_file(src_path, dest_path)
+    assert tosfs.exists(dest_path)
+    with tosfs.open(dest_path, "r") as f:
+        assert f.read() == file_content
+
+    file_content = "a" * 2048  # 2KB content
+    with tosfs.open(src_path, "w") as f:
+        f.write(file_content)
+
+    tosfs.cp_file(src_path, dest_path, managed_copy_threshold=1024)
+    assert tosfs.exists(dest_path)
+
+    with tosfs.open(dest_path, "r") as f:
+        assert f.read() == file_content
+
+    # Test cp_file with preserve_etag=True
+    dest_path_with_etag = f"{bucket}/{temporary_workspace}/etag_{file_name}"
+    tosfs.cp_file(dest_path, dest_path_with_etag, preserve_etag=True)
+    assert tosfs.exists(dest_path_with_etag)
+    with tosfs.open(dest_path_with_etag, "r") as f:
+        assert f.read() == file_content
+    assert tosfs.info(dest_path_with_etag)["ETag"] == tosfs.info(dest_path)["ETag"]
+
+
 ###########################################################
 #                File operation tests                     #
 ###########################################################
diff --git a/tosfs/utils.py b/tosfs/utils.py
index 43d8074..d949877 100644
--- a/tosfs/utils.py
+++ b/tosfs/utils.py
@@ -19,7 +19,7 @@
 import string
 import tempfile
 import time
-from typing import Any, Optional, Tuple
+from typing import Any, Generator, Optional, Tuple
 
 import tos
 
@@ -91,6 +91,21 @@ def find_bucket_key(tos_path: str) -> Tuple[str, str]:
     return bucket, tos_key
 
 
+def get_brange(size: int, block: int) -> Generator[Tuple[int, int], None, None]:
+    """Chunk up a file into zero-based byte ranges.
+
+    Parameters
+    ----------
+    size : int
+        file size
+    block : int
+        block size
+
+    """
+    for offset in range(0, size, block):
+        yield offset, min(offset + block - 1, size - 1)
+
+
 def retryable_func_wrapper(
     func: Any, *, args: tuple[()] = (), kwargs: Optional[Any] = None, retries: int = 5
 ) -> Any:

From a3b45ae36768f19f6776f708c298474db13b3a23 Mon Sep 17 00:00:00 2001
From: yanghua <yanghua1127@gmail.com>
Date: Thu, 5 Sep 2024 14:15:02 +0800
Subject: [PATCH 2/2] Reformat code

---
 tosfs/core.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/tosfs/core.py b/tosfs/core.py
index deb1cd0..6f42b3b 100644
--- a/tosfs/core.py
+++ b/tosfs/core.py
@@ -818,6 +818,7 @@ def cp_file(
             If there is a server error while copying the file.
         TosfsError
             If there is an unknown error while copying the file.
+
         """
         path1 = self._strip_protocol(path1)
         bucket, key, vers = self._split_path(path1)
@@ -872,10 +873,7 @@ def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None:
     def _copy_etag_preserved(
         self, path1: str, path2: str, size: int, total_parts: int, **kwargs: Any
     ) -> None:
-        """Copy file between tos locations as multiple-part while preserving the etag.
-
-        (using the same part sizes for each part
-        """
+        """Copy file as multiple-part while preserving the etag."""
         bucket1, key1, version1 = self._split_path(path1)
         bucket2, key2, version2 = self._split_path(path2)
 
@@ -922,15 +920,21 @@ def _copy_etag_preserved(
             raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e
 
     def _copy_managed(
-        self, path1: str, path2: str, size: int, block: int = 5 * 2**30, **kwargs: Any
+        self,
+        path1: str,
+        path2: str,
+        size: int,
+        block: int = MANAGED_COPY_THRESHOLD,
+        **kwargs: Any,
     ) -> None:
         """Copy file between locations on tos as multiple-part.
 
         block: int
-            The size of the pieces, must be larger than 5MB and at most 5GB.
+            The size of the pieces, must be larger than 5MB and at
+            most MANAGED_COPY_THRESHOLD.
             Smaller blocks mean more calls, only useful for testing.
         """
-        if block < 5 * 2**20 or block > 5 * 2**30:
+        if block < 5 * 2**20 or block > MANAGED_COPY_THRESHOLD:
             raise ValueError("Copy block size must be 5MB<=block<=5GB")
 
         bucket1, key1, version1 = self._split_path(path1)
@@ -1593,6 +1597,7 @@ def _fill_bucket_info(bucket_name: str) -> dict:
             "name": bucket_name,
         }
 
+
 class TosFile(AbstractBufferedFile):
     """File-like operations for TOS."""