diff --git a/CHANGELOG.md b/CHANGELOG.md index 0751d08c4d..186e18b7e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ * Python: Added XLEN command ([#1503](https://github.com/aws/glide-for-redis/pull/1503)) * Python: Added LASTSAVE command ([#1509](https://github.com/aws/glide-for-redis/pull/1509)) * Python: Added GETDEL command ([#1514](https://github.com/aws/glide-for-redis/pull/1514)) +* Python: Added ZINTER, ZUNION commands ([#1478](https://github.com/aws/glide-for-redis/pull/1478)) ## 0.4.1 (2024-02-06) diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 1b4aa48a0e..1e1ec39cf6 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -25,7 +25,7 @@ RangeByScore, ScoreBoundary, ScoreFilter, - _create_z_cmd_store_args, + _create_zinter_zunion_cmd_args, _create_zrange_args, ) from glide.constants import TOK, TResult @@ -3228,6 +3228,75 @@ async def zdiffstore(self, destination: str, keys: List[str]) -> int: ), ) + async def zinter( + self, + keys: List[str], + ) -> List[str]: + """ + Computes the intersection of sorted sets given by the specified `keys` and returns a list of intersecting elements. + To get the scores as well, see `zinter_withscores`. + To store the result in a key as a sorted set, see `zinterstore`. + + When in cluster mode, all keys in `keys` must map to the same hash slot. + + See https://valkey.io/commands/zinter/ for more details. + + Args: + keys (List[str]): The keys of the sorted sets. + + Returns: + List[str]: The resulting array of intersecting elements. + + Examples: + >>> await client.zadd("key1", {"member1": 10.5, "member2": 8.2}) + >>> await client.zadd("key2", {"member1": 9.5}) + >>> await client.zinter(["key1", "key2"]) + ['member1'] + """ + return cast( + List[str], + await self._execute_command(RequestType.ZInter, [str(len(keys))] + keys), + ) + + async def zinter_withscores( + self, + keys: Union[List[str], List[Tuple[str, float]]], + aggregation_type: Optional[AggregationType] = None, + ) -> Mapping[str, float]: + """ + Computes the intersection of sorted sets given by the specified `keys` and returns a sorted set of intersecting elements with scores. + To get the elements only, see `zinter`. + To store the result in a key as a sorted set, see `zinterstore`. + + When in cluster mode, all keys in `keys` must map to the same hash slot. + + See https://valkey.io/commands/zinter/ for more details. + + Args: + keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: + List[str] - for keys only. + List[Tuple[str, float]] - for weighted keys with score multipliers. + aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply + when combining the scores of elements. See `AggregationType`. + + Returns: + Mapping[str, float]: The resulting sorted set with scores. + + Examples: + >>> await client.zadd("key1", {"member1": 10.5, "member2": 8.2}) + >>> await client.zadd("key2", {"member1": 9.5}) + >>> await client.zinter_withscores(["key1", "key2"]) + {'member1': 20} # "member1" with score of 20 is the result + >>> await client.zinter_withscores(["key1", "key2"], AggregationType.MAX) + {'member1': 10.5} # "member1" with score of 10.5 is the result. + """ + args = _create_zinter_zunion_cmd_args(keys, aggregation_type) + args.append("WITHSCORES") + return cast( + Mapping[str, float], + await self._execute_command(RequestType.ZInter, args), + ) + async def zinterstore( self, destination: str, @@ -3237,6 +3306,7 @@ async def zinterstore( """ Computes the intersection of sorted sets given by the specified `keys` and stores the result in `destination`. If `destination` already exists, it is overwritten. Otherwise, a new sorted set will be created. + To get the result directly, see `zinter_withscores`. When in cluster mode, `destination` and all keys in `keys` must map to the same hash slot. @@ -3246,7 +3316,7 @@ async def zinterstore( destination (str): The key of the destination sorted set. keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: List[str] - for keys only. - List[Tuple[str, float]]] - for weighted keys with score multipliers. + List[Tuple[str, float]] - for weighted keys with score multipliers. aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply when combining the scores of elements. See `AggregationType`. @@ -3259,18 +3329,87 @@ async def zinterstore( >>> await client.zinterstore("my_sorted_set", ["key1", "key2"]) 1 # Indicates that the sorted set "my_sorted_set" contains one element. >>> await client.zrange_withscores("my_sorted_set", RangeByIndex(0, -1)) - {'member1': 20} # "member1" is now stored in "my_sorted_set" with score of 20. - >>> await client.zinterstore("my_sorted_set", ["key1", "key2"] , AggregationType.MAX ) - 1 # Indicates that the sorted set "my_sorted_set" contains one element, and it's score is the maximum score between the sets. + {'member1': 20} # "member1" is now stored in "my_sorted_set" with score of 20. + >>> await client.zinterstore("my_sorted_set", ["key1", "key2"], AggregationType.MAX) + 1 # Indicates that the sorted set "my_sorted_set" contains one element, and its score is the maximum score between the sets. >>> await client.zrange_withscores("my_sorted_set", RangeByIndex(0, -1)) - {'member1': 10.5} # "member1" is now stored in "my_sorted_set" with score of 10.5. + {'member1': 10.5} # "member1" is now stored in "my_sorted_set" with score of 10.5. """ - args = _create_z_cmd_store_args(destination, keys, aggregation_type) + args = _create_zinter_zunion_cmd_args(keys, aggregation_type, destination) return cast( int, await self._execute_command(RequestType.ZInterStore, args), ) + async def zunion( + self, + keys: List[str], + ) -> List[str]: + """ + Computes the union of sorted sets given by the specified `keys` and returns a list of union elements. + To get the scores as well, see `zunion_withscores`. + To store the result in a key as a sorted set, see `zunionstore`. + + When in cluster mode, all keys in `keys` must map to the same hash slot. + + See https://valkey.io/commands/zunion/ for more details. + + Args: + keys (List[str]): The keys of the sorted sets. + + Returns: + List[str]: The resulting array of union elements. + + Examples: + >>> await client.zadd("key1", {"member1": 10.5, "member2": 8.2}) + >>> await client.zadd("key2", {"member1": 9.5}) + >>> await client.zunion(["key1", "key2"]) + ['member1', 'member2'] + """ + return cast( + List[str], + await self._execute_command(RequestType.ZUnion, [str(len(keys))] + keys), + ) + + async def zunion_withscores( + self, + keys: Union[List[str], List[Tuple[str, float]]], + aggregation_type: Optional[AggregationType] = None, + ) -> Mapping[str, float]: + """ + Computes the union of sorted sets given by the specified `keys` and returns a sorted set of union elements with scores. + To get the elements only, see `zunion`. + To store the result in a key as a sorted set, see `zunionstore`. + + When in cluster mode, all keys in `keys` must map to the same hash slot. + + See https://valkey.io/commands/zunion/ for more details. + + Args: + keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: + List[str] - for keys only. + List[Tuple[str, float]] - for weighted keys with score multipliers. + aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply + when combining the scores of elements. See `AggregationType`. + + Returns: + Mapping[str, float]: The resulting sorted set with scores. + + Examples: + >>> await client.zadd("key1", {"member1": 10.5, "member2": 8.2}) + >>> await client.zadd("key2", {"member1": 9.5}) + >>> await client.zunion_withscores(["key1", "key2"]) + {'member1': 20, 'member2': 8.2} + >>> await client.zunion_withscores(["key1", "key2"], AggregationType.MAX) + {'member1': 10.5, 'member2': 8.2} + """ + args = _create_zinter_zunion_cmd_args(keys, aggregation_type) + args.append("WITHSCORES") + return cast( + Mapping[str, float], + await self._execute_command(RequestType.ZUnion, args), + ) + async def zunionstore( self, destination: str, @@ -3280,16 +3419,17 @@ async def zunionstore( """ Computes the union of sorted sets given by the specified `keys` and stores the result in `destination`. If `destination` already exists, it is overwritten. Otherwise, a new sorted set will be created. + To get the result directly, see `zunion_withscores`. When in cluster mode, `destination` and all keys in `keys` must map to the same hash slot. - see https://valkey.io/commands/zunionstore/ for more details. + See https://valkey.io/commands/zunionstore/ for more details. Args: destination (str): The key of the destination sorted set. keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: List[str] - for keys only. - List[Tuple[str, float]]] - for weighted keys with score multipliers. + List[Tuple[str, float]] - for weighted keys with score multipliers. aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply when combining the scores of elements. See `AggregationType`. @@ -3300,15 +3440,15 @@ async def zunionstore( >>> await client.zadd("key1", {"member1": 10.5, "member2": 8.2}) >>> await client.zadd("key2", {"member1": 9.5}) >>> await client.zunionstore("my_sorted_set", ["key1", "key2"]) - 2 # Indicates that the sorted set "my_sorted_set" contains two element. + 2 # Indicates that the sorted set "my_sorted_set" contains two elements. >>> await client.zrange_withscores("my_sorted_set", RangeByIndex(0, -1)) {'member1': 20, 'member2': 8.2} - >>> await client.zunionstore("my_sorted_set", ["key1", "key2"] , AggregationType.MAX ) - 2 # Indicates that the sorted set "my_sorted_set" contains two element, and each score is the maximum score between the sets. + >>> await client.zunionstore("my_sorted_set", ["key1", "key2"], AggregationType.MAX) + 2 # Indicates that the sorted set "my_sorted_set" contains two elements, and each score is the maximum score between the sets. >>> await client.zrange_withscores("my_sorted_set", RangeByIndex(0, -1)) {'member1': 10.5, 'member2': 8.2} """ - args = _create_z_cmd_store_args(destination, keys, aggregation_type) + args = _create_zinter_zunion_cmd_args(keys, aggregation_type, destination) return cast( int, await self._execute_command(RequestType.ZUnionStore, args), diff --git a/python/python/glide/async_commands/sorted_set.py b/python/python/glide/async_commands/sorted_set.py index 9fff352159..7ac92f1e99 100644 --- a/python/python/glide/async_commands/sorted_set.py +++ b/python/python/glide/async_commands/sorted_set.py @@ -220,12 +220,17 @@ def separate_keys( return key_list, weight_list -def _create_z_cmd_store_args( - destination: str, +def _create_zinter_zunion_cmd_args( keys: Union[List[str], List[Tuple[str, float]]], aggregation_type: Optional[AggregationType] = None, + destination: Optional[str] = None, ) -> List[str]: - args = [destination, str(len(keys))] + args = [] + + if destination: + args.append(destination) + + args.append(str(len(keys))) only_keys, weights = separate_keys(keys) diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index b2679fb722..5e0cf43dec 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -24,7 +24,7 @@ RangeByScore, ScoreBoundary, ScoreFilter, - _create_z_cmd_store_args, + _create_zinter_zunion_cmd_args, _create_zrange_args, ) from glide.protobuf.redis_request_pb2 import RequestType @@ -2272,6 +2272,47 @@ def zdiffstore( RequestType.ZDiffStore, [destination, str(len(keys))] + keys ) + def zinter( + self: TTransaction, + keys: List[str], + ) -> TTransaction: + """ + Computes the intersection of sorted sets given by the specified `keys` and returns a list of intersecting elements. + + See https://valkey.io/commands/zinter/ for more details. + + Args: + keys (List[str]): The keys of the sorted sets. + + Command response: + List[str]: The resulting array of intersecting elements. + """ + return self.append_command(RequestType.ZInter, [str(len(keys))] + keys) + + def zinter_withscores( + self: TTransaction, + keys: Union[List[str], List[Tuple[str, float]]], + aggregation_type: Optional[AggregationType] = None, + ) -> TTransaction: + """ + Computes the intersection of sorted sets given by the specified `keys` and returns a sorted set of intersecting elements with scores. + + See https://valkey.io/commands/zinter/ for more details. + + Args: + keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: + List[str] - for keys only. + List[Tuple[str, float]] - for weighted keys with score multipliers. + aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply + when combining the scores of elements. See `AggregationType`. + + Command response: + Mapping[str, float]: The resulting sorted set with scores. + """ + args = _create_zinter_zunion_cmd_args(keys, aggregation_type) + args.append("WITHSCORES") + return self.append_command(RequestType.ZInter, args) + def zinterstore( self: TTransaction, destination: str, @@ -2297,9 +2338,50 @@ def zinterstore( Command response: int: The number of elements in the resulting sorted set stored at `destination`. """ - args = _create_z_cmd_store_args(destination, keys, aggregation_type) + args = _create_zinter_zunion_cmd_args(keys, aggregation_type, destination) return self.append_command(RequestType.ZInterStore, args) + def zunion( + self: TTransaction, + keys: List[str], + ) -> TTransaction: + """ + Computes the union of sorted sets given by the specified `keys` and returns a list of union elements. + + See https://valkey.io/commands/zunion/ for more details. + + Args: + keys (List[str]): The keys of the sorted sets. + + Command response: + List[str]: The resulting array of union elements. + """ + return self.append_command(RequestType.ZUnion, [str(len(keys))] + keys) + + def zunion_withscores( + self: TTransaction, + keys: Union[List[str], List[Tuple[str, float]]], + aggregation_type: Optional[AggregationType] = None, + ) -> TTransaction: + """ + Computes the union of sorted sets given by the specified `keys` and returns a sorted set of union elements with scores. + + See https://valkey.io/commands/zunion/ for more details. + + Args: + keys (Union[List[str], List[Tuple[str, float]]]): The keys of the sorted sets with possible formats: + List[str] - for keys only. + List[Tuple[str, float]] - for weighted keys with score multipliers. + aggregation_type (Optional[AggregationType]): Specifies the aggregation strategy to apply + when combining the scores of elements. See `AggregationType`. + + Command response: + Mapping[str, float]: The resulting sorted set with scores. + """ + args = _create_zinter_zunion_cmd_args(keys, aggregation_type) + args.append("WITHSCORES") + return self.append_command(RequestType.ZUnion, args) + def zunionstore( self: TTransaction, destination: str, @@ -2325,7 +2407,7 @@ def zunionstore( Command response: int: The number of elements in the resulting sorted set stored at `destination`. """ - args = _create_z_cmd_store_args(destination, keys, aggregation_type) + args = _create_zinter_zunion_cmd_args(keys, aggregation_type, destination) return self.append_command(RequestType.ZUnionStore, args) def zrandmember(self: TTransaction, key: str) -> TTransaction: diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 1049e31d4e..bae5cbaf80 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -2157,7 +2157,7 @@ async def test_zmscore(self, redis_client: TRedisClient): @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_zinterstore(self, redis_client: TRedisClient): + async def test_zinter_commands(self, redis_client: TRedisClient): key1 = "{testKey}:1-" + get_random_string(10) key2 = "{testKey}:2-" + get_random_string(10) key3 = "{testKey}:3-" + get_random_string(10) @@ -2168,48 +2168,74 @@ async def test_zinterstore(self, redis_client: TRedisClient): assert await redis_client.zadd(key1, members_scores1) == 2 assert await redis_client.zadd(key2, members_scores2) == 3 + # zinter tests + zinter_map = await redis_client.zinter([key1, key2]) + expected_zinter_map = ["one", "two"] + assert zinter_map == expected_zinter_map + + # zinterstore tests assert await redis_client.zinterstore(key3, [key1, key2]) == 2 zinterstore_map = await redis_client.zrange_withscores(key3, range) - expected_map = { + expected_zinter_map_withscores = { "one": 2.5, "two": 4.5, } - assert compare_maps(zinterstore_map, expected_map) is True + assert compare_maps(zinterstore_map, expected_zinter_map_withscores) is True + + # zinter_withscores tests + zinter_withscores_map = await redis_client.zinter_withscores([key1, key2]) + assert ( + compare_maps(zinter_withscores_map, expected_zinter_map_withscores) is True + ) - # Intersection results are aggregated by the MAX score of elements + # MAX aggregation tests assert ( await redis_client.zinterstore(key3, [key1, key2], AggregationType.MAX) == 2 ) zinterstore_map_max = await redis_client.zrange_withscores(key3, range) - expected_map_max = { + expected_zinter_map_max = { "one": 1.5, "two": 2.5, } - assert compare_maps(zinterstore_map_max, expected_map_max) is True + assert compare_maps(zinterstore_map_max, expected_zinter_map_max) is True - # Intersection results are aggregated by the MIN score of elements + zinter_withscores_map_max = await redis_client.zinter_withscores( + [key1, key2], AggregationType.MAX + ) + assert compare_maps(zinter_withscores_map_max, expected_zinter_map_max) is True + + # MIN aggregation tests assert ( await redis_client.zinterstore(key3, [key1, key2], AggregationType.MIN) == 2 ) zinterstore_map_min = await redis_client.zrange_withscores(key3, range) - expected_map_min = { + expected_zinter_map_min = { "one": 1.0, "two": 2.0, } - assert compare_maps(zinterstore_map_min, expected_map_min) is True + assert compare_maps(zinterstore_map_min, expected_zinter_map_min) is True + + zinter_withscores_map_min = await redis_client.zinter_withscores( + [key1, key2], AggregationType.MIN + ) + assert compare_maps(zinter_withscores_map_min, expected_zinter_map_min) is True - # Intersection results are aggregated by the SUM score of elements + # SUM aggregation tests assert ( await redis_client.zinterstore(key3, [key1, key2], AggregationType.SUM) == 2 ) zinterstore_map_sum = await redis_client.zrange_withscores(key3, range) - expected_map_sum = { - "one": 2.5, - "two": 4.5, - } - assert compare_maps(zinterstore_map_sum, expected_map_sum) is True + assert compare_maps(zinterstore_map_sum, expected_zinter_map_withscores) is True + + zinter_withscores_map_sum = await redis_client.zinter_withscores( + [key1, key2], AggregationType.SUM + ) + assert ( + compare_maps(zinter_withscores_map_sum, expected_zinter_map_withscores) + is True + ) - # Scores are multiplied by 2.0 for key1 and key2 during aggregation. + # Multiplying scores during aggregation tests assert ( await redis_client.zinterstore( key3, [(key1, 2.0), (key2, 2.0)], AggregationType.SUM @@ -2217,25 +2243,51 @@ async def test_zinterstore(self, redis_client: TRedisClient): == 2 ) zinterstore_map_multiplied = await redis_client.zrange_withscores(key3, range) - expected_map_multiplied = { + expected_zinter_map_multiplied = { "one": 5.0, "two": 9.0, } - assert compare_maps(zinterstore_map_multiplied, expected_map_multiplied) is True + assert ( + compare_maps(zinterstore_map_multiplied, expected_zinter_map_multiplied) + is True + ) + + zinter_withscores_map_multiplied = await redis_client.zinter_withscores( + [(key1, 2.0), (key2, 2.0)], AggregationType.SUM + ) + assert ( + compare_maps( + zinter_withscores_map_multiplied, expected_zinter_map_multiplied + ) + is True + ) + # Non-existing key test assert ( await redis_client.zinterstore(key3, [key1, "{testKey}-non_existing_key"]) == 0 ) + zinter_withscores_non_existing = await redis_client.zinter_withscores( + [key1, "{testKey}-non_existing_key"] + ) + assert zinter_withscores_non_existing == {} # Empty list check with pytest.raises(RequestError) as e: await redis_client.zinterstore("{xyz}", []) assert "wrong number of arguments" in str(e) + with pytest.raises(RequestError) as e: + await redis_client.zinter([]) + assert "wrong number of arguments" in str(e) + + with pytest.raises(RequestError) as e: + await redis_client.zinter_withscores([]) + assert "at least 1 input key is needed" in str(e) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) - async def test_zunionstore(self, redis_client: TRedisClient): + async def test_zunion_commands(self, redis_client: TRedisClient): key1 = "{testKey}:1-" + get_random_string(10) key2 = "{testKey}:2-" + get_random_string(10) key3 = "{testKey}:3-" + get_random_string(10) @@ -2246,76 +2298,146 @@ async def test_zunionstore(self, redis_client: TRedisClient): assert await redis_client.zadd(key1, members_scores1) == 2 assert await redis_client.zadd(key2, members_scores2) == 3 + # zunion tests + zunion_map = await redis_client.zunion([key1, key2]) + expected_zunion_map = ["one", "three", "two"] + assert zunion_map == expected_zunion_map + + # zunionstore tests assert await redis_client.zunionstore(key3, [key1, key2]) == 3 zunionstore_map = await redis_client.zrange_withscores(key3, range) - expected_map = { + expected_zunion_map_withscores = { "one": 2.5, "three": 3.5, "two": 4.5, } - assert compare_maps(zunionstore_map, expected_map) is True + assert compare_maps(zunionstore_map, expected_zunion_map_withscores) is True + + # zunion_withscores tests + zunion_withscores_map = await redis_client.zunion_withscores([key1, key2]) + assert ( + compare_maps(zunion_withscores_map, expected_zunion_map_withscores) is True + ) - # Intersection results are aggregated by the MAX score of elements + # MAX aggregation tests assert ( await redis_client.zunionstore(key3, [key1, key2], AggregationType.MAX) == 3 ) zunionstore_map_max = await redis_client.zrange_withscores(key3, range) - expected_map_max = { + expected_zunion_map_max = { "one": 1.5, "two": 2.5, "three": 3.5, } - assert compare_maps(zunionstore_map_max, expected_map_max) is True + assert compare_maps(zunionstore_map_max, expected_zunion_map_max) is True + + zunion_withscores_map_max = await redis_client.zunion_withscores( + [key1, key2], AggregationType.MAX + ) + assert compare_maps(zunion_withscores_map_max, expected_zunion_map_max) is True - # Intersection results are aggregated by the MIN score of elements + # MIN aggregation tests assert ( await redis_client.zunionstore(key3, [key1, key2], AggregationType.MIN) == 3 ) zunionstore_map_min = await redis_client.zrange_withscores(key3, range) - expected_map_min = { + expected_zunion_map_min = { "one": 1.0, "two": 2.0, "three": 3.5, } - assert compare_maps(zunionstore_map_min, expected_map_min) is True + assert compare_maps(zunionstore_map_min, expected_zunion_map_min) is True + + zunion_withscores_map_min = await redis_client.zunion_withscores( + [key1, key2], AggregationType.MIN + ) + assert compare_maps(zunion_withscores_map_min, expected_zunion_map_min) is True - # Intersection results are aggregated by the SUM score of elements + # SUM aggregation tests assert ( await redis_client.zunionstore(key3, [key1, key2], AggregationType.SUM) == 3 ) zunionstore_map_sum = await redis_client.zrange_withscores(key3, range) - expected_map_sum = { - "one": 2.5, - "three": 3.5, - "two": 4.5, - } - assert compare_maps(zunionstore_map_sum, expected_map_sum) is True + assert compare_maps(zunionstore_map_sum, expected_zunion_map_withscores) is True - # Scores are multiplied by 2.0 for key1 and key2 during aggregation. + zunion_withscores_map_sum = await redis_client.zunion_withscores( + [key1, key2], AggregationType.SUM + ) + assert ( + compare_maps(zunion_withscores_map_sum, expected_zunion_map_withscores) + is True + ) + + # Multiplying scores during aggregation tests assert ( await redis_client.zunionstore( key3, [(key1, 2.0), (key2, 2.0)], AggregationType.SUM ) == 3 ) - zunionstore_map = await redis_client.zrange_withscores(key3, range) - expected_map = { + zunionstore_map_multiplied = await redis_client.zrange_withscores(key3, range) + expected_zunion_map_multiplied = { "one": 5.0, "three": 7.0, "two": 9.0, } - assert compare_maps(zunionstore_map, expected_map) is True + assert ( + compare_maps(zunionstore_map_multiplied, expected_zunion_map_multiplied) + is True + ) + + zunion_withscores_map_multiplied = await redis_client.zunion_withscores( + [(key1, 2.0), (key2, 2.0)], AggregationType.SUM + ) + assert ( + compare_maps( + zunion_withscores_map_multiplied, expected_zunion_map_multiplied + ) + is True + ) + # Non-existing key test assert ( await redis_client.zunionstore(key3, [key1, "{testKey}-non_existing_key"]) == 2 ) + zunionstore_map_nonexistingkey = await redis_client.zrange_withscores( + key3, range + ) + expected_zunion_map_nonexistingkey = { + "one": 1.0, + "two": 2.0, + } + assert ( + compare_maps( + zunionstore_map_nonexistingkey, expected_zunion_map_nonexistingkey + ) + is True + ) + + zunion_withscores_non_existing = await redis_client.zunion_withscores( + [key1, "{testKey}-non_existing_key"] + ) + assert ( + compare_maps( + zunion_withscores_non_existing, expected_zunion_map_nonexistingkey + ) + is True + ) # Empty list check with pytest.raises(RequestError) as e: await redis_client.zunionstore("{xyz}", []) assert "wrong number of arguments" in str(e) + with pytest.raises(RequestError) as e: + await redis_client.zunion([]) + assert "wrong number of arguments" in str(e) + + with pytest.raises(RequestError) as e: + await redis_client.zunion_withscores([]) + assert "at least 1 input key is needed" in str(e) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_zpopmin(self, redis_client: TRedisClient): @@ -3618,6 +3740,10 @@ async def test_multi_key_command_returns_cross_slot_error( redis_client.renamenx("abc", "def"), redis_client.pfcount(["def", "ghi"]), redis_client.pfmerge("abc", ["def", "ghi"]), + redis_client.zinter(["def", "ghi"]), + redis_client.zinter_withscores(["def", "ghi"]), + redis_client.zunion(["def", "ghi"]), + redis_client.zunion_withscores(["def", "ghi"]), ] if not await check_if_server_version_lt(redis_client, "7.0.0"): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index c98d4b09df..373412707b 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -305,8 +305,16 @@ async def transaction_test( args.append(2) transaction.zadd(key15, {"one": 1.0, "two": 2.0, "three": 3.5}) args.append(3) + transaction.zinter([key14, key15]) + args.append(["one", "two"]) + transaction.zinter_withscores([key14, key15]) + args.append({"one": 2.0, "two": 4.0}) transaction.zinterstore(key8, [key14, key15]) args.append(2) + transaction.zunion([key14, key15]) + args.append(["one", "three", "two"]) + transaction.zunion_withscores([key14, key15]) + args.append({"one": 2.0, "two": 4.0, "three": 3.5}) transaction.zunionstore(key8, [key14, key15], AggregationType.MAX) args.append(3)