From f006b551d3e057d26f2db8031c303a92244554be Mon Sep 17 00:00:00 2001 From: Daniel M Date: Tue, 17 Dec 2024 21:48:12 -0500 Subject: [PATCH] SADDEX - implement set with support for expiring members (#350) --- docs/about/changelog.md | 1 + fakeredis/_basefakesocket.py | 7 +- fakeredis/_commands.py | 80 +---------------- fakeredis/commands_mixins/generic_mixin.py | 5 +- fakeredis/commands_mixins/hash_mixin.py | 3 +- fakeredis/commands_mixins/set_mixin.py | 45 +++++----- fakeredis/commands_mixins/sortedset_mixin.py | 4 +- fakeredis/model/__init__.py | 4 + fakeredis/model/_expiring_members_set.py | 89 +++++++++++++++++++ fakeredis/model/_hash.py | 80 +++++++++++++++++ .../dragonfly_mixin.py | 16 ++-- test/test_mixins/test_set_commands.py | 11 +++ 12 files changed, 229 insertions(+), 116 deletions(-) create mode 100644 fakeredis/model/_expiring_members_set.py create mode 100644 fakeredis/model/_hash.py diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 42f32eed..779ba8b6 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -12,6 +12,7 @@ toc_depth: 2 ### 🚀 Features - Add support disable_decoding in async read_response #349 +- Implement support for `SADDEX`, using a new set implementation with support for expiring members #350 ## v2.26.2 diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index b6f0a44e..81c9d005 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -8,11 +8,10 @@ import redis from redis.connection import DefaultParser -from fakeredis.model import XStream -from fakeredis.model import ZSet +from fakeredis.model import XStream, ZSet, Hash, ExpiringMembersSet from . import _msgs as msgs from ._command_args_parsing import extract_args -from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash +from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem from ._helpers import ( SimpleError, valid_response_type, @@ -392,7 +391,7 @@ def _key_value_type(key: CommandItem) -> SimpleString: return SimpleString(b"string") elif isinstance(key.value, list): return SimpleString(b"list") - elif isinstance(key.value, set): + elif isinstance(key.value, ExpiringMembersSet): return SimpleString(b"set") elif isinstance(key.value, ZSet): return SimpleString(b"zset") diff --git a/fakeredis/_commands.py b/fakeredis/_commands.py index 91a0ea92..e6da9e78 100644 --- a/fakeredis/_commands.py +++ b/fakeredis/_commands.py @@ -8,10 +8,10 @@ import re import sys import time -from typing import Iterable, Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection +from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection from . import _msgs as msgs -from ._helpers import null_terminate, SimpleError, Database, current_time +from ._helpers import null_terminate, SimpleError, Database MAX_STRING_SIZE = 512 * 1024 * 1024 SUPPORTED_COMMANDS: Dict[str, "Signature"] = dict() # Dictionary of supported commands name => Signature @@ -107,82 +107,6 @@ def __bool__(self) -> bool: __nonzero__ = __bool__ # For Python 2 -class Hash: - DECODE_ERROR = msgs.INVALID_HASH_MSG - redis_type = b"hash" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._expirations: Dict[bytes, int] = dict() - self._values: Dict[bytes, Any] = dict() - - def _expire_keys(self) -> None: - removed = [] - now = current_time() - for k in self._expirations: - if self._expirations[k] < now: - self._values.pop(k, None) - removed.append(k) - for k in removed: - self._expirations.pop(k, None) - - def set_key_expireat(self, key: bytes, when_ms: int) -> int: - now = current_time() - if when_ms <= now: - self._values.pop(key, None) - self._expirations.pop(key, None) - return 2 - self._expirations[key] = when_ms - return 1 - - def clear_key_expireat(self, key: bytes) -> bool: - return self._expirations.pop(key, None) is not None - - def get_key_expireat(self, key: bytes) -> Optional[int]: - self._expire_keys() - return self._expirations.get(key, None) - - def __getitem__(self, key: bytes) -> Any: - self._expire_keys() - return self._values.get(key) - - def __contains__(self, key: bytes) -> bool: - self._expire_keys() - return self._values.__contains__(key) - - def __setitem__(self, key: bytes, value: Any) -> None: - self._expirations.pop(key, None) - self._values[key] = value - - def __delitem__(self, key: bytes) -> None: - self._values.pop(key, None) - self._expirations.pop(key, None) - - def __len__(self) -> int: - return len(self._values) - - def __iter__(self) -> Iterable[bytes]: - return iter(self._values) - - def get(self, key: bytes, default: Any = None) -> Any: - return self._values.get(key, default) - - def keys(self) -> Iterable[bytes]: - self._expire_keys() - return self._values.keys() - - def values(self) -> Iterable[Any]: - return [v for k, v in self.items()] - - def items(self) -> Iterable[Tuple[bytes, Any]]: - self._expire_keys() - return self._values.items() - - def update(self, values: Dict[bytes, Any]) -> None: - self._expire_keys() - self._values.update(values) - - class RedisType: @classmethod def decode(cls, *args, **kwargs): # type:ignore diff --git a/fakeredis/commands_mixins/generic_mixin.py b/fakeredis/commands_mixins/generic_mixin.py index 7af9f3e0..81e16835 100644 --- a/fakeredis/commands_mixins/generic_mixin.py +++ b/fakeredis/commands_mixins/generic_mixin.py @@ -14,10 +14,9 @@ CommandItem, SortFloat, delete_keys, - Hash, ) from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database, SimpleString -from fakeredis.model import ZSet +from fakeredis.model import ZSet, Hash, ExpiringMembersSet class GenericCommandsMixin: @@ -224,7 +223,7 @@ def scan(self, cursor, *args): @command(name="SORT", fixed=(Key(),), repeat=(bytes,)) def sort(self, key, *args): - if key.value is not None and not isinstance(key.value, (set, list, ZSet)): + if key.value is not None and not isinstance(key.value, (ExpiringMembersSet, list, ZSet)): raise SimpleError(msgs.WRONGTYPE_MSG) ( asc, diff --git a/fakeredis/commands_mixins/hash_mixin.py b/fakeredis/commands_mixins/hash_mixin.py index bb7019e2..082e789f 100644 --- a/fakeredis/commands_mixins/hash_mixin.py +++ b/fakeredis/commands_mixins/hash_mixin.py @@ -5,9 +5,10 @@ from fakeredis import _msgs as msgs from fakeredis._command_args_parsing import extract_args -from fakeredis._commands import command, Key, Hash, Int, Float, CommandItem +from fakeredis._commands import command, Key, Int, Float, CommandItem from fakeredis._helpers import SimpleError, OK, casematch, SimpleString from fakeredis._helpers import current_time +from fakeredis.model import Hash class HashCommandsMixin: diff --git a/fakeredis/commands_mixins/set_mixin.py b/fakeredis/commands_mixins/set_mixin.py index fa0a84d9..a0b65ca5 100644 --- a/fakeredis/commands_mixins/set_mixin.py +++ b/fakeredis/commands_mixins/set_mixin.py @@ -4,18 +4,19 @@ from fakeredis import _msgs as msgs from fakeredis._commands import command, Key, Int, CommandItem from fakeredis._helpers import OK, SimpleError, casematch, Database, SimpleString +from fakeredis.model import ExpiringMembersSet def _calc_setop(op: Callable[..., Any], stop_if_missing: bool, key: CommandItem, *keys: CommandItem) -> Any: if stop_if_missing and not key.value: return set() value = key.value - if not isinstance(value, set): + if not isinstance(value, ExpiringMembersSet): raise SimpleError(msgs.WRONGTYPE_MSG) ans = value.copy() for other in keys: - value = other.value if other.value is not None else set() - if not isinstance(value, set): + value = other.value if other.value is not None else ExpiringMembersSet() + if not isinstance(value, ExpiringMembersSet): raise SimpleError(msgs.WRONGTYPE_MSG) if stop_if_missing and not value: return set() @@ -48,26 +49,26 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.version: Tuple[int] self._db: Database - @command((Key(set), bytes), (bytes,)) + @command((Key(ExpiringMembersSet), bytes), (bytes,)) def sadd(self, key: CommandItem, *members: bytes) -> int: old_size = len(key.value) key.value.update(members) key.updated() return len(key.value) - old_size - @command((Key(set),)) + @command((Key(ExpiringMembersSet),)) def scard(self, key: CommandItem) -> int: return len(key.value) - @command((Key(set),), (Key(set),)) + @command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),)) def sdiff(self, *keys: CommandItem) -> Any: return _setop(lambda a, b: a - b, False, None, *keys) - @command((Key(), Key(set)), (Key(set),)) + @command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),)) def sdiffstore(self, dst: CommandItem, *keys: CommandItem) -> Any: return _setop(lambda a, b: a - b, False, dst, *keys) - @command((Key(set),), (Key(set),)) + @command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),)) def sinter(self, *keys: CommandItem) -> Any: res = _setop(lambda a, b: a & b, True, None, *keys) return res @@ -89,23 +90,23 @@ def sintercard(self, numkeys: int, *args: bytes) -> int: res = _setop(lambda a, b: a & b, False, None, *keys) return len(res) if limit == 0 else min(limit, len(res)) - @command((Key(), Key(set)), (Key(set),)) + @command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),)) def sinterstore(self, dst: CommandItem, *keys: CommandItem) -> Any: return _setop(lambda a, b: a & b, True, dst, *keys) - @command((Key(set), bytes)) + @command((Key(ExpiringMembersSet), bytes)) def sismember(self, key: CommandItem, member: bytes) -> int: return int(member in key.value) - @command((Key(set), bytes), (bytes,)) + @command((Key(ExpiringMembersSet), bytes), (bytes,)) def smismember(self, key: CommandItem, *members: bytes) -> List[int]: return [self.sismember(key, member) for member in members] - @command((Key(set),)) + @command((Key(ExpiringMembersSet),)) def smembers(self, key: CommandItem) -> List[bytes]: return list(key.value) - @command((Key(set, 0), Key(set), bytes)) + @command((Key(ExpiringMembersSet, 0), Key(ExpiringMembersSet), bytes)) def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int: try: src.value.remove(member) @@ -117,7 +118,7 @@ def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int: dst.updated() # TODO: is it updated if member was already present? return 1 - @command((Key(set),), (Int,)) + @command((Key(ExpiringMembersSet),), (Int,)) def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]: if count is None: if not key.value: @@ -135,7 +136,7 @@ def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, Li key.updated() # Inside the loop because redis special-cases count=0 return items - @command((Key(set),), (Int,)) + @command((Key(ExpiringMembersSet),), (Int,)) def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]: if count is None: if not key.value: @@ -149,7 +150,7 @@ def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[by items = list(key.value) return [random.choice(items) for _ in range(-count)] - @command((Key(set), bytes), (bytes,)) + @command((Key(ExpiringMembersSet), bytes), (bytes,)) def srem(self, key: CommandItem, *members: bytes) -> int: old_size = len(key.value) for member in members: @@ -159,15 +160,15 @@ def srem(self, key: CommandItem, *members: bytes) -> int: key.updated() return deleted - @command((Key(set), Int), (bytes, bytes)) + @command((Key(ExpiringMembersSet), Int), (bytes, bytes)) def sscan(self, key: CommandItem, cursor: int, *args: bytes) -> Any: return self._scan(key.value, cursor, *args) - @command((Key(set),), (Key(set),)) + @command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),)) def sunion(self, *keys: CommandItem) -> Any: return _setop(lambda a, b: a | b, False, None, *keys) - @command((Key(), Key(set)), (Key(set),)) + @command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),)) def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any: return _setop(lambda a, b: a | b, False, dst, *keys) @@ -176,19 +177,19 @@ def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any: # approximate and store the results in a string. Instead, it is implemented # on top of sets. - @command((Key(set),), (bytes,)) + @command((Key(ExpiringMembersSet),), (bytes,)) def pfadd(self, key: CommandItem, *elements: bytes) -> int: result = self.sadd(key, *elements) # Per the documentation: # - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise. return 1 if result > 0 else 0 - @command((Key(set),), (Key(set),)) + @command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),)) def pfcount(self, *keys: CommandItem) -> int: """Return the approximated cardinality of the set observed by the HyperLogLog at key(s).""" return len(self.sunion(*keys)) - @command((Key(set), Key(set)), (Key(set),)) + @command((Key(ExpiringMembersSet), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),)) def pfmerge(self, dest: CommandItem, *sources: CommandItem) -> SimpleString: """Merge N different HyperLogLogs into a single one.""" self.sunionstore(dest, *sources) diff --git a/fakeredis/commands_mixins/sortedset_mixin.py b/fakeredis/commands_mixins/sortedset_mixin.py index 6cb8da13..cdb0a912 100644 --- a/fakeredis/commands_mixins/sortedset_mixin.py +++ b/fakeredis/commands_mixins/sortedset_mixin.py @@ -26,7 +26,7 @@ null_terminate, Database, ) -from fakeredis.model import ZSet +from fakeredis.model import ZSet, ExpiringMembersSet SORTED_SET_METHODS = { "ZUNIONSTORE": lambda s1, s2: s1 | s2, @@ -391,7 +391,7 @@ def zscore(self, key, member): @staticmethod def _get_zset(value): - if isinstance(value, set): + if isinstance(value, ExpiringMembersSet): zset = ZSet() for item in value: zset[item] = 1.0 diff --git a/fakeredis/model/__init__.py b/fakeredis/model/__init__.py index d7e1d45b..1c79c346 100644 --- a/fakeredis/model/__init__.py +++ b/fakeredis/model/__init__.py @@ -1,3 +1,5 @@ +from ._expiring_members_set import ExpiringMembersSet +from ._hash import Hash from ._stream import XStream, StreamEntryKey, StreamGroup, StreamRangeTest from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS from ._topk import HeavyKeeper @@ -13,4 +15,6 @@ "TimeSeriesRule", "AGGREGATORS", "HeavyKeeper", + "Hash", + "ExpiringMembersSet", ] diff --git a/fakeredis/model/_expiring_members_set.py b/fakeredis/model/_expiring_members_set.py new file mode 100644 index 00000000..78aefd20 --- /dev/null +++ b/fakeredis/model/_expiring_members_set.py @@ -0,0 +1,89 @@ +import sys +from typing import Iterable, Optional, Any, Dict, Union, Set + +from fakeredis import _msgs as msgs +from fakeredis._helpers import current_time + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + +class ExpiringMembersSet: + DECODE_ERROR = msgs.INVALID_HASH_MSG + redis_type = b"set" + + def __init__(self, values: Dict[bytes, Optional[int]] = None, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._values: Dict[bytes, Optional[int]] = values or dict() + + def _expire_members(self) -> None: + removed = [] + now = current_time() + for k in self._values: + if self._values[k] is not None and self._values[k] < now: + self._values.pop(k, None) + removed.append(k) + + def set_member_expireat(self, key: bytes, when_ms: int) -> int: + now = current_time() + if when_ms <= now: + self._values.pop(key, None) + return 2 + self._values[key] = when_ms + return 1 + + def clear_key_expireat(self, key: bytes) -> bool: + return self._values.pop(key, None) is not None + + def get_key_expireat(self, key: bytes) -> Optional[int]: + self._expire_members() + return self._values.get(key, None) + + def __contains__(self, key: bytes) -> bool: + self._expire_members() + return self._values.__contains__(key) + + def __delitem__(self, key: bytes) -> None: + self._values.pop(key, None) + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self) -> Iterable[bytes]: + return iter({k for k in self._values if self._values[k] is None or self._values[k] >= current_time()}) + + def __get__(self, instance, owner=None) -> Set[bytes]: + self._expire_members() + return set(self._values.keys()) + + def __sub__(self, other: Self) -> Self: + return ExpiringMembersSet({k: v for k, v in self._values.items() if k not in other._values}) + + def __and__(self, other: Self) -> Self: + return ExpiringMembersSet({k: v for k, v in self._values.items() if k in other._values}) + + def __or__(self, other: Self) -> Self: + return ExpiringMembersSet({k: v for k, v in self._values.items()}).update(other) + + def update(self, other: Union[Self, Iterable[bytes]]) -> Self: + self._expire_members() + if isinstance(other, ExpiringMembersSet): + self._values.update(other._values) + return self + for value in other: + self._values[value] = None + return self + + def discard(self, key: bytes) -> None: + self._values.pop(key, None) + + def remove(self, key: bytes) -> None: + self._values.pop(key) + + def add(self, key: bytes) -> None: + self._values[key] = None + + def copy(self) -> Self: + return ExpiringMembersSet(self._values.copy()) diff --git a/fakeredis/model/_hash.py b/fakeredis/model/_hash.py new file mode 100644 index 00000000..c4492cb1 --- /dev/null +++ b/fakeredis/model/_hash.py @@ -0,0 +1,80 @@ +from typing import Iterable, Tuple, Optional, Any, Dict + +from fakeredis import _msgs as msgs +from fakeredis._helpers import current_time + + +class Hash: + DECODE_ERROR = msgs.INVALID_HASH_MSG + redis_type = b"hash" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._expirations: Dict[bytes, int] = dict() + self._values: Dict[bytes, Any] = dict() + + def _expire_keys(self) -> None: + removed = [] + now = current_time() + for k in self._expirations: + if self._expirations[k] < now: + self._values.pop(k, None) + removed.append(k) + for k in removed: + self._expirations.pop(k, None) + + def set_key_expireat(self, key: bytes, when_ms: int) -> int: + now = current_time() + if when_ms <= now: + self._values.pop(key, None) + self._expirations.pop(key, None) + return 2 + self._expirations[key] = when_ms + return 1 + + def clear_key_expireat(self, key: bytes) -> bool: + return self._expirations.pop(key, None) is not None + + def get_key_expireat(self, key: bytes) -> Optional[int]: + self._expire_keys() + return self._expirations.get(key, None) + + def __getitem__(self, key: bytes) -> Any: + self._expire_keys() + return self._values.get(key) + + def __contains__(self, key: bytes) -> bool: + self._expire_keys() + return self._values.__contains__(key) + + def __setitem__(self, key: bytes, value: Any) -> None: + self._expirations.pop(key, None) + self._values[key] = value + + def __delitem__(self, key: bytes) -> None: + self._values.pop(key, None) + self._expirations.pop(key, None) + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self) -> Iterable[bytes]: + return iter(self._values) + + def get(self, key: bytes, default: Any = None) -> Any: + return self._values.get(key, default) + + def keys(self) -> Iterable[bytes]: + self._expire_keys() + return self._values.keys() + + def values(self) -> Iterable[Any]: + return [v for k, v in self.items()] + + def items(self) -> Iterable[Tuple[bytes, Any]]: + self._expire_keys() + return self._values.items() + + def update(self, values: Dict[bytes, Any]) -> None: + self._expire_keys() + self._values.update(values) diff --git a/fakeredis/server_specific_commands/dragonfly_mixin.py b/fakeredis/server_specific_commands/dragonfly_mixin.py index 93807488..d873fae0 100644 --- a/fakeredis/server_specific_commands/dragonfly_mixin.py +++ b/fakeredis/server_specific_commands/dragonfly_mixin.py @@ -1,7 +1,8 @@ from typing import Callable from fakeredis._commands import command, Key, Int, CommandItem -from fakeredis._helpers import Database +from fakeredis._helpers import Database, current_time +from fakeredis.model import ExpiringMembersSet class DragonflyCommandsMixin(object): @@ -11,10 +12,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._db: Database - @command(name="SADDEX", fixed=(Key(set), Int, bytes), repeat=(bytes,), server_types=("dragonfly",)) + @command(name="SADDEX", fixed=(Key(ExpiringMembersSet), Int, bytes), repeat=(bytes,), server_types=("dragonfly",)) def saddex(self, key: CommandItem, seconds: int, *members: bytes) -> int: - old_size = len(key.value) - key.value.update(members) + val = key.value + old_size = len(val) + new_members = set(members) - set(val) + expire_at_ms = current_time() + seconds * 1000 + for member in new_members: + val.set_member_expireat(member, expire_at_ms) key.updated() - self._expireat(key, self._db.time + seconds) - return len(key.value) - old_size + return len(val) - old_size diff --git a/test/test_mixins/test_set_commands.py b/test/test_mixins/test_set_commands.py index 65221c47..0b3242e3 100644 --- a/test/test_mixins/test_set_commands.py +++ b/test/test_mixins/test_set_commands.py @@ -21,6 +21,17 @@ def test_saddex(r: redis.Redis): assert set(r.smembers("foo")) == set() +@pytest.mark.slow +@pytest.mark.unsupported_server_types("redis", "valkey") +def test_saddex_expire_members(r: redis.Redis): + set_name = "foo" + assert testtools.raw_command(r, "saddex", set_name, 1, "m1", "m2") == 2 + assert r.sadd(set_name, "m3", "m4") == 2 + assert testtools.raw_command(r, "saddex", set_name, 1, "m3") == 0 + sleep(1.1) + assert set(r.smembers("foo")) == {b"m3", b"m4"} + + @testtools.run_test_if_redispy_ver("gte", "5.1") def test_sadd(r: redis.Redis): assert r.sadd("foo", "member1") == 1