Skip to content

Commit

Permalink
Merge branch 'main' into feat/csv_parser
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang authored Dec 25, 2024
2 parents 3fdec17 + fefed56 commit aea9581
Show file tree
Hide file tree
Showing 31 changed files with 211 additions and 203 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dynamic = [
"version",
]
dependencies = [
"aiofiles<25,>=24.1",
"anyio>=4.5.1,<5",
"aiohttp>=3.10.11,<3.12",
"cryptography>=43.0.1,<45",
"grpcio>=1.53.2,<1.69",
Expand All @@ -35,9 +35,9 @@ dependencies = [
"pydantic-settings<3,>=2.3",
"pyyaml<7,>=6.0.1",
"requests<2.33,>=2.32",
"simple-sqlite3-orm<0.7,>=0.6",
"simple-sqlite3-orm<0.8,>=0.7",
"typing-extensions>=4.6.3",
"urllib3<2.3,>=2.2.2",
"urllib3>=2.2.2,<2.4",
"uvicorn[standard]>=0.30,<0.35",
"zstandard<0.24,>=0.22",
]
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Automatically generated from pyproject.toml by gen_requirements_txt.py script.
# DO NOT EDIT! Only for reference use.
aiofiles<25,>=24.1
anyio>=4.5.1,<5
aiohttp>=3.10.11,<3.12
cryptography>=43.0.1,<45
grpcio>=1.53.2,<1.69
Expand All @@ -11,8 +11,8 @@ pydantic<3,>=2.10
pydantic-settings<3,>=2.3
pyyaml<7,>=6.0.1
requests<2.33,>=2.32
simple-sqlite3-orm<0.7,>=0.6
simple-sqlite3-orm<0.8,>=0.7
typing-extensions>=4.6.3
urllib3<2.3,>=2.2.2
urllib3>=2.2.2,<2.4
uvicorn[standard]>=0.30,<0.35
zstandard<0.24,>=0.22
5 changes: 3 additions & 2 deletions src/ota_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)


