diff --git a/pyproject.toml b/pyproject.toml index a01d5fbd3..8a2ab5a74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dynamic = [ "version", ] dependencies = [ - "aiofiles<25,>=24.1", + "anyio>=4.5.1,<5", "aiohttp>=3.10.11,<3.12", "cryptography>=43.0.1,<45", "grpcio>=1.53.2,<1.69", @@ -35,7 +35,7 @@ dependencies = [ "pydantic-settings<3,>=2.3", "pyyaml<7,>=6.0.1", "requests<2.33,>=2.32", - "simple-sqlite3-orm<0.7,>=0.6", + "simple-sqlite3-orm<0.8,>=0.7", "typing-extensions>=4.6.3", "urllib3<2.3,>=2.2.2", "uvicorn[standard]>=0.30,<0.35", diff --git a/requirements.txt b/requirements.txt index 0f054df2e..dbf3e8c60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Automatically generated from pyproject.toml by gen_requirements_txt.py script. # DO NOT EDIT! Only for reference use. -aiofiles<25,>=24.1 +anyio>=4.5.1,<5 aiohttp>=3.10.11,<3.12 cryptography>=43.0.1,<45 grpcio>=1.53.2,<1.69 @@ -11,7 +11,7 @@ pydantic<3,>=2.10 pydantic-settings<3,>=2.3 pyyaml<7,>=6.0.1 requests<2.33,>=2.32 -simple-sqlite3-orm<0.7,>=0.6 +simple-sqlite3-orm<0.8,>=0.7 typing-extensions>=4.6.3 urllib3<2.3,>=2.2.2 uvicorn[standard]>=0.30,<0.35 diff --git a/src/ota_proxy/__init__.py b/src/ota_proxy/__init__.py index 7425d2811..bcce9177a 100644 --- a/src/ota_proxy/__init__.py +++ b/src/ota_proxy/__init__.py @@ -33,7 +33,7 @@ ) -async def run_otaproxy( +def run_otaproxy( host: str, port: int, *, @@ -45,6 +45,7 @@ async def run_otaproxy( enable_https: bool, external_cache_mnt_point: str | None = None, ): + import anyio import uvicorn from . import App, OTACache @@ -69,4 +70,4 @@ async def run_otaproxy( http="h11", ) _server = uvicorn.Server(_config) - await _server.serve() + anyio.run(_server.serve, backend="asyncio", backend_options={"use_uvloop": True}) diff --git a/src/ota_proxy/__main__.py b/src/ota_proxy/__main__.py index 2c7aad4ff..1814ddd8a 100644 --- a/src/ota_proxy/__main__.py +++ b/src/ota_proxy/__main__.py @@ -16,11 +16,8 @@ from __future__ import annotations import argparse -import asyncio import logging -import uvloop - from . import run_otaproxy from .config import config as cfg @@ -78,17 +75,14 @@ args = parser.parse_args() logger.info(f"launch ota_proxy at {args.host}:{args.port}") - uvloop.install() - asyncio.run( - run_otaproxy( - host=args.host, - port=args.port, - cache_dir=args.cache_dir, - cache_db_f=args.cache_db_file, - enable_cache=args.enable_cache, - upper_proxy=args.upper_proxy, - enable_https=args.enable_https, - init_cache=args.init_cache, - external_cache_mnt_point=args.external_cache_mnt_point, - ) + run_otaproxy( + host=args.host, + port=args.port, + cache_dir=args.cache_dir, + cache_db_f=args.cache_db_file, + enable_cache=args.enable_cache, + upper_proxy=args.upper_proxy, + enable_https=args.enable_https, + init_cache=args.init_cache, + external_cache_mnt_point=args.external_cache_mnt_point, ) diff --git a/src/ota_proxy/cache_streaming.py b/src/ota_proxy/cache_streaming.py index 17e14b02f..5ffaf6016 100644 --- a/src/ota_proxy/cache_streaming.py +++ b/src/ota_proxy/cache_streaming.py @@ -22,11 +22,11 @@ import os import threading import weakref -from concurrent.futures import Executor from pathlib import Path from typing import AsyncGenerator, AsyncIterator, Callable, Coroutine -import aiofiles +import anyio +from anyio import open_file from otaclient_common.common import get_backoff from otaclient_common.typing import StrOrPath @@ -101,11 +101,10 @@ def __init__( *, base_dir: StrOrPath, commit_cache_cb: _CACHE_ENTRY_REGISTER_CALLBACK, - executor: Executor, below_hard_limit_event: threading.Event, ): self.fpath = Path(base_dir) / self._tmp_file_naming(cache_identifier) - self.save_path = Path(base_dir) / cache_identifier + self.save_path = anyio.Path(base_dir) / cache_identifier self.cache_meta: CacheMeta | None = None self._commit_cache_cb = commit_cache_cb @@ -113,7 +112,6 @@ def __init__( self._writer_finished = asyncio.Event() self._writer_failed = asyncio.Event() - self._executor = executor self._space_availability_event = below_hard_limit_event self._bytes_written = 0 @@ -147,7 +145,7 @@ async def _provider_write_cache( """ logger.debug(f"start to cache for {cache_meta=}...") try: - async with aiofiles.open(self.fpath, "wb", executor=self._executor) as f: + async with await open_file(self.fpath, "wb") as f: _written = 0 while _data := (yield _written): if not self._space_availability_event.is_set(): @@ -179,7 +177,7 @@ async def _provider_write_cache( await self._commit_cache_cb(cache_meta) # finalize the cache file, skip finalize if the target file is # already presented. - if not self.save_path.is_file(): + if not await self.save_path.is_file(): os.link(self.fpath, self.save_path) except Exception as e: logger.warning(f"failed to write cache for {cache_meta=}: {e!r}") @@ -202,7 +200,7 @@ async def _subscriber_stream_cache(self) -> AsyncIterator[bytes]: """ err_count, _bytes_read = 0, 0 try: - async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f: + async with await open_file(self.fpath, "rb") as f: while ( not self._writer_finished.is_set() or _bytes_read < self._bytes_written diff --git a/src/ota_proxy/ota_cache.py b/src/ota_proxy/ota_cache.py index 3c8fbacf0..968bccfb2 100644 --- a/src/ota_proxy/ota_cache.py +++ b/src/ota_proxy/ota_cache.py @@ -20,12 +20,13 @@ import shutil import threading import time -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import AsyncIterator, Mapping, Optional from urllib.parse import SplitResult, quote, urlsplit import aiohttp +import anyio +import anyio.to_thread from multidict import CIMultiDict, CIMultiDictProxy from otaclient_common.common import get_backoff @@ -133,10 +134,6 @@ def __init__( db_f.unlink(missing_ok=True) self._init_cache = True # force init cache on db file cleanup - self._executor = ThreadPoolExecutor( - thread_name_prefix="ota_cache_fileio_executor" - ) - self._external_cache_data_dir = None self._external_cache_mp = None if external_cache_mnt_point and mount_external_cache(external_cache_mnt_point): @@ -145,7 +142,7 @@ def __init__( ) self._external_cache_mp = external_cache_mnt_point self._external_cache_data_dir = ( - Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME + anyio.Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME ) self._storage_below_hard_limit_event = threading.Event() @@ -189,11 +186,18 @@ async def start(self): # reuse the previously left ota_cache else: # cleanup unfinished tmp files - for tmp_f in self._base_dir.glob(f"{cfg.TMP_FILE_PREFIX}*"): - tmp_f.unlink(missing_ok=True) + async for tmp_f in anyio.Path(self._base_dir).glob( + f"{cfg.TMP_FILE_PREFIX}*" + ): + await tmp_f.unlink(missing_ok=True) # dispatch a background task to pulling the disk usage info - self._executor.submit(self._background_check_free_space) + _free_space_check_thread = threading.Thread( + target=self._background_check_free_space, + daemon=True, + name="ota_cache_free_space_checker", + ) + _free_space_check_thread.start() # init cache helper(and connect to ota_cache db) self._lru_helper = LRUCacheHelper( @@ -222,7 +226,6 @@ async def close(self): if not self._closed: self._closed = True await self._session.close() - self._executor.shutdown(wait=True) if self._cache_enabled: self._lru_helper.close() @@ -311,7 +314,7 @@ async def _reserve_space(self, size: int) -> bool: logger.debug( f"rotate on bucket({size=}), num of entries to be cleaned {len(_hashes)=}" ) - self._executor.submit(self._cache_entries_cleanup, _hashes) + await anyio.to_thread.run_sync(self._cache_entries_cleanup, _hashes) return True else: logger.debug(f"rotate on bucket({size=}) failed, no enough entries") @@ -429,7 +432,7 @@ async def _retrieve_file_by_cache_lookup( # NOTE: db_entry.file_sha256 can be either # 1. valid sha256 value for corresponding plain uncompressed OTA file # 2. URL based sha256 value for corresponding requested URL - cache_file = self._base_dir / cache_identifier + cache_file = anyio.Path(self._base_dir / cache_identifier) # check if cache file exists # NOTE(20240729): there is an edge condition that the finished cached file is not yet renamed, @@ -437,11 +440,11 @@ async def _retrieve_file_by_cache_lookup( # cache_commit_callback to rename the tmp file. _retry_count_max, _factor, _backoff_max = 6, 0.01, 0.1 # 0.255s in total for _retry_count in range(_retry_count_max): - if cache_file.is_file(): + if await cache_file.is_file(): break await asyncio.sleep(get_backoff(_retry_count, _factor, _backoff_max)) - if not cache_file.is_file(): + if not await cache_file.is_file(): logger.warning( f"dangling cache entry found, remove db entry: {meta_db_entry}" ) @@ -452,7 +455,7 @@ async def _retrieve_file_by_cache_lookup( # do the job. If cache is invalid, otaclient will use CacheControlHeader's retry_cache # directory to indicate invalid cache. return ( - read_file(cache_file, executor=self._executor), + read_file(cache_file), meta_db_entry.export_headers_to_client(), ) @@ -470,11 +473,11 @@ async def _retrieve_file_by_external_cache( cache_identifier = client_cache_policy.file_sha256 cache_file = self._external_cache_data_dir / cache_identifier - cache_file_zst = cache_file.with_suffix( - f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}" + cache_file_zst = anyio.Path( + cache_file.with_suffix(f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}") ) - if cache_file_zst.is_file(): + if await cache_file_zst.is_file(): _header = CIMultiDict() _header[HEADER_OTA_FILE_CACHE_CONTROL] = ( OTAFileCacheControl.export_kwargs_as_header( @@ -482,16 +485,16 @@ async def _retrieve_file_by_external_cache( file_compression_alg=cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG, ) ) - return read_file(cache_file_zst, executor=self._executor), _header + return read_file(cache_file_zst), _header - if cache_file.is_file(): + if await cache_file.is_file(): _header = CIMultiDict() _header[HEADER_OTA_FILE_CACHE_CONTROL] = ( OTAFileCacheControl.export_kwargs_as_header( file_sha256=cache_identifier ) ) - return read_file(cache_file, executor=self._executor), _header + return read_file(cache_file), _header async def _retrieve_file_by_new_caching( self, @@ -534,7 +537,6 @@ async def _retrieve_file_by_new_caching( tracker = CacheTracker( cache_identifier=cache_identifier, base_dir=self._base_dir, - executor=self._executor, commit_cache_cb=self._commit_cache_callback, below_hard_limit_event=self._storage_below_hard_limit_event, ) diff --git a/src/ota_proxy/utils.py b/src/ota_proxy/utils.py index a852b1597..8b5e2d617 100644 --- a/src/ota_proxy/utils.py +++ b/src/ota_proxy/utils.py @@ -1,18 +1,17 @@ from __future__ import annotations -from concurrent.futures import Executor from hashlib import sha256 from os import PathLike from typing import AsyncIterator -import aiofiles +from anyio import open_file from .config import config as cfg -async def read_file(fpath: PathLike, *, executor: Executor) -> AsyncIterator[bytes]: - """Open and read a file asynchronously with aiofiles.""" - async with aiofiles.open(fpath, "rb", executor=executor) as f: +async def read_file(fpath: PathLike) -> AsyncIterator[bytes]: + """Open and read a file asynchronously.""" + async with await open_file(fpath, "rb") as f: while data := await f.read(cfg.CHUNK_SIZE): yield data diff --git a/src/otaclient/_otaproxy_ctx.py b/src/otaclient/_otaproxy_ctx.py index e08a50638..b3e495429 100644 --- a/src/otaclient/_otaproxy_ctx.py +++ b/src/otaclient/_otaproxy_ctx.py @@ -19,7 +19,6 @@ from __future__ import annotations -import asyncio import atexit import logging import multiprocessing as mp @@ -78,18 +77,16 @@ def otaproxy_process(*, init_cache: bool) -> None: logger.info(f"wait for {upper_proxy=} online...") ensure_otaproxy_start(str(upper_proxy)) - asyncio.run( - run_otaproxy( - host=host, - port=port, - init_cache=init_cache, - cache_dir=local_otaproxy_cfg.BASE_DIR, - cache_db_f=local_otaproxy_cfg.DB_FILE, - upper_proxy=upper_proxy, - enable_cache=proxy_info.enable_local_ota_proxy_cache, - enable_https=proxy_info.gateway_otaproxy, - external_cache_mnt_point=external_cache_mnt_point, - ) + run_otaproxy( + host=host, + port=port, + init_cache=init_cache, + cache_dir=local_otaproxy_cfg.BASE_DIR, + cache_db_f=local_otaproxy_cfg.DB_FILE, + upper_proxy=upper_proxy, + enable_cache=proxy_info.enable_local_ota_proxy_cache, + enable_https=proxy_info.gateway_otaproxy, + external_cache_mnt_point=external_cache_mnt_point, ) diff --git a/tests/test_ota_proxy/test_cache_streaming.py b/tests/test_ota_proxy/test_cache_streaming.py index 8efdbcd1c..8701ad2bc 100644 --- a/tests/test_ota_proxy/test_cache_streaming.py +++ b/tests/test_ota_proxy/test_cache_streaming.py @@ -87,7 +87,6 @@ async def _worker( _tracker = CacheTracker( cache_identifier=self.URL, base_dir=self.base_dir, - executor=None, # type: ignore commit_cache_cb=None, # type: ignore below_hard_limit_event=None, # type: ignore ) diff --git a/tests/test_ota_proxy/test_subprocess_launch_otaproxy.py b/tests/test_ota_proxy/test_subprocess_launch_otaproxy.py index 817e91d86..17b3611b0 100644 --- a/tests/test_ota_proxy/test_subprocess_launch_otaproxy.py +++ b/tests/test_ota_proxy/test_subprocess_launch_otaproxy.py @@ -15,7 +15,6 @@ from __future__ import annotations -import asyncio import multiprocessing as mp import time from pathlib import Path @@ -27,17 +26,15 @@ def otaproxy_process(cache_dir: str): ota_cache_dir = Path(cache_dir) ota_cache_db = ota_cache_dir / "cache_db" - asyncio.run( - run_otaproxy( - host="127.0.0.1", - port=8082, - init_cache=True, - cache_dir=str(ota_cache_dir), - cache_db_f=str(ota_cache_db), - upper_proxy="", - enable_cache=True, - enable_https=False, - ), + run_otaproxy( + host="127.0.0.1", + port=8082, + init_cache=True, + cache_dir=str(ota_cache_dir), + cache_db_f=str(ota_cache_db), + upper_proxy="", + enable_cache=True, + enable_https=False, )