diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index 42ca06b..d835ddf 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -21,7 +21,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.9' + python-version: '3.11' architecture: 'x64' - name: Restore dependency cache diff --git a/Makefile b/Makefile index eab2cf2..deecf79 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ clean-env: ## remove environment lint: ## static code analysis black --line-length 120 --check fixcloudutils tests flake8 fixcloudutils - mypy --python-version 3.9 --strict --install-types --non-interactive fixcloudutils tests + mypy --python-version 3.11 --strict --install-types --non-interactive fixcloudutils tests test: ## run tests quickly with the default Python pytest diff --git a/fixcloudutils/redis/cache.py b/fixcloudutils/redis/cache.py new file mode 100644 index 0000000..491ab8b --- /dev/null +++ b/fixcloudutils/redis/cache.py @@ -0,0 +1,261 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +import base64 +import hashlib +import logging +import pickle +from asyncio import Task, Queue, CancelledError, create_task +from datetime import datetime, timedelta +from typing import Any, Optional, TypeVar, Callable, ParamSpec, NewType + +from attr import frozen +from prometheus_client import Counter +from redis.asyncio import Redis + +from fixcloudutils.asyncio import stop_running_task +from fixcloudutils.asyncio.periodic import Periodic +from fixcloudutils.redis.pub_sub import RedisPubSubListener, RedisPubSubPublisher +from fixcloudutils.service import Service +from fixcloudutils.types import Json +from fixcloudutils.util import utc, uuid_str + +log = logging.getLogger("fixcloudutils.redis.cache") +RedisKey = NewType("RedisKey", str) +P = ParamSpec("P") +T = TypeVar("T") +CacheHit = Counter("redis_cache", "Redis Cache", ["key", "stage"]) + + +@frozen +class RedisCacheSet: + key: RedisKey + value: Any + ttl: timedelta + + +@frozen +class RedisCacheEvict: + key: RedisKey + + +# Actions that can be sent to a RedisCache queue listener +RedisAction = RedisCacheSet | RedisCacheEvict + + +@frozen +class RedisCacheEntry: + key: Any + redis_key: RedisKey + value: Any + deadline: datetime + + +class RedisCache(Service): + def __init__( + self, + redis: Redis, + key: str, + *, + ttl_memory: Optional[timedelta] = None, + ttl_redis: Optional[timedelta] = None, + cleaner_task_frequency: timedelta = timedelta(seconds=10), + ) -> None: + self.redis = redis + self.key = key + self.ttl_memory = ttl_memory or timedelta(minutes=5) + self.ttl_redis = ttl_redis or timedelta(minutes=30) + self.queue: Queue[RedisAction] = Queue(128) + self.should_run = True + self.started = False + self.process_queue_task: Optional[Task[None]] = None + self.event_listener = RedisPubSubListener(redis, f"cache_evict:{key}", self._handle_evict_message) + self.event_publisher = RedisPubSubPublisher(redis, f"cache_evict:{key}", str(uuid_str())) + self.local_cache: dict[Any, RedisCacheEntry] = {} + self.cached_functions: dict[str, Any] = {} + self.cleaner_task = Periodic("wipe_local_cache", self._wipe_outdated_from_local_cache, cleaner_task_frequency) + + async def start(self) -> None: + if self.started: + return + self.started = True + self.should_run = True + await self.cleaner_task.start() + self.process_queue_task = create_task(self._process_queue()) + + async def stop(self) -> None: + if not self.started: + return + self.should_run = False + await stop_running_task(self.process_queue_task) + await self.cleaner_task.stop() + self.started = False + + async def evict(self, fn_name: str, key: str) -> None: + log.info(f"{self.key}:{fn_name} Evict {key}") + await self.queue.put(RedisCacheEvict(self._redis_key(fn_name, key))) + + def evict_with(self, fn: Callable[P, T]) -> Callable[P, T]: + async def evict_fn(*args: Any, **kwargs: Any) -> None: + key = self._redis_key(fn.__name__, None, *args, **kwargs) + log.info(f"{self.key}:{fn.__name__} Evict args based key: {key}") + await self.queue.put(RedisCacheEvict(key)) + + return evict_fn # type: ignore + + def call( + self, + fn: Callable[P, T], + *, + key: Optional[str] = None, + ttl_memory: Optional[timedelta] = None, + ttl_redis: Optional[timedelta] = None, + ) -> Callable[P, T]: + """ + This is the memoization function. + If a value for this function call is available in the local cache, it will be returned. + If a value for this function call is available in redis, it will be returned and added to the local cache. + Otherwise, the function will be called and the result will be added to redis and the local cache. + + :param fn: The function that should be memoized. + :param key: The key to use for memoization. If not provided, the function name and the arguments will be used. + :param ttl_redis: The time to live for the redis entry. If not provided, the default ttl will be used. + :param ttl_memory: The time to live for the local cache entry. If not provided, the default ttl will be used. + :return: The result of the function call. + """ + + async def handle_call(*args: Any, **kwargs: Any) -> T: + # check if the value is available in the local cache + fns = fn.__name__ + local_cache_key = key or (fns, *args, *kwargs.values()) + if local_value := self.local_cache.get(local_cache_key): + log.info(f"{self.key}:{fns} Serve result from local cache.") + CacheHit.labels(self.key, "local").inc() + return local_value.value # type: ignore + # check if the value is available in redis + redis_key = self._redis_key(fns, key, *args, **kwargs) + if redis_value := await self.redis.get(redis_key): + log.info(f"{self.key}:{fns} Serve result from redis cache.") + CacheHit.labels(self.key, "redis").inc() + result: T = pickle.loads(base64.b64decode(redis_value)) + self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory) + return result + # call the function + result = await fn(*args, **kwargs) # type: ignore + CacheHit.labels(self.key, "call").inc() + self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory) + await self.queue.put(RedisCacheSet(redis_key, result, ttl_redis or self.ttl_redis)) + return result + + return handle_call # type: ignore + + async def _process_queue(self) -> None: + """ + Local queue processor which will execute tasks in a separate execution context. + """ + while self.should_run: + try: + entry = await self.queue.get() + if isinstance(entry, RedisCacheSet): + log.info(f"{self.key}: Store cached value in redis as {entry.key}") + value = base64.b64encode(pickle.dumps(entry.value)) + await self.redis.set(name=entry.key, value=value, ex=entry.ttl) + elif isinstance(entry, RedisCacheEvict): + log.info(f"{self.key}: Delete cached value from redis key {entry.key}") + # delete the entry + await self.redis.delete(entry.key) + # inform all other cache instances to evict the key + await self.event_publisher.publish("evict", {"redis_key": entry.key}) + # delete from local cache + self._remove_from_local_cache(entry.key) + else: + log.warning(f"Unknown entry in queue: {entry}") + except CancelledError: + return # ignore + except Exception as ex: + log.warning("Failed to process queue", exc_info=ex) + + async def _handle_evict_message(self, uid: str, at: datetime, publisher: str, kind: str, data: Json) -> None: + """ + PubSub listener for evict messages. + """ + log.info(f"Received message: {kind} {data} from {publisher} at {at} by {uid}") + if kind == "evict" and (redis_key := data.get("redis_key")): + # delete from local cache + self._remove_from_local_cache(redis_key) + else: # pragma: no cover + log.warning(f"Unknown message: {kind} {data}") + + async def _wipe_outdated_from_local_cache(self) -> None: + """ + Periodically called by this cache instance to remove outdated entries from the local cache. + """ + now = utc() + for key, entry in list(self.local_cache.items()): + if entry.deadline < now: + log.info(f"Evicting {key} from local cache") + del self.local_cache[key] + + def _add_to_local_cache(self, key: Any, redis_key: RedisKey, value: Any, ttl: timedelta) -> None: + entry = RedisCacheEntry(key=key, redis_key=redis_key, value=value, deadline=utc() + ttl) + self.local_cache[key] = entry + + def _remove_from_local_cache(self, redis_key: str) -> None: + local_key: Optional[Any] = None + for key, entry in self.local_cache.items(): + if entry.redis_key == redis_key: + local_key = key + break + if local_key: + log.info(f"Evicting {redis_key} from local cache") + del self.local_cache[local_key] + + def _redis_key(self, fn_name: str, fn_key: Optional[str], *args: Any, **kwargs: Any) -> RedisKey: + if fn_key is None: + sha = hashlib.sha256() + for a in args: + sha.update(pickle.dumps(a)) + for k, v in kwargs.items(): + sha.update(pickle.dumps(k)) + sha.update(pickle.dumps(v)) + fn_key = sha.hexdigest() + return RedisKey(f"cache:{self.key}:{fn_name}:{fn_key}") + + +class redis_cached: # noqa + """ + Decorator for caching function calls in redis. + + Usage: + >>> redis = Redis() + >>> cache = RedisCache(redis, "test", "test1") + >>> @redis_cached(cache) + ... async def f(a: int, b: int) -> int: + ... return a + b + """ + + def __init__( + self, + cache: RedisCache, + *, + key: Optional[str] = None, + ttl_memory: Optional[timedelta] = None, + ttl_redis: Optional[timedelta] = None, + ): + self.cache = cache + self.key = key + self.ttl_memory = ttl_memory + self.ttl_redis = ttl_redis + + def __call__(self, fn: Callable[P, T]) -> Callable[P, T]: + return self.cache.call(fn, key=self.key, ttl_memory=self.ttl_memory, ttl_redis=self.ttl_redis) diff --git a/pyproject.toml b/pyproject.toml index 4514bfb..11e9b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fixcloudutils" -version = "1.11.0" +version = "1.12.0" authors = [{ name = "Some Engineering Inc." }] description = "Utilities for fixcloud." license = { file = "LICENSE" } diff --git a/tests/cache_test.py b/tests/cache_test.py new file mode 100644 index 0000000..f4f090a --- /dev/null +++ b/tests/cache_test.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +import os +from contextlib import AsyncExitStack +from datetime import timedelta + +import pytest +from redis.asyncio import Redis + +from conftest import eventually +from fixcloudutils.redis.cache import RedisCache + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running") +async def test_cache(redis: Redis) -> None: + t0 = timedelta(seconds=0.1) + t1 = timedelta(seconds=0.5) + t2 = timedelta(seconds=1) + t3 = timedelta(seconds=60) + + async with AsyncExitStack() as stack: + cache1 = await stack.enter_async_context( + RedisCache(redis, "test", ttl_redis=t2, ttl_memory=t1, cleaner_task_frequency=t0) + ) + cache2 = await stack.enter_async_context( + RedisCache(redis, "test", ttl_redis=t2, ttl_memory=t1, cleaner_task_frequency=t0) + ) + call_count = 0 + + async def complex_function(a: int, b: int) -> int: + nonlocal call_count + call_count += 1 + return a + b + + async def local_cache_is_empty(cache: RedisCache) -> bool: + return len(cache.local_cache) == 0 + + key = cache1._redis_key(complex_function.__name__, None, 1, 2) + assert await cache1.call(complex_function)(1, 2) == 3 + assert call_count == 1 + assert len(cache1.local_cache) == 1 + # should come from internal memory cache + assert await cache1.call(complex_function)(1, 2) == 3 + assert call_count == 1 + await eventually(redis.exists, key, timeout=2) + # should come from redis cache + assert len(cache2.local_cache) == 0 + assert await cache2.call(complex_function)(1, 2) == 3 + assert call_count == 1 + assert len(cache2.local_cache) == 1 + # after ttl expires, the local cache is empty + await eventually(local_cache_is_empty, cache1, timeout=1) + await eventually(local_cache_is_empty, cache2, timeout=1) + # after redis ttl the cache is evicted + await eventually(redis.exists, key, fn=lambda x: not x, timeout=2) + + # calling this method again should trigger a new call and a new cache entry + # we use a loner redis ttl to test the eviction of the redis cache + for a in range(100): + assert await cache1.call(complex_function, ttl_redis=t3)(a, 2) == a + 2 + assert call_count == 101 + assert len(cache1.local_cache) == 100 + for a in range(100): + assert await cache1.call(complex_function)(a, 2) == a + 2 + assert await cache2.call(complex_function)(a, 2) == a + 2 + assert len(cache1.local_cache) == 100 + assert len(cache2.local_cache) == 100 + # no more calls are done + assert call_count == 101 + + # evict all entries should evict all messages in all caches + for a in range(100): + await cache1.evict_with(complex_function)(a, 2) + await eventually(local_cache_is_empty, cache1, timeout=1) + await eventually(local_cache_is_empty, cache2, timeout=1) diff --git a/tests/conftest.py b/tests/conftest.py index 3dcd4da..a7c4d6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,8 +24,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - -from typing import List, AsyncIterator +import asyncio +from datetime import timedelta +from typing import List, AsyncIterator, Awaitable, TypeVar, Optional, Callable, Any, ParamSpec from arango.client import ArangoClient from attr import define @@ -35,12 +36,39 @@ from redis.backoff import ExponentialBackoff from fixcloudutils.arangodb.async_arangodb import AsyncArangoDB +from fixcloudutils.util import utc + +PT = ParamSpec("PT") +T = TypeVar("T") + + +async def eventually( + func: Callable[PT, Awaitable[T]], + *args: Any, + fn: Optional[Callable[[Any], bool]] = None, + timeout: float = 10, + interval: float = 0.1, +) -> T: + deadline = utc() + timedelta(seconds=timeout) + ex: Optional[Exception] = None + while True: + try: + result = await func(*args) + truthy = fn(result) if fn else bool(result) + if truthy: + return result + except Exception as e: + ex = e + if utc() > deadline: + raise TimeoutError(f"Timeout after {timeout} seconds") from ex + await asyncio.sleep(interval) @fixture async def redis() -> AsyncIterator[Redis]: backoff = ExponentialBackoff() # type: ignore - redis = Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10)) + redis = Redis(host="localhost", port=6379, db=0, decode_responses=True, retry=Retry(backoff, 10)) + await redis.flushdb() # wipe redis yield redis await redis.close(True) diff --git a/tox.ini b/tox.ini index 746c1e5..7f0507c 100644 --- a/tox.ini +++ b/tox.ini @@ -20,7 +20,7 @@ commands = black --line-length 120 --check --diff --target-version py39 . commands = flake8 fixcloudutils [testenv:mypy] -commands= python -m mypy --install-types --non-interactive --python-version 3.9 --strict fixcloudutils tests +commands= python -m mypy --install-types --non-interactive --python-version 3.11 --strict fixcloudutils tests [testenv:tests] commands = pytest