async def run_otaproxy(
def run_otaproxy(
host: str,
port: int,
*,
Expand All @@ -45,6 +45,7 @@ async def run_otaproxy(
enable_https: bool,
external_cache_mnt_point: str | None = None,
):
import anyio
import uvicorn

from . import App, OTACache
Expand All @@ -69,4 +70,4 @@ async def run_otaproxy(
http="h11",
)
_server = uvicorn.Server(_config)
await _server.serve()
anyio.run(_server.serve, backend="asyncio", backend_options={"use_uvloop": True})
26 changes: 10 additions & 16 deletions src/ota_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from __future__ import annotations

import argparse
import asyncio
import logging

import uvloop

from . import run_otaproxy
from .config import config as cfg

Expand Down Expand Up @@ -78,17 +75,14 @@
args = parser.parse_args()

logger.info(f"launch ota_proxy at {args.host}:{args.port}")
uvloop.install()
asyncio.run(
run_otaproxy(
host=args.host,
port=args.port,
cache_dir=args.cache_dir,
cache_db_f=args.cache_db_file,
enable_cache=args.enable_cache,
upper_proxy=args.upper_proxy,
enable_https=args.enable_https,
init_cache=args.init_cache,
external_cache_mnt_point=args.external_cache_mnt_point,
)
run_otaproxy(
host=args.host,
port=args.port,
cache_dir=args.cache_dir,
cache_db_f=args.cache_db_file,
enable_cache=args.enable_cache,
upper_proxy=args.upper_proxy,
enable_https=args.enable_https,
init_cache=args.init_cache,
external_cache_mnt_point=args.external_cache_mnt_point,
)
14 changes: 6 additions & 8 deletions src/ota_proxy/cache_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import os
import threading
import weakref
from concurrent.futures import Executor
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Callable, Coroutine

import aiofiles
import anyio
from anyio import open_file

from otaclient_common.common import get_backoff
from otaclient_common.typing import StrOrPath
Expand Down Expand Up @@ -101,19 +101,17 @@ def __init__(
*,
base_dir: StrOrPath,
commit_cache_cb: _CACHE_ENTRY_REGISTER_CALLBACK,
executor: Executor,
below_hard_limit_event: threading.Event,
):
self.fpath = Path(base_dir) / self._tmp_file_naming(cache_identifier)
self.save_path = Path(base_dir) / cache_identifier
self.save_path = anyio.Path(base_dir) / cache_identifier
self.cache_meta: CacheMeta | None = None
self._commit_cache_cb = commit_cache_cb

self._writer_ready = asyncio.Event()
self._writer_finished = asyncio.Event()
self._writer_failed = asyncio.Event()

self._executor = executor
self._space_availability_event = below_hard_limit_event

self._bytes_written = 0
Expand Down Expand Up @@ -147,7 +145,7 @@ async def _provider_write_cache(
"""
logger.debug(f"start to cache for {cache_meta=}...")
try:
async with aiofiles.open(self.fpath, "wb", executor=self._executor) as f:
async with await open_file(self.fpath, "wb") as f:
_written = 0
while _data := (yield _written):
if not self._space_availability_event.is_set():
Expand Down Expand Up @@ -179,7 +177,7 @@ async def _provider_write_cache(
await self._commit_cache_cb(cache_meta)
# finalize the cache file, skip finalize if the target file is
# already presented.
if not self.save_path.is_file():
if not await self.save_path.is_file():
os.link(self.fpath, self.save_path)
except Exception as e:
logger.warning(f"failed to write cache for {cache_meta=}: {e!r}")
Expand All @@ -202,7 +200,7 @@ async def _subscriber_stream_cache(self) -> AsyncIterator[bytes]:
"""
err_count, _bytes_read = 0, 0
try:
async with aiofiles.open(self.fpath, "rb", executor=self._executor) as f:
async with await open_file(self.fpath, "rb") as f:
while (
not self._writer_finished.is_set()
or _bytes_read < self._bytes_written
Expand Down
14 changes: 2 additions & 12 deletions src/ota_proxy/lru_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,17 @@
from __future__ import annotations

import bisect
import logging
import sqlite3
import time
from pathlib import Path

from simple_sqlite3_orm import utils

from otaclient_common.logging import BurstSuppressFilter
from otaclient_common.logging import get_burst_suppressed_logger

from .db import AsyncCacheMetaORM, CacheMeta

burst_suppressed_logger = logging.getLogger(f"{__name__}.db_error")
# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds
burst_suppressed_logger.addFilter(
BurstSuppressFilter(
f"{__name__}.db_error",
upper_logger_name=__name__,
burst_round_length=30,
burst_max=6,
)
)
burst_suppressed_logger = get_burst_suppressed_logger(f"{__name__}.db_error")


class LRUCacheHelper:
Expand Down
46 changes: 24 additions & 22 deletions src/ota_proxy/ota_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import AsyncIterator, Mapping, Optional
from urllib.parse import SplitResult, quote, urlsplit

import aiohttp
import anyio
import anyio.to_thread
from multidict import CIMultiDict, CIMultiDictProxy

from otaclient_common.common import get_backoff
Expand Down Expand Up @@ -133,10 +134,6 @@ def __init__(
db_f.unlink(missing_ok=True)
self._init_cache = True # force init cache on db file cleanup

self._executor = ThreadPoolExecutor(
thread_name_prefix="ota_cache_fileio_executor"
)

self._external_cache_data_dir = None
self._external_cache_mp = None
if external_cache_mnt_point and mount_external_cache(external_cache_mnt_point):
Expand All @@ -145,7 +142,7 @@ def __init__(
)
self._external_cache_mp = external_cache_mnt_point
self._external_cache_data_dir = (
Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME
anyio.Path(external_cache_mnt_point) / cfg.EXTERNAL_CACHE_DATA_DNAME
)

self._storage_below_hard_limit_event = threading.Event()
Expand Down Expand Up @@ -189,11 +186,18 @@ async def start(self):

# reuse the previously left ota_cache
else: # cleanup unfinished tmp files
for tmp_f in self._base_dir.glob(f"{cfg.TMP_FILE_PREFIX}*"):
tmp_f.unlink(missing_ok=True)
async for tmp_f in anyio.Path(self._base_dir).glob(
f"{cfg.TMP_FILE_PREFIX}*"
):
await tmp_f.unlink(missing_ok=True)

# dispatch a background task to pulling the disk usage info
self._executor.submit(self._background_check_free_space)
_free_space_check_thread = threading.Thread(
target=self._background_check_free_space,
daemon=True,
name="ota_cache_free_space_checker",
)
_free_space_check_thread.start()

# init cache helper(and connect to ota_cache db)
self._lru_helper = LRUCacheHelper(
Expand Down Expand Up @@ -222,7 +226,6 @@ async def close(self):
if not self._closed:
self._closed = True
await self._session.close()
self._executor.shutdown(wait=True)

if self._cache_enabled:
self._lru_helper.close()
Expand Down Expand Up @@ -311,7 +314,7 @@ async def _reserve_space(self, size: int) -> bool:
logger.debug(
f"rotate on bucket({size=}), num of entries to be cleaned {len(_hashes)=}"
)
self._executor.submit(self._cache_entries_cleanup, _hashes)
await anyio.to_thread.run_sync(self._cache_entries_cleanup, _hashes)
return True
else:
logger.debug(f"rotate on bucket({size=}) failed, no enough entries")
Expand Down Expand Up @@ -429,19 +432,19 @@ async def _retrieve_file_by_cache_lookup(
# NOTE: db_entry.file_sha256 can be either
# 1. valid sha256 value for corresponding plain uncompressed OTA file
# 2. URL based sha256 value for corresponding requested URL
cache_file = self._base_dir / cache_identifier
cache_file = anyio.Path(self._base_dir / cache_identifier)

# check if cache file exists
# NOTE(20240729): there is an edge condition that the finished cached file is not yet renamed,
# but the database entry has already been inserted. Here we wait for 3 rounds for
# cache_commit_callback to rename the tmp file.
_retry_count_max, _factor, _backoff_max = 6, 0.01, 0.1 # 0.255s in total
for _retry_count in range(_retry_count_max):
if cache_file.is_file():
if await cache_file.is_file():
break
await asyncio.sleep(get_backoff(_retry_count, _factor, _backoff_max))

if not cache_file.is_file():
if not await cache_file.is_file():
logger.warning(
f"dangling cache entry found, remove db entry: {meta_db_entry}"
)
Expand All @@ -452,7 +455,7 @@ async def _retrieve_file_by_cache_lookup(
# do the job. If cache is invalid, otaclient will use CacheControlHeader's retry_cache
# directory to indicate invalid cache.
return (
read_file(cache_file, executor=self._executor),
read_file(cache_file),
meta_db_entry.export_headers_to_client(),
)

Expand All @@ -470,28 +473,28 @@ async def _retrieve_file_by_external_cache(

cache_identifier = client_cache_policy.file_sha256
cache_file = self._external_cache_data_dir / cache_identifier
cache_file_zst = cache_file.with_suffix(
f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}"
cache_file_zst = anyio.Path(
cache_file.with_suffix(f".{cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG}")
)

if cache_file_zst.is_file():
if await cache_file_zst.is_file():
_header = CIMultiDict()
_header[HEADER_OTA_FILE_CACHE_CONTROL] = (
OTAFileCacheControl.export_kwargs_as_header(
file_sha256=cache_identifier,
file_compression_alg=cfg.EXTERNAL_CACHE_STORAGE_COMPRESS_ALG,
)
)
return read_file(cache_file_zst, executor=self._executor), _header
return read_file(cache_file_zst), _header

if cache_file.is_file():
if await cache_file.is_file():
_header = CIMultiDict()
_header[HEADER_OTA_FILE_CACHE_CONTROL] = (
OTAFileCacheControl.export_kwargs_as_header(
file_sha256=cache_identifier
)
)
return read_file(cache_file, executor=self._executor), _header
return read_file(cache_file), _header

async def _retrieve_file_by_new_caching(
self,
Expand Down Expand Up @@ -534,7 +537,6 @@ async def _retrieve_file_by_new_caching(
tracker = CacheTracker(
cache_identifier=cache_identifier,
base_dir=self._base_dir,
executor=self._executor,
commit_cache_cb=self._commit_cache_callback,
below_hard_limit_event=self._storage_below_hard_limit_event,
)
Expand Down
12 changes: 2 additions & 10 deletions src/ota_proxy/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import aiohttp
from multidict import CIMultiDict, CIMultiDictProxy

from otaclient_common.logging import BurstSuppressFilter
from otaclient_common.logging import get_burst_suppressed_logger

from ._consts import (
BHEADER_AUTHORIZATION,
Expand All @@ -46,16 +46,8 @@
from .ota_cache import OTACache

logger = logging.getLogger(__name__)
burst_suppressed_logger = logging.getLogger(f"{__name__}.request_error")
# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds
burst_suppressed_logger.addFilter(
BurstSuppressFilter(
f"{__name__}.request_error",
upper_logger_name=__name__,
burst_round_length=30,
burst_max=6,
)
)
burst_suppressed_logger = get_burst_suppressed_logger(f"{__name__}.request_error")

# only expose app
__all__ = ("App",)
Expand Down
9 changes: 4 additions & 5 deletions src/ota_proxy/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from concurrent.futures import Executor
from hashlib import sha256
from os import PathLike
from typing import AsyncIterator

import aiofiles
from anyio import open_file

from .config import config as cfg


async def read_file(fpath: PathLike, *, executor: Executor) -> AsyncIterator[bytes]:
"""Open and read a file asynchronously with aiofiles."""
async with aiofiles.open(fpath, "rb", executor=executor) as f:
async def read_file(fpath: PathLike) -> AsyncIterator[bytes]:
"""Open and read a file asynchronously."""
async with await open_file(fpath, "rb") as f:
while data := await f.read(cfg.CHUNK_SIZE):
yield data

Expand Down
Loading

0 comments on commit aea9581

Please sign in to comment.