diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ed4e0b68e..bb59b8dcac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Node: Added ZRANGE command ([#1115](https://github.com/aws/glide-for-redis/pull/1115)) * Python: Added RENAME command ([#1252](https://github.com/aws/glide-for-redis/pull/1252)) * Python: Added APPEND command ([#1152](https://github.com/aws/glide-for-redis/pull/1152)) +* Python: Added GEOADD command ([#1259](https://github.com/aws/glide-for-redis/pull/1259)) #### Fixes * Python: Fix typing error "‘type’ object is not subscriptable" ([#1203](https://github.com/aws/glide-for-redis/pull/1203)) diff --git a/glide-core/src/protobuf/redis_request.proto b/glide-core/src/protobuf/redis_request.proto index 4369586428..00498119e1 100644 --- a/glide-core/src/protobuf/redis_request.proto +++ b/glide-core/src/protobuf/redis_request.proto @@ -160,6 +160,7 @@ enum RequestType { SMove = 117; SMIsMember = 118; LastSave = 120; + GeoAdd = 121; } message Command { diff --git a/glide-core/src/request_type.rs b/glide-core/src/request_type.rs index 5c1329fb2a..6bd14cd2f6 100644 --- a/glide-core/src/request_type.rs +++ b/glide-core/src/request_type.rs @@ -128,6 +128,7 @@ pub enum RequestType { SMove = 117, SMIsMember = 118, LastSave = 120, + GeoAdd = 121, } fn get_two_word_command(first: &str, second: &str) -> Cmd { @@ -259,6 +260,7 @@ impl From<::protobuf::EnumOrUnknown> for RequestType { ProtobufRequestType::SMove => RequestType::SMove, ProtobufRequestType::SMIsMember => RequestType::SMIsMember, ProtobufRequestType::LastSave => RequestType::LastSave, + ProtobufRequestType::GeoAdd => RequestType::GeoAdd, } } } @@ -386,6 +388,7 @@ impl RequestType { RequestType::SMove => Some(cmd("SMOVE")), RequestType::SMIsMember => Some(cmd("SMISMEMBER")), RequestType::LastSave => Some(cmd("LASTSAVE")), + RequestType::GeoAdd => Some(cmd("GEOADD")), } } } diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 2750f617fd..83828fea88 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -5,6 +5,7 @@ ExpireOptions, ExpirySet, ExpiryType, + GeospatialData, InfoSection, UpdateOptions, ) @@ -56,6 +57,7 @@ "RedisClientConfiguration", "ScoreBoundary", "ConditionalChange", + "GeospatialData", "ExpireOptions", "ExpirySet", "ExpiryType", diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 581948347e..39bc1c2604 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -33,7 +33,7 @@ class ConditionalChange(Enum): """ - A condition to the "SET" and "ZADD" commands. + A condition to the `SET`, `ZADD` and `GEOADD` commands. - ONLY_IF_EXISTS - Only update key / elements that already exist. Equivalent to `XX` in the Redis API - ONLY_IF_DOES_NOT_EXIST - Only set key / add elements that does not already exist. Equivalent to `NX` in the Redis API """ @@ -131,6 +131,23 @@ class UpdateOptions(Enum): GREATER_THAN = "GT" +class GeospatialData: + def __init__(self, longitude: float, latitude: float): + """ + Represents a geographic position defined by longitude and latitude. + + The exact limits, as specified by EPSG:900913 / EPSG:3785 / OSGEO:41001 are the following: + - Valid longitudes are from -180 to 180 degrees. + - Valid latitudes are from -85.05112878 to 85.05112878 degrees. + + Args: + longitude (float): The longitude coordinate. + latitude (float): The latitude coordinate. + """ + self.longitude = longitude + self.latitude = latitude + + class ExpirySet: """SET option: Represents the expiry type and value to be executed with "SET" command.""" @@ -1522,6 +1539,57 @@ async def type(self, key: str) -> str: """ return cast(str, await self._execute_command(RequestType.Type, [key])) + async def geoadd( + self, + key: str, + members_geospatialdata: Mapping[str, GeospatialData], + existing_options: Optional[ConditionalChange] = None, + changed: bool = False, + ) -> int: + """ + Adds geospatial members with their positions to the specified sorted set stored at `key`. + If a member is already a part of the sorted set, its position is updated. + + See https://valkey.io/commands/geoadd for more details. + + Args: + key (str): The key of the sorted set. + members_geospatialdata (Mapping[str, GeospatialData]): A mapping of member names to their corresponding positions. See `GeospatialData`. + The command will report an error when the user attempts to index coordinates outside the specified ranges. + existing_options (Optional[ConditionalChange]): Options for handling existing members. + - NX: Only add new elements. + - XX: Only update existing elements. + changed (bool): Modify the return value to return the number of changed elements, instead of the number of new elements added. + + Returns: + int: The number of elements added to the sorted set. + If `changed` is set, returns the number of elements updated in the sorted set. + + Examples: + >>> await client.geoadd("my_sorted_set", {"Palermo": GeospatialData(13.361389, 38.115556), "Catania": GeospatialData(15.087269, 37.502669)}) + 2 # Indicates that two elements have been added to the sorted set "my_sorted_set". + >>> await client.geoadd("my_sorted_set", {"Palermo": GeospatialData(14.361389, 38.115556)}, existing_options=ConditionalChange.XX, changed=True) + 1 # Updates the position of an existing member in the sorted set "my_sorted_set". + """ + args = [key] + if existing_options: + args.append(existing_options.value) + + if changed: + args.append("CH") + + members_geospatialdata_list = [ + coord + for member, position in members_geospatialdata.items() + for coord in [str(position.longitude), str(position.latitude), member] + ] + args += members_geospatialdata_list + + return cast( + int, + await self._execute_command(RequestType.GeoAdd, args), + ) + async def zadd( self, key: str, diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index adc979ec79..7445966e10 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -7,6 +7,7 @@ ConditionalChange, ExpireOptions, ExpirySet, + GeospatialData, InfoSection, UpdateOptions, ) @@ -1164,6 +1165,48 @@ def type(self: TTransaction, key: str) -> TTransaction: """ return self.append_command(RequestType.Type, [key]) + def geoadd( + self: TTransaction, + key: str, + members_geospatialdata: Mapping[str, GeospatialData], + existing_options: Optional[ConditionalChange] = None, + changed: bool = False, + ) -> TTransaction: + """ + Adds geospatial members with their positions to the specified sorted set stored at `key`. + If a member is already a part of the sorted set, its position is updated. + + See https://valkey.io/commands/geoadd for more details. + + Args: + key (str): The key of the sorted set. + members_geospatialdata (Mapping[str, GeospatialData]): A mapping of member names to their corresponding positions. See `GeospatialData`. + The command will report an error when the user attempts to index coordinates outside the specified ranges. + existing_options (Optional[ConditionalChange]): Options for handling existing members. + - NX: Only add new elements. + - XX: Only update existing elements. + changed (bool): Modify the return value to return the number of changed elements, instead of the number of new elements added. + + Commands response: + int: The number of elements added to the sorted set. + If `changed` is set, returns the number of elements updated in the sorted set. + """ + args = [key] + if existing_options: + args.append(existing_options.value) + + if changed: + args.append("CH") + + members_geospatialdata_list = [ + coord + for member, position in members_geospatialdata.items() + for coord in [str(position.longitude), str(position.latitude), member] + ] + args += members_geospatialdata_list + + return self.append_command(RequestType.GeoAdd, args) + def zadd( self: TTransaction, key: str, diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index ea4ec3ff95..a6fe0dda62 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -16,6 +16,7 @@ ExpireOptions, ExpirySet, ExpiryType, + GeospatialData, InfBound, InfoSection, UpdateOptions, @@ -1213,6 +1214,67 @@ async def test_persist(self, redis_client: TRedisClient): assert await redis_client.expire(key, 10) assert await redis_client.persist(key) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_geoadd(self, redis_client: TRedisClient): + key, key2 = get_random_string(10), get_random_string(10) + members_coordinates = { + "Palermo": GeospatialData(13.361389, 38.115556), + "Catania": GeospatialData(15.087269, 37.502669), + } + assert await redis_client.geoadd(key, members_coordinates) == 2 + members_coordinates["Catania"].latitude = 39 + assert ( + await redis_client.geoadd( + key, + members_coordinates, + existing_options=ConditionalChange.ONLY_IF_DOES_NOT_EXIST, + ) + == 0 + ) + assert ( + await redis_client.geoadd( + key, + members_coordinates, + existing_options=ConditionalChange.ONLY_IF_EXISTS, + ) + == 0 + ) + members_coordinates["Catania"].latitude = 40 + members_coordinates.update({"Tel-Aviv": GeospatialData(32.0853, 34.7818)}) + assert ( + await redis_client.geoadd( + key, + members_coordinates, + changed=True, + ) + == 2 + ) + + assert await redis_client.set(key2, "value") == OK + with pytest.raises(RequestError): + await redis_client.geoadd(key2, members_coordinates) + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_geoadd_invalid_args(self, redis_client: TRedisClient): + key = get_random_string(10) + + with pytest.raises(RequestError): + await redis_client.geoadd(key, {}) + + with pytest.raises(RequestError): + await redis_client.geoadd(key, {"Place": GeospatialData(-181, 0)}) + + with pytest.raises(RequestError): + await redis_client.geoadd(key, {"Place": GeospatialData(181, 0)}) + + with pytest.raises(RequestError): + await redis_client.geoadd(key, {"Place": GeospatialData(0, 86)}) + + with pytest.raises(RequestError): + await redis_client.geoadd(key, {"Place": GeospatialData(0, -86)}) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_zadd_zaddincr(self, redis_client: TRedisClient): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index 0fab703f50..ca60e05779 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -5,6 +5,7 @@ import pytest from glide import RequestError +from glide.async_commands.core import GeospatialData from glide.async_commands.sorted_set import InfBound, RangeByIndex, ScoreBoundary from glide.async_commands.transaction import ( BaseTransaction, @@ -198,6 +199,15 @@ async def transaction_test( args.append({"four": 4}) transaction.zremrangebyscore(key8, InfBound.NEG_INF, InfBound.POS_INF) args.append(1) + + transaction.geoadd( + key9, + { + "Palermo": GeospatialData(13.361389, 38.115556), + "Catania": GeospatialData(15.087269, 37.502669), + }, + ) + args.append(2) return args