diff --git a/otaclient/app/downloader.py b/otaclient/app/downloader.py index 34ef9190e..b101b097f 100644 --- a/otaclient/app/downloader.py +++ b/otaclient/app/downloader.py @@ -13,6 +13,8 @@ # limitations under the License. +from __future__ import annotations + import errno import logging import os @@ -30,12 +32,10 @@ ByteString, Callable, Dict, - Iterator, Mapping, Optional, Protocol, Tuple, - Union, ) from urllib.parse import urlsplit @@ -189,12 +189,10 @@ class DecompressionAdapterProtocol(Protocol): """DecompressionAdapter protocol for Downloader.""" @abstractmethod - def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]: - """Decompresses the source stream. - - This Method take a src_stream of compressed file and - return another stream that yields decompressed data chunks. - """ + def stream_reader( + self, src_stream: IO[bytes] | ByteString + ) -> zstandard.ZstdDecompressionReader: + pass class ZstdDecompressionAdapter(DecompressionAdapterProtocol): @@ -203,8 +201,10 @@ class ZstdDecompressionAdapter(DecompressionAdapterProtocol): def __init__(self) -> None: self._dctx = zstandard.ZstdDecompressor() - def iter_chunk(self, src_stream: Union[IO[bytes], ByteString]) -> Iterator[bytes]: - yield from self._dctx.read_to_iter(src_stream) + def stream_reader( + self, src_stream: IO[bytes] | ByteString + ) -> zstandard.ZstdDecompressionReader: + return self._dctx.stream_reader(src_stream) # downloader implementation @@ -225,7 +225,7 @@ class Downloader: MAX_TRAFFIC_STATS_COLLECT_PER_ROUND = 512 def __init__(self) -> None: - self._local = threading.local() + self._thread_local = threading.local() self._executor = ThreadPoolExecutor( max_workers=min(self.MAX_DOWNLOAD_THREADS, (os.cpu_count() or 1) + 4), thread_name_prefix="downloader", @@ -274,25 +274,31 @@ def _thread_initializer(self): ) session.mount("https://", adapter) session.mount("http://", adapter) - self._local.session = session + self._thread_local.session = session # ------ compression support ------ # - self._local._compression_support_matrix = {} + self._thread_local._compression_support_matrix = {} # zstd decompression adapter - self._local._zstd = ZstdDecompressionAdapter() - self._local._compression_support_matrix["zst"] = self._local._zstd - self._local._compression_support_matrix["zstd"] = self._local._zstd + self._thread_local._zstd = ZstdDecompressionAdapter() + self._thread_local._compression_support_matrix["zst"] = self._thread_local._zstd + self._thread_local._compression_support_matrix["zstd"] = ( + self._thread_local._zstd + ) + + # ------ setup buffer ------ # + self._thread_local.buffer = buffer = bytearray(self.CHUNK_SIZE) + self._thread_local.view = memoryview(buffer) @property def _session(self) -> requests.Session: """A thread-local private session.""" - return self._local.session + return self._thread_local.session def _get_decompressor( self, compression_alg: Any ) -> Optional[DecompressionAdapterProtocol]: """Get thread-local private decompressor adapter accordingly.""" - return self._local._compression_support_matrix.get(compression_alg) + return self._thread_local._compression_support_matrix.get(compression_alg) @property def downloaded_bytes(self) -> int: @@ -471,6 +477,7 @@ def _download_task( headers: Optional[Dict[str, str]] = None, compression_alg: Optional[str] = None, ) -> Tuple[int, int, int]: + """The task entry to be executed in the download worker pool.""" _thread_id = threading.get_ident() proxies = proxies or self._proxies @@ -482,7 +489,7 @@ def _download_task( headers, digest=digest, compression_alg=compression_alg, proxies=proxies ) - _hash_inst = self._hash_func() + digestobj = self._hash_func() # NOTE: downloaded_file_size is the number of bytes we return to the caller(if compressed, # the number will be of the decompressed file) downloaded_file_size = 0 @@ -504,29 +511,22 @@ def _download_task( url, dst, digest, compression_alg, resp.headers ) - # support for compresed file + stream = raw_resp = resp.raw if decompressor := self._get_decompressor(compression_alg): - for _chunk in decompressor.iter_chunk(resp.raw): - _hash_inst.update(_chunk) - _dst.write(_chunk) - downloaded_file_size += len(_chunk) - - _traffic_on_wire = resp.raw.tell() - self._workers_downloaded_bytes[_thread_id] += ( - _traffic_on_wire - traffic_on_wire - ) + stream = decompressor.stream_reader(stream) - traffic_on_wire = _traffic_on_wire - # un-compressed file - else: - for _chunk in resp.iter_content(chunk_size=self.CHUNK_SIZE): - _hash_inst.update(_chunk) - _dst.write(_chunk) + buffer, view = self._thread_local.buffer, self._thread_local.view + while read_size := stream.readinto(buffer): + digestobj.update(view[:read_size]) + _dst.write(view[:read_size]) - chunk_len = len(_chunk) - downloaded_file_size += chunk_len - traffic_on_wire += chunk_len - self._workers_downloaded_bytes[_thread_id] += chunk_len + downloaded_file_size += read_size + + _new_traffic_on_wire = raw_resp.tell() + self._workers_downloaded_bytes[_thread_id] += ( + _new_traffic_on_wire - traffic_on_wire + ) + traffic_on_wire = _new_traffic_on_wire # checking the download result if size and size != downloaded_file_size: @@ -534,7 +534,7 @@ def _download_task( logger.warning(_msg) raise ChunkStreamingError(url, dst, _msg) - if digest and ((calc_digest := _hash_inst.hexdigest()) != digest): + if digest and ((calc_digest := digestobj.hexdigest()) != digest): _msg = ( "sha256hash check failed detected: " f"act={calc_digest}, exp={digest}, {url=}"