Skip to content

Commit

Permalink
Remove keys from cache on exception (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
laky55555 authored Jul 12, 2024
1 parent 4f7e63a commit 6aafceb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
28 changes: 11 additions & 17 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,24 @@ def _task_done_callback(
) -> None:
self.__tasks.discard(task)

cache_item = self.__cache.get(key)
if self.__ttl is not None and cache_item is not None:
loop = asyncio.get_running_loop()
cache_item.later_call = loop.call_later(
self.__ttl, self.__cache.pop, key, None
)

if task.cancelled():
fut.cancel()
self.__cache.pop(key, None)
return

exc = task.exception()
if exc is not None:
fut.set_exception(exc)
self.__cache.pop(key, None)
return

cache_item = self.__cache.get(key)
if self.__ttl is not None and cache_item is not None:
loop = asyncio.get_running_loop()
cache_item.later_call = loop.call_later(
self.__ttl, self.__cache.pop, key, None
)

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
Expand All @@ -197,19 +199,11 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
cache_item = self.__cache.get(key)

if cache_item is not None:
self._cache_hit(key)
if not cache_item.fut.done():
self._cache_hit(key)
return await asyncio.shield(cache_item.fut)

exc = cache_item.fut._exception

if exc is None:
self._cache_hit(key)
return cache_item.fut.result()
else:
# exception here
cache_item = self.__cache.pop(key)
cache_item.cancel()
return cache_item.fut.result()

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_close.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ async def coro(val: int) -> int:

await close

check_lru(coro, hits=0, misses=5, cache=5, tasks=0)
check_lru(coro, hits=0, misses=5, cache=0, tasks=0)
assert coro.cache_parameters()["closed"]

with pytest.raises(asyncio.CancelledError):
await gather

check_lru(coro, hits=0, misses=5, cache=5, tasks=0)
check_lru(coro, hits=0, misses=5, cache=0, tasks=0)
assert coro.cache_parameters()["closed"]

# double call is no-op
Expand Down
33 changes: 31 additions & 2 deletions tests/test_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import gc
import sys
from typing import Callable

import pytest
Expand All @@ -16,12 +18,39 @@ async def coro(val: int) -> None:

ret = await asyncio.gather(*coros, return_exceptions=True)

check_lru(coro, hits=2, misses=1, cache=1, tasks=0)
check_lru(coro, hits=2, misses=1, cache=0, tasks=0)

for item in ret:
assert isinstance(item, ZeroDivisionError)

with pytest.raises(ZeroDivisionError):
await coro(1)

check_lru(coro, hits=2, misses=2, cache=1, tasks=0)
check_lru(coro, hits=2, misses=2, cache=0, tasks=0)


@pytest.mark.xfail(
reason="Memory leak is not fixed for PyPy3.9",
condition=sys.implementation.name == "pypy",
)
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None:
class CustomClass:
...

@alru_cache()
async def coro(val: int) -> None:
_ = CustomClass() # object we are verifying not to leak
1 / 0

coros = [coro(v) for v in range(1000)]

await asyncio.gather(*coros, return_exceptions=True)

check_lru(coro, hits=0, misses=1000, cache=0, tasks=0)

await asyncio.sleep(0.00001)
gc.collect()

assert (
len([obj for obj in gc.get_objects() if isinstance(obj, CustomClass)]) == 0
), "Only objects in the cache should be left in memory."

0 comments on commit 6aafceb

Please sign in to comment.