Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ota_proxy: migrate to anyio, drop deps of aiofiles #467

Merged
merged 13 commits into from
Dec 23, 2024
Merged
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dynamic = [
"version",
]
dependencies = [
"aiofiles<25,>=24.1",
"anyio>=4.5.1,<5",
"aiohttp>=3.10.11,<3.12",
"cryptography>=43.0.1,<45",
"grpcio>=1.53.2,<1.69",
Expand All @@ -35,7 +35,7 @@ dependencies = [
"pydantic-settings<3,>=2.3",
"pyyaml<7,>=6.0.1",
"requests<2.33,>=2.32",
"simple-sqlite3-orm<0.7,>=0.6",
"simple-sqlite3-orm<0.8,>=0.7",
"typing-extensions>=4.6.3",
"urllib3<2.3,>=2.2.2",
"uvicorn[standard]>=0.30,<0.35",
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Automatically generated from pyproject.toml by gen_requirements_txt.py script.
# DO NOT EDIT! Only for reference use.
aiofiles<25,>=24.1
anyio>=4.5.1,<5
aiohttp>=3.10.11,<3.12
cryptography>=43.0.1,<45
grpcio>=1.53.2,<1.69
Expand All @@ -11,7 +11,7 @@ pydantic<3,>=2.10
pydantic-settings<3,>=2.3
pyyaml<7,>=6.0.1
requests<2.33,>=2.32
simple-sqlite3-orm<0.7,>=0.6
simple-sqlite3-orm<0.8,>=0.7
typing-extensions>=4.6.3
urllib3<2.3,>=2.2.2
uvicorn[standard]>=0.30,<0.35
Expand Down
5 changes: 3 additions & 2 deletions src/ota_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)


async def run_otaproxy(
def run_otaproxy(
host: str,
port: int,
*,
Expand All @@ -45,6 +45,7 @@ async def run_otaproxy(
enable_https: bool,
external_cache_mnt_point: str | None = None,
):
import anyio
import uvicorn

from . import App, OTACache
Expand All @@ -69,4 +70,4 @@ async def run_otaproxy(
http="h11",
)
_server = uvicorn.Server(_config)
await _server.serve()
anyio.run(_server.serve, backend="asyncio", backend_options={"use_uvloop": True})
26 changes: 10 additions & 16 deletions src/ota_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from __future__ import annotations

import argparse
import asyncio
import logging

import uvloop

from . import run_otaproxy
from .config import config as cfg

Expand Down Expand Up @@ -78,17 +75,14 @@
args = parser.parse_args()

logger.info(f"launch ota_proxy at {args.host}:{args.port}")
uvloop.install()
asyncio.run(
run_otaproxy(
host=args.host,
port=args.port,
cache_dir=args.cache_dir,
cache_db_f=args.cache_db_file,
enable_cache=args.enable_cache,
upper_proxy=args.upper_proxy,
enable_https=args.enable_https,
init_cache=args.init_cache,
external_cache_mnt_point=args.external_cache_mnt_point,
)
run_otaproxy(
host=args.host,
port=args.port,
cache_dir=args.cache_dir,
cache_db_f=args.cache_db_file,
enable_cache=args.enable_cache,
upper_proxy=args.upper_proxy,
enable_https=args.enable_https,
init_cache=args.init_cache,
external_cache_mnt_point=args.external_cache_mnt_point,
)
14 changes: 6 additions & 8 deletions src/ota_proxy/cache_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import os
import threading
import weakref
from concurrent.futures import Executor
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Callable, Coroutine

import aiofiles
import anyio
from anyio import open_file

from otaclient_common.common import get_backoff
from otaclient_common.typing import StrOrPath
Expand Down Expand Up @@ -101,19 +101,17 @@ def __init__(
*,
base_dir: StrOrPath,
commit_cache_cb: _CACHE_ENTRY_REGISTER_CALLBACK,
executor: Executor,
below_hard_limit_event: threading.Event,
):
self.fpath = Path(base_dir) / self._tmp_file_naming(cache_identifier)
self.save_path = Path(base_dir) / cache_identifier
self.save_path = anyio.Path(base_dir) / cache_identifier
self.cache_meta: CacheMeta | None = None
self._commit_cache_cb = commit_cache_cb

self._writer_ready = asyncio.Event()
self._writer_finished = asyncio.Event()
self._writer_failed = asyncio.Event()

self._executor = executor
self._space_availability_event = below_hard_limit_event

self._bytes_written = 0
Expand Down Expand Up @@ -147,7 +145,7 @@ async def _provider_write_cache(
"""
logger.debug(f"start to cache for {cache_meta=}...")
try:
async with aiofiles.open(self.fpath, "wb", executor=self._executor) as f:
async with await open_file(self.fpath, "wb") as f:
_written = 0
while _data := (yield _written):
if not self._space_availability_event.is_set():
Expand Down Expand Up @@ -179,7 +177,7 @@ async def _provider_write_cache(
await self._commit_cache_cb(cache_meta)
# finalize the cache file, skip finalize if the target file is
# already presented.
if not self.save_path.is_file():
if not await self.save_path.is_file():
os.link(self.fpath, self.save_path)
except Exception as e:
logger.warning(f"failed to write cache for {cache_meta=}: {e!r}")
Expand All @@ -202,7 +200,7 @@ async def _subscriber_stream_cache(self) -> AsyncIterator[bytes]:
"""
err_count, _bytes_read = 0, 0
try:
async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f:
async with await open_file(self.fpath, "rb") as f:
while (
not self._writer_finished.is_set()
or _bytes_read < self._bytes_written
Expand Down
46 changes: 24 additions & 22 deletions src/ota_proxy/ota_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import AsyncIterator, Mapping, Optional
from urllib.parse import SplitResult, quote, urlsplit

import aiohttp
import anyio
import anyio.to_thread
from multidict import CIMultiDict, CIMultiDictProxy

from otaclient_common.common import get_backoff
Expand Down Expand Up @@ -133,10 +134,6 @@ def __init__(
db_f.unlink(missing_ok=True)
self._init_cache = True # force init cache on db file cleanup

self._executor = ThreadPoolExecutor(
thread_name_prefix="ota_cache_fileio_executor"
)

self._external_cache_data_dir = None
self._external_cache_mp = None
if external_cache_mnt_point and mount_external_cache(external_cache_mnt_point):
Expand All @@ -145,7 +142,7 @@ def __init__(
)
self._external_cache_mp = external_cache_mnt_point
self._external_cache_data_dir = (
Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME
anyio.Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME
)

self._storage_below_hard_limit_event = threading.Event()
Expand Down Expand Up @@ -189,11 +186,18 @@ async def start(self):

# reuse the previously left ota_cache
else: # cleanup unfinished tmp files
for tmp_f in self._base_dir.glob(f"{cfg.TMP_FILE_PREFIX}*"):
tmp_f.unlink(missing_ok=True)
async for tmp_f in anyio.Path(self._base_dir).glob(
f"{cfg.TMP_FILE_PREFIX}*"
):
await tmp_f.unlink(missing_ok=True)

# dispatch a background task to pulling the disk usage info
self._executor.submit(self._background_check_free_space)
_free_space_check_thread = threading.Thread(
target=self._background_check_free_space,
daemon=True,
name="ota_cache_free_space_checker",
)
_free_space_check_thread.start()

# init cache helper(and connect to ota_cache db)
self._lru_helper = LRUCacheHelper(
Expand Down Expand Up @@ -222,7 +226,6 @@ async def close(self):
if not self._closed:
self._closed = True
await self._session.close()
self._executor.shutdown(wait=True)

if self._cache_enabled:
self._lru_helper.close()
Expand Down Expand Up @@ -311,7 +314,7 @@ async def _reserve_space(self, size: int) -> bool:
logger.debug(
f"rotate on bucket({size=}), num of entries to be cleaned {len(_hashes)=}"
)
self._executor.submit(self._cache_entries_cleanup, _hashes)
await anyio.to_thread.run_sync(self._cache_entries_cleanup, _hashes)
return True
else:
logger.debug(f"rotate on bucket({size=}) failed, no enough entries")
Expand Down Expand Up @@ -429,19 +432,19 @@ async def _retrieve_file_by_cache_lookup(
# NOTE: db_entry.file_sha256 can be either
# 1. valid sha256 value for corresponding plain uncompressed OTA file
# 2. URL based sha256 value for corresponding requested URL
cache_file = self._base_dir / cache_identifier
cache_file = anyio.Path(self._base_dir / cache_identifier)

# check if cache file exists
# NOTE(20240729): there is an edge condition that the finished cached file is not yet renamed,
# but the database entry has already been inserted. Here we wait for 3 rounds for
# cache_commit_callback to rename the tmp file.
_retry_count_max, _factor, _backoff_max = 6, 0.01, 0.1 # 0.255s in total
for _retry_count in range(_retry_count_max):
if cache_file.is_file():
if await cache_file.is_file():
break
await asyncio.sleep(get_backoff(_retry_count, _factor, _backoff_max))

if not cache_file.is_file():
if not await cache_file.is_file():
logger.warning(
f"dangling cache entry found, remove db entry: {meta_db_entry}"
)
Expand All @@ -452,7 +455,7 @@ async def _retrieve_file_by_cache_lookup(
# do the job. If cache is invalid, otaclient will use CacheControlHeader's retry_cache
# directory to indicate invalid cache.
return (
read_file(cache_file, executor=self._executor),
read_file(cache_file),
meta_db_entry.export_headers_to_client(),
)

Expand All @@ -470,28 +473,28 @@ async def _retrieve_file_by_external_cache(

cache_identifier = client_cache_policy.file_sha256
cache_file = self._external_cache_data_dir / cache_identifier
cache_file_zst = cache_file.with_suffix(
f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}"
cache_file_zst = anyio.Path(
cache_file.with_suffix(f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}")
)

if cache_file_zst.is_file():
if await cache_file_zst.is_file():
_header = CIMultiDict()
_header[HEADER_OTA_FILE_CACHE_CONTROL] = (
OTAFileCacheControl.export_kwargs_as_header(
file_sha256=cache_identifier,
file_compression_alg=cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG,
)
)
return read_file(cache_file_zst, executor=self._executor), _header
return read_file(cache_file_zst), _header

if cache_file.is_file():
if await cache_file.is_file():
_header = CIMultiDict()
_header[HEADER_OTA_FILE_CACHE_CONTROL] = (
OTAFileCacheControl.export_kwargs_as_header(
file_sha256=cache_identifier
)
)
return read_file(cache_file, executor=self._executor), _header
return read_file(cache_file), _header

async def _retrieve_file_by_new_caching(
self,
Expand Down Expand Up @@ -534,7 +537,6 @@ async def _retrieve_file_by_new_caching(
tracker = CacheTracker(
cache_identifier=cache_identifier,
base_dir=self._base_dir,
executor=self._executor,
commit_cache_cb=self._commit_cache_callback,
below_hard_limit_event=self._storage_below_hard_limit_event,
)
Expand Down
9 changes: 4 additions & 5 deletions src/ota_proxy/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from concurrent.futures import Executor
from hashlib import sha256
from os import PathLike
from typing import AsyncIterator

import aiofiles
from anyio import open_file

from .config import config as cfg


async def read_file(fpath: PathLike, *, executor: Executor) -> AsyncIterator[bytes]:
"""Open and read a file asynchronously with aiofiles."""
async with aiofiles.open(fpath, "rb", executor=executor) as f:
async def read_file(fpath: PathLike) -> AsyncIterator[bytes]:
"""Open and read a file asynchronously."""
async with await open_file(fpath, "rb") as f:
while data := await f.read(cfg.CHUNK_SIZE):
yield data

Expand Down
23 changes: 10 additions & 13 deletions src/otaclient/_otaproxy_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from __future__ import annotations

import asyncio
import atexit
import logging
import multiprocessing as mp
Expand Down Expand Up @@ -78,18 +77,16 @@ def otaproxy_process(*, init_cache: bool) -> None:
logger.info(f"wait for {upper_proxy=} online...")
ensure_otaproxy_start(str(upper_proxy))

asyncio.run(
run_otaproxy(
host=host,
port=port,
init_cache=init_cache,
cache_dir=local_otaproxy_cfg.BASE_DIR,
cache_db_f=local_otaproxy_cfg.DB_FILE,
upper_proxy=upper_proxy,
enable_cache=proxy_info.enable_local_ota_proxy_cache,
enable_https=proxy_info.gateway_otaproxy,
external_cache_mnt_point=external_cache_mnt_point,
)
run_otaproxy(
host=host,
port=port,
init_cache=init_cache,
cache_dir=local_otaproxy_cfg.BASE_DIR,
cache_db_f=local_otaproxy_cfg.DB_FILE,
upper_proxy=upper_proxy,
enable_cache=proxy_info.enable_local_ota_proxy_cache,
enable_https=proxy_info.gateway_otaproxy,
external_cache_mnt_point=external_cache_mnt_point,
)


Expand Down
1 change: 0 additions & 1 deletion tests/test_ota_proxy/test_cache_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ async def _worker(
_tracker = CacheTracker(
cache_identifier=self.URL,
base_dir=self.base_dir,
executor=None, # type: ignore
commit_cache_cb=None, # type: ignore
below_hard_limit_event=None, # type: ignore
)
Expand Down
Loading
Loading