Skip to content

Commit

Permalink
refine downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed May 27, 2024
1 parent 2090ab4 commit a0b0811
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions otaclient/app/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from __future__ import annotations

import errno
import logging
import os
Expand All @@ -30,12 +32,10 @@
ByteString,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Protocol,
Tuple,
Union,
)
from urllib.parse import urlsplit

Expand Down Expand Up @@ -189,12 +189,10 @@ class DecompressionAdapterProtocol(Protocol):
"""DecompressionAdapter protocol for Downloader."""

@abstractmethod
def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]:
"""Decompresses the source stream.
This Method take a src_stream of compressed file and
return another stream that yields decompressed data chunks.
"""
def stream_reader(
self, src_stream: IO[bytes] | ByteString
) -> zstandard.ZstdDecompressionReader:
pass


class ZstdDecompressionAdapter(DecompressionAdapterProtocol):
Expand All @@ -203,8 +201,10 @@ class ZstdDecompressionAdapter(DecompressionAdapterProtocol):
def __init__(self) -> None:
self._dctx = zstandard.ZstdDecompressor()

def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]:
yield from self._dctx.read_to_iter(src_stream)
def stream_reader(
self, src_stream: IO[bytes] | ByteString
) -> zstandard.ZstdDecompressionReader:
return self._dctx.stream_reader(src_stream)


# downloader implementation
Expand All @@ -225,7 +225,7 @@ class Downloader:
MAX_TRAFFIC_STATS_COLLECT_PER_ROUND = 512

def __init__(self) -> None:
self._local = threading.local()
self._thread_local = threading.local()
self._executor = ThreadPoolExecutor(
max_workers=min(self.MAX_DOWNLOAD_THREADS, (os.cpu_count() or 1) + 4),
thread_name_prefix="downloader",
Expand Down Expand Up @@ -274,25 +274,31 @@ def _thread_initializer(self):
)
session.mount("https://", adapter)
session.mount("http://", adapter)
self._local.session = session
self._thread_local.session = session

# ------ compression support ------ #
self._local._compression_support_matrix = {}
self._thread_local._compression_support_matrix = {}
# zstd decompression adapter
self._local._zstd = ZstdDecompressionAdapter()
self._local._compression_support_matrix["zst"] = self._local._zstd
self._local._compression_support_matrix["zstd"] = self._local._zstd
self._thread_local._zstd = ZstdDecompressionAdapter()
self._thread_local._compression_support_matrix["zst"] = self._thread_local._zstd
self._thread_local._compression_support_matrix["zstd"] = (
self._thread_local._zstd
)

# ------ setup buffer ------ #
self._thread_local.buffer = buffer = bytearray(self.CHUNK_SIZE)
self._thread_local.view = memoryview(buffer)

@property
def _session(self) -> requests.Session:
"""A thread-local private session."""
return self._local.session
return self._thread_local.session

def _get_decompressor(
self, compression_alg: Any
) -> Optional[DecompressionAdapterProtocol]:
"""Get thread-local private decompressor adapter accordingly."""
return self._local._compression_support_matrix.get(compression_alg)
return self._thread_local._compression_support_matrix.get(compression_alg)

@property
def downloaded_bytes(self) -> int:
Expand Down Expand Up @@ -471,6 +477,7 @@ def _download_task(
headers: Optional[Dict[str, str]] = None,
compression_alg: Optional[str] = None,
) -> Tuple[int, int, int]:
"""The task entry to be executed in the download worker pool."""
_thread_id = threading.get_ident()

proxies = proxies or self._proxies
Expand All @@ -482,7 +489,7 @@ def _download_task(
headers, digest=digest, compression_alg=compression_alg, proxies=proxies
)

_hash_inst = self._hash_func()
digestobj = self._hash_func()
# NOTE: downloaded_file_size is the number of bytes we return to the caller(if compressed,
# the number will be of the decompressed file)
downloaded_file_size = 0
Expand All @@ -504,37 +511,30 @@ def _download_task(
url, dst, digest, compression_alg, resp.headers
)

# support for compresed file
stream = raw_resp = resp.raw
if decompressor := self._get_decompressor(compression_alg):
for _chunk in decompressor.iter_chunk(resp.raw):
_hash_inst.update(_chunk)
_dst.write(_chunk)
downloaded_file_size += len(_chunk)

_traffic_on_wire = resp.raw.tell()
self._workers_downloaded_bytes[_thread_id] += (
_traffic_on_wire - traffic_on_wire
)
stream = decompressor.stream_reader(stream)

traffic_on_wire = _traffic_on_wire
# un-compressed file
else:
for _chunk in resp.iter_content(chunk_size=self.CHUNK_SIZE):
_hash_inst.update(_chunk)
_dst.write(_chunk)
buffer, view = self._thread_local.buffer, self._thread_local.view
while read_size := stream.readinto(buffer):
digestobj.update(view[:read_size])
_dst.write(view[:read_size])

chunk_len = len(_chunk)
downloaded_file_size += chunk_len
traffic_on_wire += chunk_len
self._workers_downloaded_bytes[_thread_id] += chunk_len
downloaded_file_size += read_size

_new_traffic_on_wire = raw_resp.tell()
self._workers_downloaded_bytes[_thread_id] += (
_new_traffic_on_wire - traffic_on_wire
)
traffic_on_wire = _new_traffic_on_wire

# checking the download result
if size and size != downloaded_file_size:
_msg = f"partial download detected: {size=},{downloaded_file_size=}"
logger.warning(_msg)
raise ChunkStreamingError(url, dst, _msg)

if digest and ((calc_digest := _hash_inst.hexdigest()) != digest):
if digest and ((calc_digest := digestobj.hexdigest()) != digest):
_msg = (
"sha256hash check failed detected: "
f"act={calc_digest}, exp={digest}, {url=}"
Expand Down

0 comments on commit a0b0811

Please sign in to comment.