Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Use Hash for Cache Entries #22

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 55 additions & 75 deletions fixcloudutils/redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pickle
from asyncio import Task, Queue, CancelledError, create_task
from datetime import datetime, timedelta
from typing import Any, Optional, TypeVar, Callable, ParamSpec, NewType
from typing import Any, Optional, TypeVar, Callable, ParamSpec, NewType, Hashable

from attr import frozen
from prometheus_client import Counter
Expand All @@ -40,6 +40,8 @@
@frozen
class RedisCacheSet:
key: RedisKey
fn_name: str
fn_key: str
value: Any
ttl: timedelta

Expand All @@ -57,6 +59,7 @@ class RedisCacheEvict:
class RedisCacheEntry:
key: Any
redis_key: RedisKey
fn_key: str
value: Any
deadline: datetime

Expand Down Expand Up @@ -101,23 +104,15 @@ async def stop(self) -> None:
await self.cleaner_task.stop()
self.started = False

async def evict(self, fn_name: str, key: str) -> None:
log.info(f"{self.key}:{fn_name} Evict {key}")
await self.queue.put(RedisCacheEvict(self._redis_key(fn_name, key)))

def evict_with(self, fn: Callable[P, T]) -> Callable[P, T]:
async def evict_fn(*args: Any, **kwargs: Any) -> None:
key = self._redis_key(fn.__name__, None, *args, **kwargs)
log.info(f"{self.key}:{fn.__name__} Evict args based key: {key}")
await self.queue.put(RedisCacheEvict(key))

return evict_fn # type: ignore
async def evict(self, key: str) -> None:
log.info(f"{self.key}: Evict {key}")
await self.queue.put(RedisCacheEvict(self._redis_key(key)))

def call(
self,
fn: Callable[P, T],
key: str,
*,
key: Optional[str] = None,
ttl_memory: Optional[timedelta] = None,
ttl_redis: Optional[timedelta] = None,
) -> Callable[P, T]:
Expand All @@ -137,24 +132,25 @@ def call(
async def handle_call(*args: Any, **kwargs: Any) -> T:
# check if the value is available in the local cache
fns = fn.__name__
local_cache_key = key or (fns, *args, *kwargs.values())
fn_key = self._fn_key(fns, *args, *kwargs.values())
local_cache_key = (key, fn_key)
if local_value := self.local_cache.get(local_cache_key):
log.info(f"{self.key}:{fns} Serve result from local cache.")
log.info(f"{self.key}:{fns}:{key} Serve result from local cache.")
CacheHit.labels(self.key, "local").inc()
return local_value.value # type: ignore
# check if the value is available in redis
redis_key = self._redis_key(fns, key, *args, **kwargs)
if redis_value := await self.redis.get(redis_key):
log.info(f"{self.key}:{fns} Serve result from redis cache.")
redis_key = self._redis_key(key)
if redis_value := await self.redis.hget(redis_key, fn_key): # type: ignore
log.info(f"{self.key}:{fns}:{key} Serve result from redis cache.")
CacheHit.labels(self.key, "redis").inc()
result: T = pickle.loads(base64.b64decode(redis_value))
self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory)
self._add_to_local_cache(local_cache_key, redis_key, fn_key, result, ttl_memory or self.ttl_memory)
return result
# call the function
result = await fn(*args, **kwargs) # type: ignore
CacheHit.labels(self.key, "call").inc()
self._add_to_local_cache(local_cache_key, redis_key, result, ttl_memory or self.ttl_memory)
await self.queue.put(RedisCacheSet(redis_key, result, ttl_redis or self.ttl_redis))
self._add_to_local_cache(local_cache_key, redis_key, fn_key, result, ttl_memory or self.ttl_memory)
await self.queue.put(RedisCacheSet(redis_key, fns, fn_key, result, ttl_redis or self.ttl_redis))
return result

return handle_call # type: ignore
Expand All @@ -167,9 +163,10 @@ async def _process_queue(self) -> None:
try:
entry = await self.queue.get()
if isinstance(entry, RedisCacheSet):
log.info(f"{self.key}: Store cached value in redis as {entry.key}")
value = base64.b64encode(pickle.dumps(entry.value))
await self.redis.set(name=entry.key, value=value, ex=entry.ttl)
log.info(f"{self.key}:{entry.fn_name} Store cached value in redis as {entry.key}:{entry.fn_key}")
value = base64.b64encode(pickle.dumps(entry.value)).decode("utf-8")
await self.redis.hset(name=entry.key, key=entry.fn_key, value=value) # type: ignore
await self.redis.expire(name=entry.key, time=entry.ttl)
elif isinstance(entry, RedisCacheEvict):
log.info(f"{self.key}: Delete cached value from redis key {entry.key}")
# delete the entry
Expand Down Expand Up @@ -206,56 +203,39 @@ async def _wipe_outdated_from_local_cache(self) -> None:
log.info(f"Evicting {key} from local cache")
del self.local_cache[key]

def _add_to_local_cache(self, key: Any, redis_key: RedisKey, value: Any, ttl: timedelta) -> None:
entry = RedisCacheEntry(key=key, redis_key=redis_key, value=value, deadline=utc() + ttl)
def _add_to_local_cache(self, key: Any, redis_key: RedisKey, fn_key: str, value: Any, ttl: timedelta) -> None:
entry = RedisCacheEntry(key=key, redis_key=redis_key, fn_key=fn_key, value=value, deadline=utc() + ttl)
self.local_cache[key] = entry

