diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml
index 42ca06b..d835ddf 100644
--- a/.github/workflows/build_and_publish.yml
+++ b/.github/workflows/build_and_publish.yml
@@ -21,7 +21,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v4
with:
- python-version: '3.9'
+ python-version: '3.11'
architecture: 'x64'
- name: Restore dependency cache
diff --git a/Makefile b/Makefile
index eab2cf2..deecf79 100644
--- a/Makefile
+++ b/Makefile
@@ -58,7 +58,7 @@ clean-env: ## remove environment
lint: ## static code analysis
black --line-length 120 --check fixcloudutils tests
flake8 fixcloudutils
- mypy --python-version 3.9 --strict --install-types --non-interactive fixcloudutils tests
+ mypy --python-version 3.11 --strict --install-types --non-interactive fixcloudutils tests
test: ## run tests quickly with the default Python
pytest
diff --git a/fixcloudutils/redis/cache.py b/fixcloudutils/redis/cache.py
new file mode 100644
index 0000000..491ab8b
--- /dev/null
+++ b/fixcloudutils/redis/cache.py
@@ -0,0 +1,261 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+import base64
+import hashlib
+import logging
+import pickle
+from asyncio import Task, Queue, CancelledError, create_task
+from datetime import datetime, timedelta
+from typing import Any, Optional, TypeVar, Callable, ParamSpec, NewType
+
+from attr import frozen
+from prometheus_client import Counter
+from redis.asyncio import Redis
+
+from fixcloudutils.asyncio import stop_running_task
+from fixcloudutils.asyncio.periodic import Periodic
+from fixcloudutils.redis.pub_sub import RedisPubSubListener, RedisPubSubPublisher
+from fixcloudutils.service import Service
+from fixcloudutils.types import Json
+from fixcloudutils.util import utc, uuid_str
+
+log = logging.getLogger("fixcloudutils.redis.cache")
+RedisKey = NewType("RedisKey", str)
+P = ParamSpec("P")
+T = TypeVar("T")
+CacheHit = Counter("redis_cache", "Redis Cache", ["key", "stage"])
+
+
+@frozen
+class RedisCacheSet:
+ key: RedisKey
+ value: Any
+ ttl: timedelta
+
+
+@frozen
+class RedisCacheEvict:
+ key: RedisKey
+
+
+# Actions that can be sent to a RedisCache queue listener
+RedisAction = RedisCacheSet | RedisCacheEvict
+
+
+@frozen
+class RedisCacheEntry:
+ key: Any
+ redis_key: RedisKey
+ value: Any
+ deadline: datetime
+
+
+class RedisCache(Service):
+ def __init__(
+ self,
+ redis: Redis,
+ key: str,
+ *,
+ ttl_memory: Optional[timedelta] = None,
+ ttl_redis: Optional[timedelta] = None,
+ cleaner_task_frequency: timedelta = timedelta(seconds=10),
+ ) -> None:
+ self.redis = redis
+ self.key = key
+ self.ttl_memory = ttl_memory or timedelta(minutes=5)
+ self.ttl_redis = ttl_redis or timedelta(minutes=30)
+ self.queue: Queue[RedisAction] = Queue(128)
+ self.should_run = True
+ self.started = False
+ self.process_queue_task: Optional[Task[None]] = None
+ self.event_listener = RedisPubSubListener(redis, f"cache_evict:{key}", self._handle_evict_message)
+ self.event_publisher = RedisPubSubPublisher(redis, f"cache_evict:{key}", str(uuid_str()))
+ self.local_cache: dict[Any, RedisCacheEntry] = {}
+ self.cached_functions: dict[str, Any] = {}
+ self.cleaner_task = Periodic("wipe_local_cache", self._wipe_outdated_from_local_cache, cleaner_task_frequency)
+
+ async def start(self) -> None:
+ if self.started:
+ return
+ self.started = True
+ self.should_run = True
+ await self.cleaner_task.start()
+ self.process_queue_task = create_task(self._process_queue())
+
+ async def stop(self) -> None:
+ if not self.started:
+ return
+ self.should_run = False
+ await stop_running_task(self.process_queue_task)
+ await self.cleaner_task.stop()
+ self.started = False
+
+ async def evict(self, fn_name: str, key: str) -> None:
+ log.info(f"{self.key}:{fn_name} Evict {key}")
+ await self.queue.put(RedisCacheEvict(self._redis_key(fn_name, key)))
+
+ def evict_with(self, fn: Callable[P, T]) -> Callable[P, T]:
+ async def evict_fn(*args: Any, **kwargs: Any) -> None:
+ key = self._redis_key(fn.__name__, None, *args, **kwargs)
+ log.info(f"{self.key}:{fn.__name__} Evict args based key: {key}")
+ await self.queue.put(RedisCacheEvict(key))
+
+ return evict_fn # type: ignore
+
+ def call(
+ self,
+ fn: Callable[P, T],
+ *,
+ key: Optional[str] = None,
+ ttl_memory: Optional[timedelta] = None,
+ ttl_redis: Optional[timedelta] = None,
+ ) -> Callable[P, T]:
+ """
+ This is the memoization function.
+ If a value for this function call is available in the local cache, it will be returned.
+ If a value for this function call is available in redis, it will be returned and added to the local cache.
+ Otherwise, the function will be called and the result will be added to redis and the local cache.
+
+ :param fn: The function that should be memoized.
+ :param key: The key to use for memoization. If not provided, the function name and the arguments will be used.
+ :param ttl_redis: The time to live for the redis entry. If not provided, the default ttl will be used.
+ :param ttl_memory: The time to live for the local cache entry. If not provided, the default ttl will be used.
+ :return: The result of the function call.
+ """
+
+ async def handle_call(*args: Any, **kwargs: Any) -> T:
+ # check if the value is available in the local cache
+ fns = fn.__name__
+ local_cache_key = key or (fns, *args, *kwargs.values())
+ if local_value := self.local_cache.get(local_cache_key):
+ log.info(f"{self.key}:{fns} Serve result from local cache.")
+ CacheHit.labels(self.key, "local").inc()
+ return local_value.value # type: ignore
+ # check if the value is available in redis
+ redis_key = self._redis_key(fns, key, *args, **kwargs)
+ if redis_value := await self.redis.get(redis_key):
+ log.info(f"{self.key}:{fns} Serve result from redis cache.")
+ CacheHit.labels(self.key, "redis").inc()
+ result: T = pickle.loads(base64.b64decode(redis_value))
+ self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory)
+ return result
+ # call the function
+ result = await fn(*args, **kwargs) # type: ignore
+ CacheHit.labels(self.key, "call").inc()
+ self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory)
+ await self.queue.put(RedisCacheSet(redis_key, result, ttl_redis or self.ttl_redis))
+ return result
+
+ return handle_call # type: ignore
+
+ async def _process_queue(self) -> None:
+ """
+ Local queue processor which will execute tasks in a separate execution context.
+ """
+ while self.should_run:
+ try:
+ entry = await self.queue.get()
+ if isinstance(entry, RedisCacheSet):
+ log.info(f"{self.key}: Store cached value in redis as {entry.key}")
+ value = base64.b64encode(pickle.dumps(entry.value))
+ await self.redis.set(name=entry.key, value=value, ex=entry.ttl)
+ elif isinstance(entry, RedisCacheEvict):
+ log.info(f"{self.key}: Delete cached value from redis key {entry.key}")
+ # delete the entry
+ await self.redis.delete(entry.key)
+ # inform all other cache instances to evict the key
+ await self.event_publisher.publish("evict", {"redis_key": entry.key})
+ # delete from local cache
+ self._remove_from_local_cache(entry.key)
+ else:
+ log.warning(f"Unknown entry in queue: {entry}")
+ except CancelledError:
+ return # ignore
+ except Exception as ex:
+ log.warning("Failed to process queue", exc_info=ex)
+
+ async def _handle_evict_message(self, uid: str, at: datetime, publisher: str, kind: str, data: Json) -> None:
+ """
+ PubSub listener for evict messages.
+ """
+ log.info(f"Received message: {kind} {data} from {publisher} at {at} by {uid}")
+ if kind == "evict" and (redis_key := data.get("redis_key")):
+ # delete from local cache
+ self._remove_from_local_cache(redis_key)
+ else: # pragma: no cover
+ log.warning(f"Unknown message: {kind} {data}")
+
+ async def _wipe_outdated_from_local_cache(self) -> None:
+ """
+ Periodically called by this cache instance to remove outdated entries from the local cache.
+ """
+ now = utc()
+ for key, entry in list(self.local_cache.items()):
+ if entry.deadline < now:
+ log.info(f"Evicting {key} from local cache")
+ del self.local_cache[key]
+
+ def _add_to_local_cache(self, key: Any, redis_key: RedisKey, value: Any, ttl: timedelta) -> None:
+ entry = RedisCacheEntry(key=key, redis_key=redis_key, value=value, deadline=utc() + ttl)
+ self.local_cache[key] = entry
+
+ def _remove_from_local_cache(self, redis_key: str) -> None:
+ local_key: Optional[Any] = None
+ for key, entry in self.local_cache.items():
+ if entry.redis_key == redis_key:
+ local_key = key
+ break
+ if local_key:
+ log.info(f"Evicting {redis_key} from local cache")
+ del self.local_cache[local_key]
+
+ def _redis_key(self, fn_name: str, fn_key: Optional[str], *args: Any, **kwargs: Any) -> RedisKey:
+ if fn_key is None:
+ sha = hashlib.sha256()
+ for a in args:
+ sha.update(pickle.dumps(a))
+ for k, v in kwargs.items():
+ sha.update(pickle.dumps(k))
+ sha.update(pickle.dumps(v))
+ fn_key = sha.hexdigest()
+ return RedisKey(f"cache:{self.key}:{fn_name}:{fn_key}")
+
+
+class redis_cached: # noqa
+ """
+ Decorator for caching function calls in redis.
+
+ Usage:
+ >>> redis = Redis()
+ >>> cache = RedisCache(redis, "test", "test1")
+ >>> @redis_cached(cache)
+ ... async def f(a: int, b: int) -> int:
+ ... return a + b
+ """
+
+ def __init__(
+ self,
+ cache: RedisCache,
+ *,
+ key: Optional[str] = None,
+ ttl_memory: Optional[timedelta] = None,
+ ttl_redis: Optional[timedelta] = None,
+ ):
+ self.cache = cache
+ self.key = key
+ self.ttl_memory = ttl_memory
+ self.ttl_redis = ttl_redis
+
+ def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
+ return self.cache.call(fn, key=self.key, ttl_memory=self.ttl_memory, ttl_redis=self.ttl_redis)
diff --git a/pyproject.toml b/pyproject.toml
index 4514bfb..11e9b8d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "fixcloudutils"
-version = "1.11.0"
+version = "1.12.0"
authors = [{ name = "Some Engineering Inc." }]
description = "Utilities for fixcloud."
license = { file = "LICENSE" }
diff --git a/tests/cache_test.py b/tests/cache_test.py
new file mode 100644
index 0000000..f4f090a
--- /dev/null
+++ b/tests/cache_test.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2023. Some Engineering
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+import os
+from contextlib import AsyncExitStack
+from datetime import timedelta
+
+import pytest
+from redis.asyncio import Redis
+
+from conftest import eventually
+from fixcloudutils.redis.cache import RedisCache
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(os.environ.get("REDIS_RUNNING") is None, reason="Redis is not running")
+async def test_cache(redis: Redis) -> None:
+ t0 = timedelta(seconds=0.1)
+ t1 = timedelta(seconds=0.5)
+ t2 = timedelta(seconds=1)
+ t3 = timedelta(seconds=60)
+
+ async with AsyncExitStack() as stack:
+ cache1 = await stack.enter_async_context(
+ RedisCache(redis, "test", ttl_redis=t2, ttl_memory=t1, cleaner_task_frequency=t0)
+ )
+ cache2 = await stack.enter_async_context(
+ RedisCache(redis, "test", ttl_redis=t2, ttl_memory=t1, cleaner_task_frequency=t0)
+ )
+ call_count = 0
+
+ async def complex_function(a: int, b: int) -> int:
+ nonlocal call_count
+ call_count += 1
+ return a + b
+
+ async def local_cache_is_empty(cache: RedisCache) -> bool:
+ return len(cache.local_cache) == 0
+
+ key = cache1._redis_key(complex_function.__name__, None, 1, 2)
+ assert await cache1.call(complex_function)(1, 2) == 3
+ assert call_count == 1
+ assert len(cache1.local_cache) == 1
+ # should come from internal memory cache
+ assert await cache1.call(complex_function)(1, 2) == 3
+ assert call_count == 1
+ await eventually(redis.exists, key, timeout=2)
+ # should come from redis cache
+ assert len(cache2.local_cache) == 0
+ assert await cache2.call(complex_function)(1, 2) == 3
+ assert call_count == 1
+ assert len(cache2.local_cache) == 1
+ # after ttl expires, the local cache is empty
+ await eventually(local_cache_is_empty, cache1, timeout=1)
+ await eventually(local_cache_is_empty, cache2, timeout=1)
+ # after redis ttl the cache is evicted
+ await eventually(redis.exists, key, fn=lambda x: not x, timeout=2)
+
+ # calling this method again should trigger a new call and a new cache entry
+ # we use a loner redis ttl to test the eviction of the redis cache
+ for a in range(100):
+ assert await cache1.call(complex_function, ttl_redis=t3)(a, 2) == a + 2
+ assert call_count == 101
+ assert len(cache1.local_cache) == 100
+ for a in range(100):
+ assert await cache1.call(complex_function)(a, 2) == a + 2
+ assert await cache2.call(complex_function)(a, 2) == a + 2
+ assert len(cache1.local_cache) == 100
+ assert len(cache2.local_cache) == 100
+ # no more calls are done
+ assert call_count == 101
+
+ # evict all entries should evict all messages in all caches
+ for a in range(100):
+ await cache1.evict_with(complex_function)(a, 2)
+ await eventually(local_cache_is_empty, cache1, timeout=1)
+ await eventually(local_cache_is_empty, cache2, timeout=1)
diff --git a/tests/conftest.py b/tests/conftest.py
index 3dcd4da..a7c4d6c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -24,8 +24,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-
-from typing import List, AsyncIterator
+import asyncio
+from datetime import timedelta
+from typing import List, AsyncIterator, Awaitable, TypeVar, Optional, Callable, Any, ParamSpec
from arango.client import ArangoClient
from attr import define
@@ -35,12 +36,39 @@
from redis.backoff import ExponentialBackoff
from fixcloudutils.arangodb.async_arangodb import AsyncArangoDB
+from fixcloudutils.util import utc
+
+PT = ParamSpec("PT")
+T = TypeVar("T")
+
+
+async def eventually(
+ func: Callable[PT, Awaitable[T]],
+ *args: Any,
+ fn: Optional[Callable[[Any], bool]] = None,
+ timeout: float = 10,
+ interval: float = 0.1,
+) -> T:
+ deadline = utc() + timedelta(seconds=timeout)
+ ex: Optional[Exception] = None
+ while True:
+ try:
+ result = await func(*args)
+ truthy = fn(result) if fn else bool(result)
+ if truthy:
+ return result
+ except Exception as e:
+ ex = e
+ if utc() > deadline:
+ raise TimeoutError(f"Timeout after {timeout} seconds") from ex
+ await asyncio.sleep(interval)
@fixture
async def redis() -> AsyncIterator[Redis]:
backoff = ExponentialBackoff() # type: ignore
- redis = Redis(host="localhost", port=6379, decode_responses=True, retry=Retry(backoff, 10))
+ redis = Redis(host="localhost", port=6379, db=0, decode_responses=True, retry=Retry(backoff, 10))
+ await redis.flushdb() # wipe redis
yield redis
await redis.close(True)
diff --git a/tox.ini b/tox.ini
index 746c1e5..7f0507c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -20,7 +20,7 @@ commands = black --line-length 120 --check --diff --target-version py39 .
commands = flake8 fixcloudutils
[testenv:mypy]
-commands= python -m mypy --install-types --non-interactive --python-version 3.9 --strict fixcloudutils tests
+commands= python -m mypy --install-types --non-interactive --python-version 3.11 --strict fixcloudutils tests
[testenv:tests]
commands = pytest