Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refine downloader #309

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 38 additions & 40 deletions otaclient/app/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from __future__ import annotations

import errno
import logging
import os
Expand All @@ -30,12 +32,10 @@
ByteString,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Protocol,
Tuple,
Union,
)
from urllib.parse import urlsplit

Expand Down Expand Up @@ -189,12 +189,10 @@ class DecompressionAdapterProtocol(Protocol):
"""DecompressionAdapter protocol for Downloader."""

@abstractmethod
def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]:
"""Decompresses the source stream.

This Method take a src_stream of compressed file and
return another stream that yields decompressed data chunks.
"""
def stream_reader(
self, src_stream: IO[bytes] | ByteString
) -> zstandard.ZstdDecompressionReader:
pass


class ZstdDecompressionAdapter(DecompressionAdapterProtocol):
Expand All @@ -203,8 +201,10 @@ class ZstdDecompressionAdapter(DecompressionAdapterProtocol):
def __init__(self) -> None:
self._dctx = zstandard.ZstdDecompressor()

def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]:
yield from self._dctx.read_to_iter(src_stream)
def stream_reader(
self, src_stream: IO[bytes] | ByteString
) -> zstandard.ZstdDecompressionReader:
return self._dctx.stream_reader(src_stream)


# downloader implementation
Expand All @@ -225,7 +225,7 @@ class Downloader:
MAX_TRAFFIC_STATS_COLLECT_PER_ROUND = 512

def __init__(self) -> None:
self._local = threading.local()
self._thread_local = threading.local()
self._executor = ThreadPoolExecutor(
max_workers=min(self.MAX_DOWNLOAD_THREADS, (os.cpu_count() or 1) + 4),
thread_name_prefix="downloader",
Expand Down Expand Up @@ -274,25 +274,29 @@ def _thread_initializer(self):
)
session.mount("https://", adapter)
session.mount("http://", adapter)
self._local.session = session
self._thread_local.session = session

# ------ compression support ------ #
self._local._compression_support_matrix = {}
self._thread_local._compression_support_matrix = {}
# zstd decompression adapter
self._local._zstd = ZstdDecompressionAdapter()
self._local._compression_support_matrix["zst"] = self._local._zstd
self._local._compression_support_matrix["zstd"] = self._local._zstd
self._thread_local._zstd = zstd_adapter = ZstdDecompressionAdapter()
self._thread_local._compression_support_matrix["zst"] = zstd_adapter
self._thread_local._compression_support_matrix["zstd"] = zstd_adapter

# ------ setup buffer ------ #
self._thread_local.buffer = buffer = bytearray(self.CHUNK_SIZE)
self._thread_local.view = memoryview(buffer)

@property
def _session(self) -> requests.Session:
"""A thread-local private session."""
return self._local.session
return self._thread_local.session

def _get_decompressor(
self, compression_alg: Any
) -> Optional[DecompressionAdapterProtocol]:
"""Get thread-local private decompressor adapter accordingly."""
return self._local._compression_support_matrix.get(compression_alg)
return self._thread_local._compression_support_matrix.get(compression_alg)

@property
def downloaded_bytes(self) -> int:
Expand Down Expand Up @@ -471,6 +475,7 @@ def _download_task(
headers: Optional[Dict[str, str]] = None,
compression_alg: Optional[str] = None,
) -> Tuple[int, int, int]:
"""The task entry to be executed in the download worker pool."""
_thread_id = threading.get_ident()

proxies = proxies or self._proxies
Expand All @@ -482,7 +487,7 @@ def _download_task(
headers, digest=digest, compression_alg=compression_alg, proxies=proxies
)

_hash_inst = self._hash_func()
digestobj = self._hash_func()
# NOTE: downloaded_file_size is the number of bytes we return to the caller(if compressed,
# the number will be of the decompressed file)
downloaded_file_size = 0
Expand All @@ -504,37 +509,30 @@ def _download_task(
url, dst, digest, compression_alg, resp.headers
)

# support for compresed file
stream = raw_resp = resp.raw
if decompressor := self._get_decompressor(compression_alg):
for _chunk in decompressor.iter_chunk(resp.raw):
_hash_inst.update(_chunk)
_dst.write(_chunk)
downloaded_file_size += len(_chunk)

_traffic_on_wire = resp.raw.tell()
self._workers_downloaded_bytes[_thread_id] += (
_traffic_on_wire - traffic_on_wire
)
stream = decompressor.stream_reader(stream)

buffer, view = self._thread_local.buffer, self._thread_local.view
while read_size := stream.readinto(buffer):
digestobj.update(view[:read_size])
_dst.write(view[:read_size])

traffic_on_wire = _traffic_on_wire
# un-compressed file
else:
for _chunk in resp.iter_content(chunk_size=self.CHUNK_SIZE):
_hash_inst.update(_chunk)
_dst.write(_chunk)
downloaded_file_size += read_size

chunk_len = len(_chunk)
downloaded_file_size += chunk_len
traffic_on_wire += chunk_len
self._workers_downloaded_bytes[_thread_id] += chunk_len
_new_traffic_on_wire = raw_resp.tell()
self._workers_downloaded_bytes[_thread_id] += (
_new_traffic_on_wire - traffic_on_wire
)
traffic_on_wire = _new_traffic_on_wire

# checking the download result
if size and size != downloaded_file_size:
_msg = f"partial download detected: {size=},{downloaded_file_size=}"
logger.warning(_msg)
raise ChunkStreamingError(url, dst, _msg)

if digest and ((calc_digest := _hash_inst.hexdigest()) != digest):
if digest and ((calc_digest := digestobj.hexdigest()) != digest):
_msg = (
"sha256hash check failed detected: "
f"act={calc_digest}, exp={digest}, {url=}"
Expand Down
Loading