Skip to content

Commit

Permalink
Add test cases for stability
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 25, 2024
1 parent 81df130 commit 112a525
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 24 deletions.
23 changes: 16 additions & 7 deletions tosfs/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None:
def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.part_size:
self._flush_staging_buffer()
self._flush_staging_buffer(False)

def _flush_staging_buffer(self) -> None:
def _flush_staging_buffer(self, final: bool = False) -> None:
if self.staging_buffer.tell() == 0:
return

Expand All @@ -93,13 +93,22 @@ def _flush_staging_buffer(self) -> None:
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

# Move remaining data to a new buffer
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
if not final:
# Move remaining data to a new buffer
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
else:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

self.staging_buffer = io.BytesIO()

def _upload_staged_files(self) -> None:
self._flush_staging_buffer()
self._flush_staging_buffer(True)
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
Expand Down
13 changes: 10 additions & 3 deletions tosfs/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Optional, Tuple

import requests
from requests import RequestException
from requests.exceptions import (
ChunkedEncodingError,
ConnectTimeout,
Expand Down Expand Up @@ -70,6 +71,7 @@
ConnectionResetError,
ConnectionError,
ChunkedEncodingError,
RequestException,
}

MAX_RETRY_NUM = 20
Expand Down Expand Up @@ -142,9 +144,14 @@ def _is_retryable_tos_server_exception(e: TosError) -> bool:


def _is_retryable_tos_client_exception(e: TosError) -> bool:
return isinstance(e, TosClientError) and any(
isinstance(e.cause, excp) for excp in TOS_CLIENT_RETRYABLE_EXCEPTIONS
)
if isinstance(e, TosClientError):
cause = e.cause
while cause is not None:
for excp in TOS_CLIENT_RETRYABLE_EXCEPTIONS:
if isinstance(cause, excp):
return True
cause = getattr(cause, "cause", None)
return False


def _get_sleep_time(err: TosError, retry_count: int) -> float:
Expand Down
34 changes: 20 additions & 14 deletions tosfs/tests/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -73,6 +59,26 @@
),
True,
),
(
TosClientError(
"{'message': 'http request timeout', "
"'case': \"('Connection aborted.', "
"ConnectionResetError(104, 'Connection reset by peer'))\", "
"'request_url': "
"'http://proton-ci.tos-cn-beijing.volces.com/"
"nHnbR/yAlen'}",
TosClientError(
"http request timeout",
ConnectionError(
ProtocolError(
"Connection aborted.",
ConnectionResetError(104, "Connection reset by peer"),
)
),
),
),
True,
),
],
)
def test_is_retry_exception(
Expand Down
55 changes: 55 additions & 0 deletions tosfs/tests/test_stability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from time import sleep

from tosfs.utils import random_str


def test_write_breakpoint_continuation(tosfs, bucket, temporary_workspace):
file_name = f"{random_str()}.txt"
first_part = random_str(10 * 1024 * 1024)
second_part = random_str(10 * 1024 * 1024)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "w") as f:
f.write(first_part)
# mock a very long block(business processing or network issue)
sleep(60)
f.write(second_part)

assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len(
first_part + second_part
)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f:
assert f.read() == first_part + second_part


def test_read_breakpoint_continuation(tosfs, bucket, temporary_workspace):
file_name = f"{random_str()}.txt"
first_part = random_str(10 * 1024 * 1024)
second_part = random_str(10 * 1024 * 1024)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "w") as f:
f.write(first_part)
f.write(second_part)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f:
read_first_part = f.read(10 * 1024 * 1024)
assert read_first_part == first_part
# mock a very long block(business processing or network issue)
sleep(60)
read_second_part = f.read(10 * 1024 * 1024)
assert read_second_part == second_part
assert read_first_part + read_second_part == first_part + second_part

0 comments on commit 112a525

Please sign in to comment.