Skip to content

Commit

Permalink
SADDEX - implement set with support for expiring members (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla authored Dec 18, 2024
1 parent 0afe9a5 commit f006b55
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 116 deletions.
1 change: 1 addition & 0 deletions docs/about/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ toc_depth: 2
### 🚀 Features

- Add support disable_decoding in async read_response #349
- Implement support for `SADDEX`, using a new set implementation with support for expiring members #350

## v2.26.2

Expand Down
7 changes: 3 additions & 4 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import redis
from redis.connection import DefaultParser

from fakeredis.model import XStream
from fakeredis.model import ZSet
from fakeredis.model import XStream, ZSet, Hash, ExpiringMembersSet
from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem
from ._helpers import (
SimpleError,
valid_response_type,
Expand Down Expand Up @@ -392,7 +391,7 @@ def _key_value_type(key: CommandItem) -> SimpleString:
return SimpleString(b"string")
elif isinstance(key.value, list):
return SimpleString(b"list")
elif isinstance(key.value, set):
elif isinstance(key.value, ExpiringMembersSet):
return SimpleString(b"set")
elif isinstance(key.value, ZSet):
return SimpleString(b"zset")
Expand Down
80 changes: 2 additions & 78 deletions fakeredis/_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import re
import sys
import time
from typing import Iterable, Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection
from typing import Tuple, Union, Optional, Any, Type, List, Callable, Sequence, Dict, Set, Collection

from . import _msgs as msgs
from ._helpers import null_terminate, SimpleError, Database, current_time
from ._helpers import null_terminate, SimpleError, Database

MAX_STRING_SIZE = 512 * 1024 * 1024
SUPPORTED_COMMANDS: Dict[str, "Signature"] = dict() # Dictionary of supported commands name => Signature
Expand Down Expand Up @@ -107,82 +107,6 @@ def __bool__(self) -> bool:
__nonzero__ = __bool__ # For Python 2


class Hash:
DECODE_ERROR = msgs.INVALID_HASH_MSG
redis_type = b"hash"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._expirations: Dict[bytes, int] = dict()
self._values: Dict[bytes, Any] = dict()

def _expire_keys(self) -> None:
removed = []
now = current_time()
for k in self._expirations:
if self._expirations[k] < now:
self._values.pop(k, None)
removed.append(k)
for k in removed:
self._expirations.pop(k, None)

def set_key_expireat(self, key: bytes, when_ms: int) -> int:
now = current_time()
if when_ms <= now:
self._values.pop(key, None)
self._expirations.pop(key, None)
return 2
self._expirations[key] = when_ms
return 1

def clear_key_expireat(self, key: bytes) -> bool:
return self._expirations.pop(key, None) is not None

def get_key_expireat(self, key: bytes) -> Optional[int]:
self._expire_keys()
return self._expirations.get(key, None)

def __getitem__(self, key: bytes) -> Any:
self._expire_keys()
return self._values.get(key)

def __contains__(self, key: bytes) -> bool:
self._expire_keys()
return self._values.__contains__(key)

def __setitem__(self, key: bytes, value: Any) -> None:
self._expirations.pop(key, None)
self._values[key] = value

def __delitem__(self, key: bytes) -> None:
self._values.pop(key, None)
self._expirations.pop(key, None)

def __len__(self) -> int:
return len(self._values)

def __iter__(self) -> Iterable[bytes]:
return iter(self._values)

def get(self, key: bytes, default: Any = None) -> Any:
return self._values.get(key, default)

def keys(self) -> Iterable[bytes]:
self._expire_keys()
return self._values.keys()

def values(self) -> Iterable[Any]:
return [v for k, v in self.items()]

def items(self) -> Iterable[Tuple[bytes, Any]]:
self._expire_keys()
return self._values.items()

def update(self, values: Dict[bytes, Any]) -> None:
self._expire_keys()
self._values.update(values)


class RedisType:
@classmethod
def decode(cls, *args, **kwargs): # type:ignore
Expand Down
5 changes: 2 additions & 3 deletions fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
CommandItem,
SortFloat,
delete_keys,
Hash,
)
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database, SimpleString
from fakeredis.model import ZSet
from fakeredis.model import ZSet, Hash, ExpiringMembersSet


class GenericCommandsMixin:
Expand Down Expand Up @@ -224,7 +223,7 @@ def scan(self, cursor, *args):

@command(name="SORT", fixed=(Key(),), repeat=(bytes,))
def sort(self, key, *args):
if key.value is not None and not isinstance(key.value, (set, list, ZSet)):
if key.value is not None and not isinstance(key.value, (ExpiringMembersSet, list, ZSet)):
raise SimpleError(msgs.WRONGTYPE_MSG)
(
asc,
Expand Down
3 changes: 2 additions & 1 deletion fakeredis/commands_mixins/hash_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Hash, Int, Float, CommandItem
from fakeredis._commands import command, Key, Int, Float, CommandItem
from fakeredis._helpers import SimpleError, OK, casematch, SimpleString
from fakeredis._helpers import current_time
from fakeredis.model import Hash


class HashCommandsMixin:
Expand Down
45 changes: 23 additions & 22 deletions fakeredis/commands_mixins/set_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
from fakeredis import _msgs as msgs
from fakeredis._commands import command, Key, Int, CommandItem
from fakeredis._helpers import OK, SimpleError, casematch, Database, SimpleString
from fakeredis.model import ExpiringMembersSet


def _calc_setop(op: Callable[..., Any], stop_if_missing: bool, key: CommandItem, *keys: CommandItem) -> Any:
if stop_if_missing and not key.value:
return set()
value = key.value
if not isinstance(value, set):
if not isinstance(value, ExpiringMembersSet):
raise SimpleError(msgs.WRONGTYPE_MSG)
ans = value.copy()
for other in keys:
value = other.value if other.value is not None else set()
if not isinstance(value, set):
value = other.value if other.value is not None else ExpiringMembersSet()
if not isinstance(value, ExpiringMembersSet):
raise SimpleError(msgs.WRONGTYPE_MSG)
if stop_if_missing and not value:
return set()
Expand Down Expand Up @@ -48,26 +49,26 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.version: Tuple[int]
self._db: Database

@command((Key(set), bytes), (bytes,))
@command((Key(ExpiringMembersSet), bytes), (bytes,))
def sadd(self, key: CommandItem, *members: bytes) -> int:
old_size = len(key.value)
key.value.update(members)
key.updated()
return len(key.value) - old_size

@command((Key(set),))
@command((Key(ExpiringMembersSet),))
def scard(self, key: CommandItem) -> int:
return len(key.value)

@command((Key(set),), (Key(set),))
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
def sdiff(self, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a - b, False, None, *keys)

@command((Key(), Key(set)), (Key(set),))
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
def sdiffstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a - b, False, dst, *keys)

@command((Key(set),), (Key(set),))
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
def sinter(self, *keys: CommandItem) -> Any:
res = _setop(lambda a, b: a & b, True, None, *keys)
return res
Expand All @@ -89,23 +90,23 @@ def sintercard(self, numkeys: int, *args: bytes) -> int:
res = _setop(lambda a, b: a & b, False, None, *keys)
return len(res) if limit == 0 else min(limit, len(res))

@command((Key(), Key(set)), (Key(set),))
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
def sinterstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a & b, True, dst, *keys)

