diff --git a/async_lru/__init__.py b/async_lru/__init__.py index e04ca33..228ebad 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -168,22 +168,24 @@ def _task_done_callback( ) -> None: self.__tasks.discard(task) - cache_item = self.__cache.get(key) - if self.__ttl is not None and cache_item is not None: - loop = asyncio.get_running_loop() - cache_item.later_call = loop.call_later( - self.__ttl, self.__cache.pop, key, None - ) - if task.cancelled(): fut.cancel() + self.__cache.pop(key, None) return exc = task.exception() if exc is not None: fut.set_exception(exc) + self.__cache.pop(key, None) return + cache_item = self.__cache.get(key) + if self.__ttl is not None and cache_item is not None: + loop = asyncio.get_running_loop() + cache_item.later_call = loop.call_later( + self.__ttl, self.__cache.pop, key, None + ) + fut.set_result(task.result()) async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: @@ -197,19 +199,11 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: cache_item = self.__cache.get(key) if cache_item is not None: + self._cache_hit(key) if not cache_item.fut.done(): - self._cache_hit(key) return await asyncio.shield(cache_item.fut) - exc = cache_item.fut._exception - - if exc is None: - self._cache_hit(key) - return cache_item.fut.result() - else: - # exception here - cache_item = self.__cache.pop(key) - cache_item.cancel() + return cache_item.fut.result() fut = loop.create_future() coro = self.__wrapped__(*fn_args, **fn_kwargs) diff --git a/tests/test_close.py b/tests/test_close.py index 86e5aaa..abf46d3 100644 --- a/tests/test_close.py +++ b/tests/test_close.py @@ -31,13 +31,13 @@ async def coro(val: int) -> int: await close - check_lru(coro, hits=0, misses=5, cache=5, tasks=0) + check_lru(coro, hits=0, misses=5, cache=0, tasks=0) assert coro.cache_parameters()["closed"] with pytest.raises(asyncio.CancelledError): await gather - check_lru(coro, hits=0, misses=5, cache=5, tasks=0) + check_lru(coro, hits=0, misses=5, cache=0, tasks=0) assert coro.cache_parameters()["closed"] # double call is no-op diff --git a/tests/test_exception.py b/tests/test_exception.py index 8fd4eb5..054ea3a 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -1,4 +1,6 @@ import asyncio +import gc +import sys from typing import Callable import pytest @@ -16,7 +18,7 @@ async def coro(val: int) -> None: ret = await asyncio.gather(*coros, return_exceptions=True) - check_lru(coro, hits=2, misses=1, cache=1, tasks=0) + check_lru(coro, hits=2, misses=1, cache=0, tasks=0) for item in ret: assert isinstance(item, ZeroDivisionError) @@ -24,4 +26,31 @@ async def coro(val: int) -> None: with pytest.raises(ZeroDivisionError): await coro(1) - check_lru(coro, hits=2, misses=2, cache=1, tasks=0) + check_lru(coro, hits=2, misses=2, cache=0, tasks=0) + + +@pytest.mark.xfail( + reason="Memory leak is not fixed for PyPy3.9", + condition=sys.implementation.name == "pypy", +) +async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: + class CustomClass: + ... + + @alru_cache() + async def coro(val: int) -> None: + _ = CustomClass() # object we are verifying not to leak + 1 / 0 + + coros = [coro(v) for v in range(1000)] + + await asyncio.gather(*coros, return_exceptions=True) + + check_lru(coro, hits=0, misses=1000, cache=0, tasks=0) + + await asyncio.sleep(0.00001) + gc.collect() + + assert ( + len([obj for obj in gc.get_objects() if isinstance(obj, CustomClass)]) == 0 + ), "Only objects in the cache should be left in memory."