diff --git a/src/ota_proxy/ota_cache.py b/src/ota_proxy/ota_cache.py index 87efea741..e9d1e0a86 100644 --- a/src/ota_proxy/ota_cache.py +++ b/src/ota_proxy/ota_cache.py @@ -26,7 +26,7 @@ from urllib.parse import SplitResult, quote, urlsplit import aiohttp -from multidict import CIMultiDictProxy +from multidict import CIMultiDict, CIMultiDictProxy from otaclient_common.common import get_backoff from otaclient_common.typing import StrOrPath @@ -409,7 +409,7 @@ async def _do_request() -> AsyncIterator[bytes]: async def _retrieve_file_by_cache( self, cache_identifier: str, *, retry_cache: bool - ) -> Optional[Tuple[AsyncIterator[bytes], Mapping[str, str]]]: + ) -> tuple[AsyncIterator[bytes], CIMultiDict[str]] | None: """ Returns: A tuple of bytes iterator and headers dict for back to client. @@ -458,7 +458,8 @@ async def _retrieve_file_by_cache( async def _retrieve_file_by_external_cache( self, client_cache_policy: OTAFileCacheControl - ) -> tuple[AsyncIterator[bytes], Mapping[str, str]] | None: + ) -> tuple[AsyncIterator[bytes], CIMultiDict[str]] | None: + # skip if not external cache or otaclient doesn't sent valid file_sha256 if not self._external_cache_data_dir or not client_cache_policy.file_sha256: return @@ -469,18 +470,23 @@ async def _retrieve_file_by_external_cache( ) if cache_file_zst.is_file(): - return read_file(cache_file_zst, executor=self._executor), { - HEADER_OTA_FILE_CACHE_CONTROL: OTAFileCacheControl.export_kwargs_as_header( + _header = CIMultiDict() + _header[HEADER_OTA_FILE_CACHE_CONTROL] = ( + OTAFileCacheControl.export_kwargs_as_header( file_sha256=cache_identifier, file_compression_alg=cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG, ) - } - elif cache_file.is_file(): - return read_file(cache_file, executor=self._executor), { - HEADER_OTA_FILE_CACHE_CONTROL: OTAFileCacheControl.export_kwargs_as_header( + ) + return read_file(cache_file_zst, executor=self._executor), _header + + if cache_file.is_file(): + _header = CIMultiDict() + _header[HEADER_OTA_FILE_CACHE_CONTROL] = ( + OTAFileCacheControl.export_kwargs_as_header( file_sha256=cache_identifier ) - } + ) + return read_file(cache_file, executor=self._executor), _header # exposed API @@ -488,7 +494,7 @@ async def retrieve_file( self, raw_url: str, headers_from_client: Dict[str, str], - ) -> Optional[Tuple[AsyncIterator[bytes], Mapping[str, str]]]: + ) -> tuple[AsyncIterator[bytes], CIMultiDict[str] | CIMultiDictProxy[str]] | None: """Retrieve a file descriptor for the requested . This method retrieves a file descriptor for incoming client request.