@command((Key(set), bytes))
@command((Key(ExpiringMembersSet), bytes))
def sismember(self, key: CommandItem, member: bytes) -> int:
return int(member in key.value)

@command((Key(set), bytes), (bytes,))
@command((Key(ExpiringMembersSet), bytes), (bytes,))
def smismember(self, key: CommandItem, *members: bytes) -> List[int]:
return [self.sismember(key, member) for member in members]

@command((Key(set),))
@command((Key(ExpiringMembersSet),))
def smembers(self, key: CommandItem) -> List[bytes]:
return list(key.value)

@command((Key(set, 0), Key(set), bytes))
@command((Key(ExpiringMembersSet, 0), Key(ExpiringMembersSet), bytes))
def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
try:
src.value.remove(member)
Expand All @@ -117,7 +118,7 @@ def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
dst.updated() # TODO: is it updated if member was already present?
return 1

@command((Key(set),), (Int,))
@command((Key(ExpiringMembersSet),), (Int,))
def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
if count is None:
if not key.value:
Expand All @@ -135,7 +136,7 @@ def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, Li
key.updated() # Inside the loop because redis special-cases count=0
return items

@command((Key(set),), (Int,))
@command((Key(ExpiringMembersSet),), (Int,))
def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
if count is None:
if not key.value:
Expand All @@ -149,7 +150,7 @@ def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[by
items = list(key.value)
return [random.choice(items) for _ in range(-count)]

@command((Key(set), bytes), (bytes,))
@command((Key(ExpiringMembersSet), bytes), (bytes,))
def srem(self, key: CommandItem, *members: bytes) -> int:
old_size = len(key.value)
for member in members:
Expand All @@ -159,15 +160,15 @@ def srem(self, key: CommandItem, *members: bytes) -> int:
key.updated()
return deleted

@command((Key(set), Int), (bytes, bytes))
@command((Key(ExpiringMembersSet), Int), (bytes, bytes))
def sscan(self, key: CommandItem, cursor: int, *args: bytes) -> Any:
return self._scan(key.value, cursor, *args)

@command((Key(set),), (Key(set),))
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
def sunion(self, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a | b, False, None, *keys)

@command((Key(), Key(set)), (Key(set),))
@command((Key(), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a | b, False, dst, *keys)

Expand All @@ -176,19 +177,19 @@ def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
# approximate and store the results in a string. Instead, it is implemented
# on top of sets.

@command((Key(set),), (bytes,))
@command((Key(ExpiringMembersSet),), (bytes,))
def pfadd(self, key: CommandItem, *elements: bytes) -> int:
result = self.sadd(key, *elements)
# Per the documentation:
# - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise.
return 1 if result > 0 else 0

@command((Key(set),), (Key(set),))
@command((Key(ExpiringMembersSet),), (Key(ExpiringMembersSet),))
def pfcount(self, *keys: CommandItem) -> int:
"""Return the approximated cardinality of the set observed by the HyperLogLog at key(s)."""
return len(self.sunion(*keys))

@command((Key(set), Key(set)), (Key(set),))
@command((Key(ExpiringMembersSet), Key(ExpiringMembersSet)), (Key(ExpiringMembersSet),))
def pfmerge(self, dest: CommandItem, *sources: CommandItem) -> SimpleString:
"""Merge N different HyperLogLogs into a single one."""
self.sunionstore(dest, *sources)
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/commands_mixins/sortedset_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
null_terminate,
Database,
)
from fakeredis.model import ZSet
from fakeredis.model import ZSet, ExpiringMembersSet

SORTED_SET_METHODS = {
"ZUNIONSTORE": lambda s1, s2: s1 | s2,
Expand Down Expand Up @@ -391,7 +391,7 @@ def zscore(self, key, member):

@staticmethod
def _get_zset(value):
if isinstance(value, set):
if isinstance(value, ExpiringMembersSet):
zset = ZSet()
for item in value:
zset[item] = 1.0
Expand Down
4 changes: 4 additions & 0 deletions fakeredis/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._expiring_members_set import ExpiringMembersSet
from ._hash import Hash
from ._stream import XStream, StreamEntryKey, StreamGroup, StreamRangeTest
from ._timeseries_model import TimeSeries, TimeSeriesRule, AGGREGATORS
from ._topk import HeavyKeeper
Expand All @@ -13,4 +15,6 @@
"TimeSeriesRule",
"AGGREGATORS",
"HeavyKeeper",
"Hash",
"ExpiringMembersSet",
]
Loading

0 comments on commit f006b55

Please sign in to comment.