From f9223484685cdcb9e02702516b6746019bd1ce1a Mon Sep 17 00:00:00 2001 From: Thomas Kemmer Date: Sat, 18 Dec 2021 13:54:50 +0100 Subject: [PATCH] Fix #157: Use heapq for keeping track of TLRU expiration time. --- src/cachetools/__init__.py | 110 +++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 59 deletions(-) diff --git a/src/cachetools/__init__.py b/src/cachetools/__init__.py index 2fd4dcd..a2aa9b6 100644 --- a/src/cachetools/__init__.py +++ b/src/cachetools/__init__.py @@ -18,6 +18,7 @@ import collections import collections.abc import functools +import heapq import random import time @@ -498,32 +499,45 @@ def __getlink(self, key): return value +@functools.total_ordering +class _TLRUItem: + + __slots__ = ("key", "expire", "removed") + + def __init__(self, key=None, expire=None): + self.key = key + self.expire = expire + self.removed = False + + def __lt__(self, other): + return self.expire < other.expire + + class TLRUCache(Cache): """LRU Cache implementation with per-item time-to-use (TTU) value.""" def __init__(self, maxsize, ttu, timer=time.monotonic, getsizeof=None): Cache.__init__(self, maxsize, getsizeof) - self.__root = root = _Link() - root.prev = root.next = root - self.__links = collections.OrderedDict() + self.__items = collections.OrderedDict() + self.__order = [] self.__timer = _Timer(timer) self.__ttu = ttu def __contains__(self, key): try: - link = self.__links[key] # no reordering + item = self.__items[key] # no reordering except KeyError: return False else: - return self.__timer() < link.expire + return self.__timer() < item.expire def __getitem__(self, key, cache_getitem=Cache.__getitem__): try: - link = self.__getlink(key) + item = self.__getitem(key) except KeyError: expired = False else: - expired = not (self.__timer() < link.expire) + expired = not (self.__timer() < item.expire) if expired: return self.__missing__(key) else: @@ -534,60 +548,36 @@ def __setitem__(self, key, value, cache_setitem=Cache.__setitem__): self.expire(time) cache_setitem(self, key, value) try: - link = self.__getlink(key) + # removing an existing item would break the heap + # structure, so only mark it as removed for now + self.__getitem(key).removed = True except KeyError: - self.__links[key] = link = _Link(key) - else: - link.unlink() - link.expire = time + self.__ttu(value) - # FIXME: insert in sorted expiration order, start at the end - # of the linked list since we expect newer items to expire - # later; this is O(n) and should be replaced with e.g. RBTree - root = self.__root - prev = root.prev - while prev is not root and link.expire < prev.expire: - prev = prev.prev - link.next = next = prev.next - link.prev = prev - prev.next = next.prev = link + pass + self.__items[key] = item = _TLRUItem(key, time + self.__ttu(value)) + heapq.heappush(self.__order, item) def __delitem__(self, key, cache_delitem=Cache.__delitem__): cache_delitem(self, key) - link = self.__links.pop(key) - link.unlink() - if not (self.__timer() < link.expire): + item = self.__items.pop(key) + item.removed = True + if not (self.__timer() < item.expire): raise KeyError(key) def __iter__(self): - root = self.__root - curr = root.next - while curr is not root: + for curr in self.__order: # "freeze" time for iterator access with self.__timer as time: - if time < curr.expire: + if time < curr.expire and not curr.removed: yield curr.key - curr = curr.next def __len__(self): - root = self.__root - curr = root.next time = self.__timer() - count = len(self.__links) - while curr is not root and not (time < curr.expire): - count -= 1 - curr = curr.next + count = 0 + for curr in self.__order: + if time < curr.expire and not curr.removed: + count += 1 return count - def __setstate__(self, state): - self.__dict__.update(state) - root = self.__root - root.prev = root.next = root - for link in sorted(self.__links.values(), key=lambda obj: obj.expire): - link.next = root - link.prev = prev = root.prev - prev.next = root.prev = link - self.expire(self.__timer()) - def __repr__(self, cache_repr=Cache.__repr__): with self.__timer as time: self.expire(time) @@ -608,16 +598,18 @@ def expire(self, time=None): """Remove expired items from the cache.""" if time is None: time = self.__timer() - root = self.__root - curr = root.next - links = self.__links + items = self.__items + order = self.__order + # clean up the heap if too many items are marked as removed + if len(order) > len(items) * 2: + self.__order = order = [item for item in order if not item.removed] + heapq.heapify(order) cache_delitem = Cache.__delitem__ - while curr is not root and not (time < curr.expire): - cache_delitem(self, curr.key) - del links[curr.key] - next = curr.next - curr.unlink() - curr = next + while order and (order[0].removed or time >= order[0].expire): + item = heapq.heappop(order) + if not item.removed: + cache_delitem(self, item.key) + del items[item.key] def clear(self): with self.__timer as time: @@ -644,15 +636,15 @@ def popitem(self): with self.__timer as time: self.expire(time) try: - key = next(iter(self.__links)) + key = next(iter(self.__items)) except StopIteration: raise KeyError("%s is empty" % self.__class__.__name__) from None else: return (key, self.pop(key)) - def __getlink(self, key): - value = self.__links[key] - self.__links.move_to_end(key) + def __getitem(self, key): + value = self.__items[key] + self.__items.move_to_end(key) return value