From 8c3917232166133ac484b5021bc4ddfa47de8107 Mon Sep 17 00:00:00 2001 From: Thomas Kemmer Date: Sun, 3 Oct 2021 21:07:22 +0200 Subject: [PATCH] Fix #157: Add basic TLRU implementation. --- src/cachetools/__init__.py | 223 ++++++++++++++++++++++++++----- tests/test_tlru.py | 259 +++++++++++++++++++++++++++++++++++++ 2 files changed, 450 insertions(+), 32 deletions(-) create mode 100644 tests/test_tlru.py diff --git a/src/cachetools/__init__.py b/src/cachetools/__init__.py index 42822f0..292b032 100644 --- a/src/cachetools/__init__.py +++ b/src/cachetools/__init__.py @@ -7,6 +7,7 @@ "LRUCache", "MRUCache", "RRCache", + "TLRUCache", "TTLCache", "cached", "cachedmethod", @@ -37,6 +38,53 @@ def pop(self, _): return 1 +class _Link: + + __slots__ = ("key", "expire", "next", "prev") + + def __init__(self, key=None, expire=None): + self.key = key + self.expire = expire + + def __reduce__(self): + return _Link, (self.key, self.expire) + + def unlink(self): + next = self.next + prev = self.prev + prev.next = next + next.prev = prev + + +class _Timer: + def __init__(self, timer): + self.__timer = timer + self.__nesting = 0 + + def __call__(self): + if self.__nesting == 0: + return self.__timer() + else: + return self.__time + + def __enter__(self): + if self.__nesting == 0: + self.__time = time = self.__timer() + else: + time = self.__time + self.__nesting += 1 + return time + + def __exit__(self, *exc): + self.__nesting -= 1 + + def __reduce__(self): + return _Timer, (self.__timer,) + + def __getattr__(self, name): + return getattr(self.__timer, name) + + class Cache(collections.abc.MutableMapping): """Mutable mapping to serve as a simple cache or cache base class.""" @@ -294,51 +342,162 @@ def popitem(self): return (key, self.pop(key)) -class _Timer: - def __init__(self, timer): - self.__timer = timer - self.__nesting = 0 +class TLRUCache(Cache): + """LRU Cache implementation with per-item time-to-use (TTU) value.""" - def __call__(self): - if self.__nesting == 0: - return self.__timer() + def __init__(self, maxsize, ttu, timer=time.monotonic, getsizeof=None): + Cache.__init__(self, maxsize, getsizeof) + self.__root = root = _Link() + root.prev = root.next = root + self.__links = collections.OrderedDict() + self.__timer = _Timer(timer) + self.__ttu = ttu + + def __contains__(self, key): + try: + link = self.__links[key] # no reordering + except KeyError: + return False else: - return self.__time + return not (link.expire < self.__timer()) - def __enter__(self): - if self.__nesting == 0: - self.__time = time = self.__timer() + def __getitem__(self, key, cache_getitem=Cache.__getitem__): + try: + link = self.__getlink(key) + except KeyError: + expired = False else: - time = self.__time - self.__nesting += 1 - return time + expired = link.expire < self.__timer() + if expired: + return self.__missing__(key) + else: + return cache_getitem(self, key) - def __exit__(self, *exc): - self.__nesting -= 1 + def __setitem__(self, key, value, cache_setitem=Cache.__setitem__): + with self.__timer as time: + self.expire(time) + cache_setitem(self, key, value) + try: + link = self.__getlink(key) + except KeyError: + self.__links[key] = link = _Link(key) + else: + link.unlink() + link.expire = time + self.__ttu(value) + # TODO: insert in sorted order or change data structure + link.next = root = self.__root + link.prev = prev = root.prev + prev.next = root.prev = link - def __reduce__(self): - return _Timer, (self.__timer,) + def __delitem__(self, key, cache_delitem=Cache.__delitem__): + cache_delitem(self, key) + link = self.__links.pop(key) + link.unlink() + if link.expire < self.__timer(): + raise KeyError(key) - def __getattr__(self, name): - return getattr(self.__timer, name) + def __iter__(self): + root = self.__root + curr = root.next + while curr is not root: + # "freeze" time for iterator access + with self.__timer as time: + if not (curr.expire < time): + yield curr.key + curr = curr.next + def __len__(self): + root = self.__root + curr = root.next + time = self.__timer() + count = len(self.__links) + # TODO: prevent iterating over all elements + while curr is not root: + if curr.expire < time: + count -= 1 + curr = curr.next + return count -class _Link: + def __setstate__(self, state): + self.__dict__.update(state) + root = self.__root + root.prev = root.next = root + for link in sorted(self.__links.values(), key=lambda obj: obj.expire): + link.next = root + link.prev = prev = root.prev + prev.next = root.prev = link + self.expire(self.__timer()) - __slots__ = ("key", "expire", "next", "prev") + def __repr__(self, cache_repr=Cache.__repr__): + with self.__timer as time: + self.expire(time) + return cache_repr(self) - def __init__(self, key=None, expire=None): - self.key = key - self.expire = expire + @property + def currsize(self): + with self.__timer as time: + self.expire(time) + return super().currsize - def __reduce__(self): - return _Link, (self.key, self.expire) + @property + def timer(self): + """The timer function used by the cache.""" + return self.__timer - def unlink(self): - next = self.next - prev = self.prev - prev.next = next - next.prev = prev + def expire(self, time=None): + """Remove expired items from the cache.""" + if time is None: + time = self.__timer() + root = self.__root + curr = root.next + links = self.__links + cache_delitem = Cache.__delitem__ + # TODO: prevent iterating over all elements + while curr is not root: + if curr.expire < time: + cache_delitem(self, curr.key) + del links[curr.key] + next = curr.next + curr.unlink() + curr = next + else: + curr = curr.next + + def clear(self): + with self.__timer as time: + self.expire(time) + Cache.clear(self) + + def get(self, *args, **kwargs): + with self.__timer: + return Cache.get(self, *args, **kwargs) + + def pop(self, *args, **kwargs): + with self.__timer: + return Cache.pop(self, *args, **kwargs) + + def setdefault(self, *args, **kwargs): + with self.__timer: + return Cache.setdefault(self, *args, **kwargs) + + def popitem(self): + """Remove and return the `(key, value)` pair least recently used that + has not already expired. + + """ + with self.__timer as time: + self.expire(time) + try: + key = next(iter(self.__links)) + except StopIteration: + raise KeyError("%s is empty" % self.__class__.__name__) from None + else: + return (key, self.pop(key)) + + def __getlink(self, key): + value = self.__links[key] + self.__links.move_to_end(key) + return value class TTLCache(Cache): diff --git a/tests/test_tlru.py b/tests/test_tlru.py new file mode 100644 index 0000000..b46ee31 --- /dev/null +++ b/tests/test_tlru.py @@ -0,0 +1,259 @@ +import unittest + +from cachetools import TLRUCache + +from . import CacheTestMixin + + +class Timer: + def __init__(self, auto=False): + self.auto = auto + self.time = 0 + + def __call__(self): + if self.auto: + self.time += 1 + return self.time + + def tick(self): + self.time += 1 + + +class TLRUTestCache(TLRUCache): + def default_ttu(_): + return 0 + + def __init__(self, maxsize, ttu=default_ttu, **kwargs): + TLRUCache.__init__(self, maxsize, ttu, timer=Timer(), **kwargs) + + +class TLRUCacheTest(unittest.TestCase, CacheTestMixin): + + Cache = TLRUTestCache + + def test_ttu(self): + cache = TLRUCache(maxsize=6, ttu=lambda v: v, timer=Timer()) + self.assertEqual(0, cache.timer()) + + cache[1] = 1 + self.assertEqual({1}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertEqual(1, cache[1]) + + cache.timer.tick() + self.assertEqual({1}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertEqual(1, cache[1]) + + cache[2] = 2 + self.assertEqual({1, 2}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertEqual(1, cache[1]) + self.assertEqual(2, cache[2]) + + cache.timer.tick() + self.assertEqual({2}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + + cache[3] = 3 + self.assertEqual({2, 3}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual({2, 3}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache[1] = 1 + self.assertEqual({1, 2, 3}, set(cache)) + self.assertEqual(3, len(cache)) + self.assertEqual(1, cache[1]) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual({1, 3}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertEqual(1, cache[1]) + self.assertNotIn(2, cache) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual({3}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual(set(), set(cache)) + self.assertEqual(0, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertNotIn(3, cache) + + with self.assertRaises(KeyError): + del cache[1] + with self.assertRaises(KeyError): + cache.pop(2) + with self.assertRaises(KeyError): + del cache[3] + + def test_ttu_fixed(self): + cache = TLRUCache(maxsize=2, ttu=lambda _: 1, timer=Timer()) + self.assertEqual(0, cache.timer()) + + cache[1] = 1 + self.assertEqual({1}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertEqual(1, cache[1]) + + cache.timer.tick() + self.assertEqual({1}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertEqual(1, cache[1]) + + cache[2] = 2 + self.assertEqual({1, 2}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertEqual(1, cache[1]) + self.assertEqual(2, cache[2]) + + cache.timer.tick() + self.assertEqual({2}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + + cache[3] = 3 + self.assertEqual({2, 3}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual({3}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertEqual(3, cache[3]) + + cache.timer.tick() + self.assertEqual(set(), set(cache)) + self.assertEqual(0, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertNotIn(3, cache) + + with self.assertRaises(KeyError): + del cache[1] + with self.assertRaises(KeyError): + cache.pop(2) + with self.assertRaises(KeyError): + del cache[3] + + def test_ttu_lru(self): + cache = TLRUCache(maxsize=2, ttu=lambda _: 0, timer=Timer()) + + cache[1] = 1 + cache[2] = 2 + cache[3] = 3 + + self.assertEqual(len(cache), 2) + self.assertNotIn(1, cache) + self.assertEqual(cache[2], 2) + self.assertEqual(cache[3], 3) + + cache[2] + cache[4] = 4 + self.assertEqual(len(cache), 2) + self.assertNotIn(1, cache) + self.assertEqual(cache[2], 2) + self.assertNotIn(3, cache) + self.assertEqual(cache[4], 4) + + cache[5] = 5 + self.assertEqual(len(cache), 2) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertNotIn(3, cache) + self.assertEqual(cache[4], 4) + self.assertEqual(cache[5], 5) + + def test_ttu_expire(self): + cache = TLRUCache(maxsize=3, ttu=lambda _: 2, timer=Timer()) + with cache.timer as time: + self.assertEqual(time, cache.timer()) + + cache[1] = 1 + cache.timer.tick() + cache[2] = 2 + cache.timer.tick() + cache[3] = 3 + self.assertEqual(2, cache.timer()) + + self.assertEqual({1, 2, 3}, set(cache)) + self.assertEqual(3, len(cache)) + self.assertEqual(1, cache[1]) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.expire() + self.assertEqual({1, 2, 3}, set(cache)) + self.assertEqual(3, len(cache)) + self.assertEqual(1, cache[1]) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.expire(3) + self.assertEqual({2, 3}, set(cache)) + self.assertEqual(2, len(cache)) + self.assertNotIn(1, cache) + self.assertEqual(2, cache[2]) + self.assertEqual(3, cache[3]) + + cache.expire(4) + self.assertEqual({3}, set(cache)) + self.assertEqual(1, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertEqual(3, cache[3]) + + cache.expire(5) + self.assertEqual(set(), set(cache)) + self.assertEqual(0, len(cache)) + self.assertNotIn(1, cache) + self.assertNotIn(2, cache) + self.assertNotIn(3, cache) + + def test_ttu_atomic(self): + cache = TLRUCache(maxsize=1, ttu=lambda _: 1, timer=Timer(auto=True)) + cache[1] = 1 + self.assertEqual(1, cache[1]) + cache[1] = 1 + self.assertEqual(1, cache.get(1)) + cache[1] = 1 + self.assertEqual(1, cache.pop(1)) + cache[1] = 1 + self.assertEqual(1, cache.setdefault(1)) + cache[1] = 1 + cache.clear() + self.assertEqual(0, len(cache)) + + def test_ttu_tuple_key(self): + cache = TLRUCache(maxsize=1, ttu=lambda _: 0, timer=Timer()) + + cache[(1, 2, 3)] = 42 + self.assertEqual(42, cache[(1, 2, 3)]) + cache.timer.tick() + with self.assertRaises(KeyError): + cache[(1, 2, 3)] + self.assertNotIn((1, 2, 3), cache)