def _remove_from_local_cache(self, redis_key: str) -> None:
local_key: Optional[Any] = None
for key, entry in self.local_cache.items():
if entry.redis_key == redis_key:
local_key = key
break
if local_key:
log.info(f"Evicting {redis_key} from local cache")
del self.local_cache[local_key]

def _redis_key(self, fn_name: str, fn_key: Optional[str], *args: Any, **kwargs: Any) -> RedisKey:
if fn_key is None:
sha = hashlib.sha256()
for a in args:
sha.update(pickle.dumps(a))
for k, v in kwargs.items():
sha.update(pickle.dumps(k))
sha.update(pickle.dumps(v))
fn_key = sha.hexdigest()
return RedisKey(f"cache:{self.key}:{fn_name}:{fn_key}")


class redis_cached: # noqa
"""
Decorator for caching function calls in redis.

Usage:
>>> redis = Redis()
>>> cache = RedisCache(redis, "test", "test1")
>>> @redis_cached(cache)
... async def f(a: int, b: int) -> int:
... return a + b
"""

def __init__(
self,
cache: RedisCache,
*,
key: Optional[str] = None,
ttl_memory: Optional[timedelta] = None,
ttl_redis: Optional[timedelta] = None,
):
self.cache = cache
self.key = key
self.ttl_memory = ttl_memory
self.ttl_redis = ttl_redis

def __call__(self, fn: Callable[P, T]) -> Callable[P, T]:
return self.cache.call(fn, key=self.key, ttl_memory=self.ttl_memory, ttl_redis=self.ttl_redis)
def _remove_from_local_cache(self, redis_key: RedisKey) -> None:
entries = [key for key, entry in self.local_cache.items() if entry.redis_key == redis_key]
if entries:
log.info(f"Evicting {len(entries)} entries for key {redis_key} from local cache")
for k in entries:
del self.local_cache[k]

def _redis_key(self, key: str) -> RedisKey:
return RedisKey(f"cache:{self.key}:{key}")

def _fn_key(self, fn_name: str, *args: Any, **kwargs: Any) -> str:
counter = 0

def object_hash(obj: Any) -> bytes:
nonlocal counter
if isinstance(obj, Hashable):
counter += 1
return f"{counter}:{hash(obj)}".encode("utf-8")
else:
return pickle.dumps(obj)

if len(args) == 0 and len(kwargs) == 0:
return fn_name

sha = hashlib.sha256()
sha.update(fn_name.encode("utf-8"))
for a in args:
sha.update(object_hash(a))
for k, v in kwargs.items():
sha.update(object_hash(k))
sha.update(object_hash(v))
return sha.hexdigest()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fixcloudutils"
version = "1.12.0"
version = "1.13.0"
authors = [{ name = "Some Engineering Inc." }]
description = "Utilities for fixcloud."
license = { file = "LICENSE" }
Expand Down
26 changes: 15 additions & 11 deletions tests/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,45 @@ async def complex_function(a: int, b: int) -> int:
async def local_cache_is_empty(cache: RedisCache) -> bool:
return len(cache.local_cache) == 0

key = cache1._redis_key(complex_function.__name__, None, 1, 2)
assert await cache1.call(complex_function)(1, 2) == 3
c1 = "customer_1"
redis_key = cache1._redis_key(c1)
fn_key = cache1._fn_key(complex_function.__name__, 1, 2)
assert await cache1.call(complex_function, c1)(1, 2) == 3
assert call_count == 1
assert len(cache1.local_cache) == 1
# should come from internal memory cache
assert await cache1.call(complex_function)(1, 2) == 3
assert await cache1.call(complex_function, c1)(1, 2) == 3
assert call_count == 1
await eventually(redis.exists, key, timeout=2)
await eventually(redis.hexists, redis_key, fn_key, timeout=2) # type: ignore
# should come from redis cache
assert len(cache2.local_cache) == 0
assert await cache2.call(complex_function)(1, 2) == 3
assert await cache2.call(complex_function, c1)(1, 2) == 3
assert call_count == 1
assert len(cache2.local_cache) == 1
# after ttl expires, the local cache is empty
await eventually(local_cache_is_empty, cache1, timeout=1)
await eventually(local_cache_is_empty, cache2, timeout=1)
# after redis ttl the cache is evicted
await eventually(redis.exists, key, fn=lambda x: not x, timeout=2)
await eventually(redis.hexists, redis_key, fn_key, fn=lambda x: not x, timeout=2) # type: ignore

# calling this method again should trigger a new call and a new cache entry
# we use a loner redis ttl to test the eviction of the redis cache
for a in range(100):
assert await cache1.call(complex_function, ttl_redis=t3)(a, 2) == a + 2
assert await cache1.call(complex_function, c1, ttl_redis=t3)(a, 2) == a + 2
assert call_count == 101
assert len(cache1.local_cache) == 100
await eventually(redis.hlen, redis_key, fn=lambda x: x == 100, timeout=2) # type: ignore

for a in range(100):
assert await cache1.call(complex_function)(a, 2) == a + 2
assert await cache2.call(complex_function)(a, 2) == a + 2
assert await cache1.call(complex_function, c1)(a, 2) == a + 2
assert await cache2.call(complex_function, c1)(a, 2) == a + 2
assert len(cache1.local_cache) == 100
assert len(cache2.local_cache) == 100
# no more calls are done
assert call_count == 101

# evict all entries should evict all messages in all caches
for a in range(100):
await cache1.evict_with(complex_function)(a, 2)
await cache1.evict(c1)
await eventually(local_cache_is_empty, cache1, timeout=1)
await eventually(local_cache_is_empty, cache2, timeout=1)
await eventually(redis.exists, redis_key, fn=lambda x: not x, timeout=2)