From 905a17b4fa9dec0ed3f9ad08b26437b97ba39323 Mon Sep 17 00:00:00 2001 From: Avi Fenesh <55848801+avifenesh@users.noreply.github.com> Date: Tue, 2 Jul 2024 20:44:12 +0300 Subject: [PATCH] Cluster Scan - glide-core and python (#1623) Scan command for Glide-Core and Py --- .github/workflows/lint-rust/action.yml | 2 + CHANGELOG.md | 1 + glide-core/Cargo.toml | 1 + glide-core/src/client/mod.rs | 49 +- glide-core/src/cluster_scan_container.rs | 65 +++ glide-core/src/lib.rs | 1 + glide-core/src/protobuf/redis_request.proto | 11 +- glide-core/src/request_type.rs | 3 + glide-core/src/socket_listener.rs | 85 ++- python/DEVELOPER.md | 2 + python/python/glide/__init__.py | 3 +- .../glide/async_commands/cluster_commands.py | 73 ++- .../glide/async_commands/command_args.py | 36 ++ python/python/glide/async_commands/core.py | 12 +- .../async_commands/standalone_commands.py | 67 ++- python/python/glide/glide.pyi | 5 + python/python/glide/glide_client.py | 31 ++ python/python/tests/test_scan.py | 485 ++++++++++++++++++ python/src/lib.rs | 39 +- submodules/redis-rs | 2 +- 20 files changed, 948 insertions(+), 25 deletions(-) create mode 100644 glide-core/src/cluster_scan_container.rs create mode 100644 python/python/tests/test_scan.py diff --git a/.github/workflows/lint-rust/action.yml b/.github/workflows/lint-rust/action.yml index 531acae1be..8a7cdf185f 100644 --- a/.github/workflows/lint-rust/action.yml +++ b/.github/workflows/lint-rust/action.yml @@ -24,6 +24,8 @@ runs: github-token: ${{ inputs.github-token }} - uses: Swatinem/rust-cache@v2 + with: + github-token: ${{ inputs.github-token }} - run: cargo fmt --all -- --check working-directory: ${{ inputs.cargo-toml-folder }} diff --git a/CHANGELOG.md b/CHANGELOG.md index b3fb2ae6e4..8ae99d7c34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,7 @@ * Python: Added FCALL_RO command ([#1721](https://github.com/aws/glide-for-redis/pull/1721)) * Python: Added WATCH and UNWATCH command ([#1736](https://github.com/aws/glide-for-redis/pull/1736)) * Python: Added LPos command ([#1740](https://github.com/aws/glide-for-redis/pull/1740)) +* Python: Added SCAN command ([#1623](https://github.com/aws/glide-for-redis/pull/1623)) ### Breaking Changes * Node: Update XREAD to return a Map of Map ([#1494](https://github.com/aws/glide-for-redis/pull/1494)) diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index 66fae92c9d..0f4a56330d 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -26,6 +26,7 @@ directories = { version = "4.0", optional = true } once_cell = "1.18.0" arcstr = "1.1.5" sha1_smol = "1.0.0" +nanoid = "0.4.0" [features] socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util"] diff --git a/glide-core/src/client/mod.rs b/glide-core/src/client/mod.rs index 03f4d5b75c..610d228f93 100644 --- a/glide-core/src/client/mod.rs +++ b/glide-core/src/client/mod.rs @@ -3,14 +3,14 @@ */ mod types; +use crate::cluster_scan_container::insert_cluster_scan_cursor; use crate::scripts_container::get_script; use futures::FutureExt; use logger_core::log_info; use redis::aio::ConnectionLike; use redis::cluster_async::ClusterConnection; use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo}; -use redis::{Cmd, ErrorKind, PushInfo, Value}; -use redis::{RedisError, RedisResult}; +use redis::{Cmd, ErrorKind, ObjectType, PushInfo, RedisError, RedisResult, ScanStateRC, Value}; pub use standalone_client::StandaloneClient; use std::io; use std::time::Duration; @@ -28,6 +28,7 @@ pub const DEFAULT_RESPONSE_TIMEOUT: Duration = Duration::from_millis(250); pub const DEFAULT_CONNECTION_ATTEMPT_TIMEOUT: Duration = Duration::from_millis(250); pub const DEFAULT_PERIODIC_CHECKS_INTERVAL: Duration = Duration::from_secs(60); pub const INTERNAL_CONNECTION_TIMEOUT: Duration = Duration::from_millis(250); +pub const FINISHED_SCAN_CURSOR: &str = "finished"; pub(super) fn get_port(address: &NodeAddress) -> u16 { const DEFAULT_PORT: u16 = 6379; @@ -245,6 +246,50 @@ impl Client { .boxed() } + // Cluster scan is not passed to redis-rs as a regular command, so we need to handle it separately. + // We send the command to a specific function in the redis-rs cluster client, which internally handles the + // the complication of a command scan, and generate the command base on the logic in the redis-rs library. + // + // The function returns a tuple with the cursor and the keys found in the scan. + // The cursor is not a regular cursor, but an ARC to a struct that contains the cursor and the data needed + // to continue the scan called ScanState. + // In order to avoid passing Rust GC to clean the ScanState when the cursor (ref) is passed to the wrapper, + // which means that Rust layer is not aware of the cursor anymore, we need to keep the ScanState alive. + // We do that by storing the ScanState in a global container, and return a cursor-id of the cursor to the wrapper. + // + // The wrapper create an object contain the cursor-id with a drop function that will remove the cursor from the container. + // When the ref is removed from the hash-map, there's no more references to the ScanState, and the GC will clean it. + pub async fn cluster_scan<'a>( + &'a mut self, + scan_state_cursor: &'a ScanStateRC, + match_pattern: &'a Option<&str>, + count: Option, + object_type: Option, + ) -> RedisResult { + match self.internal_client { + ClientWrapper::Standalone(_) => { + unreachable!("Cluster scan is not supported in standalone mode") + } + ClientWrapper::Cluster { ref mut client } => { + let (cursor, keys) = client + .cluster_scan( + scan_state_cursor.clone(), + *match_pattern, + count, + object_type, + ) + .await?; + + let cluster_cursor_id = if cursor.is_finished() { + Value::BulkString(FINISHED_SCAN_CURSOR.into()) + } else { + Value::BulkString(insert_cluster_scan_cursor(cursor).into()) + }; + Ok(Value::Array(vec![cluster_cursor_id, Value::Array(keys)])) + } + } + } + fn get_transaction_values( pipeline: &redis::Pipeline, mut values: Vec, diff --git a/glide-core/src/cluster_scan_container.rs b/glide-core/src/cluster_scan_container.rs new file mode 100644 index 0000000000..20c464a186 --- /dev/null +++ b/glide-core/src/cluster_scan_container.rs @@ -0,0 +1,65 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ +use logger_core::log_debug; +use nanoid::nanoid; +use once_cell::sync::Lazy; +use redis::{RedisResult, ScanStateRC}; +use std::{collections::HashMap, sync::Mutex}; + +// This is a container for storing the cursor of a cluster scan. +// The cursor for a cluster scan is a ref to the actual ScanState struct in redis-rs. +// In order to avoid dropping it when it is passed between layers of the application, +// we store it in this container and only pass the id of the cursor. +// The cursor is stored in the container and can be retrieved using the id. +// In wrapper layer we wrap the id in an object, which, when dropped, trigger the removal of the cursor from the container. +// When the ref is removed from the container, the actual ScanState struct is dropped by Rust GC. + +static CONTAINER: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +pub fn insert_cluster_scan_cursor(scan_state: ScanStateRC) -> String { + let id = nanoid!(); + CONTAINER.lock().unwrap().insert(id.clone(), scan_state); + log_debug( + "scan_state_cursor insert", + format!( + "Inserted to container scan_state_cursor with id: `{:?}`", + id + ), + ); + id +} + +pub fn get_cluster_scan_cursor(id: String) -> RedisResult { + let scan_state_rc = CONTAINER.lock().unwrap().get(&id).cloned(); + log_debug( + "scan_state_cursor get", + format!( + "Retrieved from container scan_state_cursor with id: `{:?}`", + id + ), + ); + match scan_state_rc { + Some(scan_state_rc) => Ok(scan_state_rc), + None => Err(redis::RedisError::from(( + redis::ErrorKind::ResponseError, + "Invalid scan_state_cursor id", + format!( + "The scan_state_cursor sent with id: `{:?}` does not exist", + id + ), + ))), + } +} + +pub fn remove_scan_state_cursor(id: String) { + log_debug( + "scan_state_cursor remove", + format!( + "Removed from container scan_state_cursor with id: `{:?}`", + id + ), + ); + CONTAINER.lock().unwrap().remove(&id); +} diff --git a/glide-core/src/lib.rs b/glide-core/src/lib.rs index 5bbc431e82..8da08e99f9 100644 --- a/glide-core/src/lib.rs +++ b/glide-core/src/lib.rs @@ -15,4 +15,5 @@ pub use socket_listener::*; pub mod errors; pub mod scripts_container; pub use client::ConnectionRequest; +pub mod cluster_scan_container; pub mod request_type; diff --git a/glide-core/src/protobuf/redis_request.proto b/glide-core/src/protobuf/redis_request.proto index dc1df57495..367b51dac3 100644 --- a/glide-core/src/protobuf/redis_request.proto +++ b/glide-core/src/protobuf/redis_request.proto @@ -244,6 +244,7 @@ enum RequestType { XAutoClaim = 203; Wait = 208; XClaim = 209; + Scan = 210; } message Command { @@ -268,6 +269,13 @@ message Transaction { repeated Command commands = 1; } +message ClusterScan { + string cursor = 1; + optional string match_pattern = 2; + optional int64 count = 3; + optional string object_type = 4; +} + message RedisRequest { uint32 callback_idx = 1; @@ -275,6 +283,7 @@ message RedisRequest { Command single_command = 2; Transaction transaction = 3; ScriptInvocation script_invocation = 4; + ClusterScan cluster_scan = 5; } - Routes route = 5; + Routes route = 6; } diff --git a/glide-core/src/request_type.rs b/glide-core/src/request_type.rs index 7064e9a0ce..3602f9bc68 100644 --- a/glide-core/src/request_type.rs +++ b/glide-core/src/request_type.rs @@ -214,6 +214,7 @@ pub enum RequestType { XAutoClaim = 203, Wait = 208, XClaim = 209, + Scan = 210, } fn get_two_word_command(first: &str, second: &str) -> Cmd { @@ -431,6 +432,7 @@ impl From<::protobuf::EnumOrUnknown> for RequestType { ProtobufRequestType::XAutoClaim => RequestType::XAutoClaim, ProtobufRequestType::Wait => RequestType::Wait, ProtobufRequestType::XClaim => RequestType::XClaim, + ProtobufRequestType::Scan => RequestType::Scan, } } } @@ -646,6 +648,7 @@ impl RequestType { RequestType::XAutoClaim => Some(cmd("XAUTOCLAIM")), RequestType::Wait => Some(cmd("WAIT")), RequestType::XClaim => Some(cmd("XCLAIM")), + RequestType::Scan => Some(cmd("SCAN")), } } } diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index a34cafce73..24f5c425e9 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -3,10 +3,12 @@ */ use super::rotating_buffer::RotatingBuffer; use crate::client::Client; +use crate::cluster_scan_container::get_cluster_scan_cursor; use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::redis_request::{ - command, redis_request, Command, RedisRequest, Routes, ScriptInvocation, SlotTypes, Transaction, + command, redis_request, ClusterScan, Command, RedisRequest, Routes, ScriptInvocation, + SlotTypes, Transaction, }; use crate::response; use crate::response::Response; @@ -20,8 +22,7 @@ use redis::cluster_routing::{ MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, }; use redis::cluster_routing::{ResponsePolicy, Routable}; -use redis::RedisError; -use redis::{Cmd, PushInfo, Value}; +use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value}; use std::cell::Cell; use std::rc::Rc; use std::{env, str}; @@ -46,6 +47,13 @@ const SOCKET_FILE_NAME: &str = "glide-socket"; /// strings instead of a pointer pub const MAX_REQUEST_ARGS_LENGTH: usize = 2_i32.pow(12) as usize; // TODO: find the right number +pub const STRING: &str = "string"; +pub const LIST: &str = "list"; +pub const SET: &str = "set"; +pub const ZSET: &str = "zset"; +pub const HASH: &str = "hash"; +pub const STREAM: &str = "stream"; + /// struct containing all objects needed to bind to a socket and clean it. struct SocketListener { socket_path: String, @@ -201,13 +209,13 @@ async fn write_result( None } } - Err(ClienUsageError::Internal(error_message)) => { + Err(ClientUsageError::Internal(error_message)) => { log_error("internal error", &error_message); Some(response::response::Value::ClosingError( error_message.into(), )) } - Err(ClienUsageError::User(error_message)) => { + Err(ClientUsageError::User(error_message)) => { log_error("user error", &error_message); let request_error = response::RequestError { type_: response::RequestErrorType::Unspecified.into(), @@ -216,7 +224,7 @@ async fn write_result( }; Some(response::response::Value::RequestError(request_error)) } - Err(ClienUsageError::Redis(err)) => { + Err(ClientUsageError::Redis(err)) => { let error_message = error_message(&err); log_warn("received error", error_message.as_str()); log_debug("received error", format!("for callback {}", callback_index)); @@ -264,9 +272,9 @@ fn get_command(request: &Command) -> Option { request_type.get_command() } -fn get_redis_command(command: &Command) -> Result { +fn get_redis_command(command: &Command) -> Result { let Some(mut cmd) = get_command(command) else { - return Err(ClienUsageError::Internal(format!( + return Err(ClientUsageError::Internal(format!( "Received invalid request type: {:?}", command.request_type ))); @@ -285,14 +293,14 @@ fn get_redis_command(command: &Command) -> Result { } } None => { - return Err(ClienUsageError::Internal( + return Err(ClientUsageError::Internal( "Failed to get request arguments, no arguments are set".to_string(), )); } }; if cmd.args_iter().next().is_none() { - return Err(ClienUsageError::User( + return Err(ClientUsageError::User( "Received command without a command name or arguments".into(), )); } @@ -311,6 +319,50 @@ async fn send_command( .map_err(|err| err.into()) } +// Parse the cluster scan command parameters from protobuf and send the command to redis-rs. +async fn cluster_scan(cluster_scan: ClusterScan, mut client: Client) -> ClientUsageResult { + // Since we don't send the cluster scan as a usual command, but through a special function in redis-rs library, + // we need to handle the command separately. + // Specifically, we need to handle the cursor, which is not the cursor returned from the server, + // but the ID of the ScanStateRC, stored in the cluster scan container. + // We need to get the ref from the table or create a new one if the cursor is empty. + let cursor: String = cluster_scan.cursor.into(); + let cluster_scan_cursor = if cursor.is_empty() { + ScanStateRC::new() + } else { + get_cluster_scan_cursor(cursor)? + }; + + let match_pattern_string = cluster_scan + .match_pattern + .map(|pattern| pattern.to_string()); + let match_pattern = match_pattern_string.as_deref(); + let count = cluster_scan.count.map(|count| count as usize); + + let object_type = match cluster_scan.object_type { + Some(char_object_type) => match char_object_type.to_string().to_lowercase().as_str() { + STRING => Some(redis::ObjectType::String), + LIST => Some(redis::ObjectType::List), + SET => Some(redis::ObjectType::Set), + ZSET => Some(redis::ObjectType::ZSet), + HASH => Some(redis::ObjectType::Hash), + STREAM => Some(redis::ObjectType::Stream), + _ => { + return Err(ClientUsageError::Internal(format!( + "Received invalid object type: {:?}", + char_object_type + ))) + } + }, + None => None, + }; + + client + .cluster_scan(&cluster_scan_cursor, &match_pattern, count, object_type) + .await + .map_err(|err| err.into()) +} + async fn invoke_script( script: ScriptInvocation, mut client: Client, @@ -349,7 +401,7 @@ fn get_slot_addr(slot_type: &protobuf::EnumOrUnknown) -> ClientUsageR SlotTypes::Primary => SlotAddr::Master, SlotTypes::Replica => SlotAddr::ReplicaRequired, }) - .map_err(|id| ClienUsageError::Internal(format!("Received unexpected slot id type {id}"))) + .map_err(|id| ClientUsageError::Internal(format!("Received unexpected slot id type {id}"))) } fn get_route( @@ -369,7 +421,7 @@ fn get_route( match route { Value::SimpleRoutes(simple_route) => { let simple_route = simple_route.enum_value().map_err(|id| { - ClienUsageError::Internal(format!("Received unexpected simple route type {id}")) + ClientUsageError::Internal(format!("Received unexpected simple route type {id}")) })?; match simple_route { crate::redis_request::SimpleRoutes::AllNodes => Ok(Some(RoutingInfo::MultiNode(( @@ -418,6 +470,9 @@ fn handle_request(request: RedisRequest, client: Client, writer: Rc) { task::spawn_local(async move { let result = match request.command { Some(action) => match action { + redis_request::Command::ClusterScan(cluster_scan_command) => { + cluster_scan(cluster_scan_command, client).await + } redis_request::Command::SingleCommand(command) => { match get_redis_command(&command) { Ok(cmd) => match get_route(request.route.0, Some(&cmd)) { @@ -448,7 +503,7 @@ fn handle_request(request: RedisRequest, client: Client, writer: Rc) { request.callback_idx ), ); - Err(ClienUsageError::Internal( + Err(ClientUsageError::Internal( "Received empty request".to_string(), )) } @@ -766,7 +821,7 @@ enum ClientCreationError { /// Enum describing errors received during client usage. #[derive(Debug, Error)] -enum ClienUsageError { +enum ClientUsageError { #[error("Redis error: {0}")] Redis(#[from] RedisError), /// An error that stems from wrong behavior of the client. @@ -777,7 +832,7 @@ enum ClienUsageError { User(String), } -type ClientUsageResult = Result; +type ClientUsageResult = Result; /// Defines errors caused the connection to close. #[derive(Debug, Clone)] diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index 58918ce352..6983d543fb 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -33,6 +33,8 @@ source "$HOME/.cargo/env" rustc --version # Install protobuf compiler PB_REL="https://github.com/protocolbuffers/protobuf/releases" +# For other arch type from x86 example below, the signature of the curl url should be protoc---.zip, +# e.g. protoc-3.20.3-linux-aarch_64.zip for ARM64. curl -LO $PB_REL/download/v3.20.3/protoc-3.20.3-linux-x86_64.zip unzip protoc-3.20.3-linux-x86_64.zip -d $HOME/.local export PATH="$PATH:$HOME/.local/bin" diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 2e89d3cc1a..79755af209 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -97,7 +97,7 @@ SlotType, ) -from .glide import Script +from .glide import ClusterScanCursor, Script __all__ = [ # Client @@ -175,6 +175,7 @@ "TrimByMaxLen", "TrimByMinId", "UpdateOptions", + "ClusterScanCursor" # Logger "Logger", "LogLevel", diff --git a/python/python/glide/async_commands/cluster_commands.py b/python/python/glide/async_commands/cluster_commands.py index 867e4910ce..e584ec155a 100644 --- a/python/python/glide/async_commands/cluster_commands.py +++ b/python/python/glide/async_commands/cluster_commands.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set, Union, cast -from glide.async_commands.command_args import Limit, OrderBy +from glide.async_commands.command_args import Limit, ObjectType, OrderBy from glide.async_commands.core import ( CoreCommands, FlushMode, @@ -23,6 +23,8 @@ from glide.protobuf.redis_request_pb2 import RequestType from glide.routes import Route +from ..glide import ClusterScanCursor + class ClusterCommands(CoreCommands): async def custom_command( @@ -945,3 +947,72 @@ async def unwatch(self, route: Optional[Route] = None) -> TOK: TOK, await self._execute_command(RequestType.UnWatch, [], route), ) + + async def scan( + self, + cursor: ClusterScanCursor, + match: Optional[TEncodable] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[ClusterScanCursor, List[bytes]]]: + """ + Incrementally iterates over the keys in the Cluster. + The method returns a list containing the next cursor and a list of keys. + + This command is similar to the SCAN command, but it is designed to work in a Cluster environment. + For each iteration the new cursor object should be used to continue the scan. + Using the same cursor object for multiple iterations will result in the same keys or unexpected behavior. + For more information about the Cluster Scan implementation, + see [Cluster Scan](https://github.com/aws/glide-for-redis/wiki/General-Concepts#cluster-scan). + + As the SCAN command, the method can be used to iterate over the keys in the database, + to return all keys the database have from the time the scan started till the scan ends. + The same key can be returned in multiple scans iteration. + + See https://valkey.io/commands/scan/ for more details. + + Args: + cursor (ClusterScanCursor): The cursor object that wraps the scan state. + To start a new scan, create a new empty ClusterScanCursor using ClusterScanCursor(). + match (Optional[TEncodable]): A pattern to match keys against. + count (Optional[int]): The number of keys to return in a single iteration. + The actual number returned can vary and is not guaranteed to match this count exactly. + This parameter serves as a hint to the server on the number of steps to perform in each iteration. + The default value is 10. + type (Optional[ObjectType]): The type of object to scan for. + + Returns: + List[Union[ClusterScanCursor, List[TEncodable]]]: A list containing the next cursor and a list of keys, + formatted as [ClusterScanCursor, [key1, key2, ...]]. + + Examples: + >>> # In the following example, we will iterate over the keys in the cluster. + await redis_client.mset({b'key1': b'value1', b'key2': b'value2', b'key3': b'value3'}) + cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await redis_client.scan(cursor, count=10) + all_keys.extend(keys) + print(all_keys) # [b'key1', b'key2', b'key3'] + >>> # In the following example, we will iterate over the keys in the cluster that match the pattern "*key*". + await redis_client.mset({b"key1": b"value1", b"key2": b"value2", b"not_my_key": b"value3", b"something_else": b"value4"}) + cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await redis_client.scan(cursor, match=b"*key*", count=10) + all_keys.extend(keys) + print(all_keys) # [b'my_key1', b'my_key2', b'not_my_key'] + >>> # In the following example, we will iterate over the keys in the cluster that are of type STRING. + await redis_client.mset({b'key1': b'value1', b'key2': b'value2', b'key3': b'value3'}) + await redis_client.sadd(b"this_is_a_set", [b"value4"]) + cursor = ClusterScanCursor() + all_keys = [] + while not cursor.is_finished(): + cursor, keys = await redis_client.scan(cursor, type=ObjectType.STRING) + all_keys.extend(keys) + print(all_keys) # [b'key1', b'key2', b'key3'] + """ + return cast( + List[Union[ClusterScanCursor, List[bytes]]], + await self._cluster_scan(cursor, match, count, type), + ) diff --git a/python/python/glide/async_commands/command_args.py b/python/python/glide/async_commands/command_args.py index ce76fd2d55..2b00665a7d 100644 --- a/python/python/glide/async_commands/command_args.py +++ b/python/python/glide/async_commands/command_args.py @@ -63,3 +63,39 @@ class ListDirection(Enum): """ RIGHT: Represents the option that elements should be popped from or added to the right side of a list. """ + + +class ObjectType(Enum): + """ + Enumeration representing the data types supported by the database. + """ + + STRING = "String" + """ + Represents a string data type. + """ + + LIST = "List" + """ + Represents a list data type. + """ + + SET = "Set" + """ + Represents a set data type. + """ + + ZSET = "ZSet" + """ + Represents a sorted set data type. + """ + + HASH = "Hash" + """ + Represents a hash data type. + """ + + STREAM = "Stream" + """ + Represents a stream data type. + """ diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index 431d2dc9ed..a42caea2d6 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -25,7 +25,7 @@ _create_bitfield_args, _create_bitfield_read_only_args, ) -from glide.async_commands.command_args import Limit, ListDirection, OrderBy +from glide.async_commands.command_args import Limit, ListDirection, ObjectType, OrderBy from glide.async_commands.sorted_set import ( AggregationType, GeoSearchByBox, @@ -58,7 +58,7 @@ from glide.protobuf.redis_request_pb2 import RequestType from glide.routes import Route -from ..glide import Script +from ..glide import ClusterScanCursor, Script class ConditionalChange(Enum): @@ -361,6 +361,14 @@ async def _execute_script( route: Optional[Route] = None, ) -> TResult: ... + async def _cluster_scan( + self, + cursor: ClusterScanCursor, + match: Optional[TEncodable] = ..., + count: Optional[int] = ..., + type: Optional[ObjectType] = ..., + ) -> TResult: ... + async def set( self, key: TEncodable, diff --git a/python/python/glide/async_commands/standalone_commands.py b/python/python/glide/async_commands/standalone_commands.py index f7cb65088a..2ca870000e 100644 --- a/python/python/glide/async_commands/standalone_commands.py +++ b/python/python/glide/async_commands/standalone_commands.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set, Union, cast -from glide.async_commands.command_args import Limit, OrderBy +from glide.async_commands.command_args import Limit, ObjectType, OrderBy from glide.async_commands.core import ( CoreCommands, FlushMode, @@ -765,3 +765,68 @@ async def unwatch(self) -> TOK: TOK, await self._execute_command(RequestType.UnWatch, []), ) + + async def scan( + self, + cursor: TEncodable, + match: Optional[TEncodable] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[bytes, List[bytes]]]: + """ + Incrementally iterate over a collection of keys. + SCAN is a cursor based iterator. This means that at every call of the command, + the server returns an updated cursor that the user needs to use as the cursor argument in the next call. + An iteration starts when the cursor is set to "0", and terminates when the cursor returned by the server is "0". + + A full iteration always retrieves all the elements that were present + in the collection from the start to the end of a full iteration. + Elements that were not constantly present in the collection during a full iteration, may be returned or not. + + See https://valkey.io/commands/scan for more details. + + Args: + cursor (TResult): The cursor used for iteration. For the first iteration, the cursor should be set to "0". + Using a non-zero cursor in the first iteration, + or an invalid cursor at any iteration, will lead to undefined results. + Using the same cursor in multiple iterations will, in case nothing changed between the iterations, + return the same elements multiple times. + If the the db has changed, it may result an undefined behavior. + match (Optional[TResult]): A pattern to match keys against. + count (Optional[int]): The number of keys to return per iteration. + The number of keys returned per iteration is not guaranteed to be the same as the count argument. + the argument is used as a hint for the server to know how many "steps" it can use to retrieve the keys. + The default value is 10. + type (ObjectType): The type of object to scan for. + + Returns: + List[Union[bytes, List[bytes]]]: A List containing the next cursor value and a list of keys, + formatted as [cursor, [key1, key2, ...]] + + Examples: + >>> result = await client.scan(b'0') + print(result) #[b'17', [b'key1', b'key2', b'key3', b'key4', b'key5', b'set1', b'set2', b'set3']] + first_cursor_result = result[0] + result = await client.scan(first_cursor_result) + print(result) #[b'349', [b'key4', b'key5', b'set1', b'hash1', b'zset1', b'list1', b'list2', + b'list3', b'zset2', b'zset3', b'zset4', b'zset5', b'zset6']] + result = await client.scan(result[0]) + print(result) #[b'0', [b'key6', b'key7']] + + >>> result = await client.scan(first_cursor_result, match=b'key*', count=2) + print(result) #[b'6', [b'key4', b'key5']] + + >>> result = await client.scan("0", type=ObjectType.Set) + print(result) #[b'362', [b'set1', b'set2', b'set3']] + """ + args = [cursor] + if match: + args.extend(["MATCH", match]) + if count: + args.extend(["COUNT", str(count)]) + if type: + args.extend(["TYPE", type.value]) + return cast( + List[Union[bytes, List[bytes]]], + await self._execute_command(RequestType.Scan, args), + ) diff --git a/python/python/glide/glide.pyi b/python/python/glide/glide.pyi index 8964329a2b..ee053b1e44 100644 --- a/python/python/glide/glide.pyi +++ b/python/python/glide/glide.pyi @@ -21,6 +21,11 @@ class Script: def get_hash(self) -> str: ... def __del__(self) -> None: ... +class ClusterScanCursor: + def __init__(self, cursor: Optional[str] = None) -> None: ... + def get_cursor(self) -> str: ... + def is_finished(self) -> bool: ... + def start_socket_listener_external(init_callback: Callable) -> None: ... def value_from_pointer(pointer: int) -> TResult: ... def create_leaked_value(message: str) -> int: ... diff --git a/python/python/glide/glide_client.py b/python/python/glide/glide_client.py index 57b36f4455..06740fd151 100644 --- a/python/python/glide/glide_client.py +++ b/python/python/glide/glide_client.py @@ -7,6 +7,7 @@ import async_timeout from glide.async_commands.cluster_commands import ClusterCommands +from glide.async_commands.command_args import ObjectType from glide.async_commands.core import CoreCommands from glide.async_commands.standalone_commands import StandaloneCommands from glide.config import BaseClientConfiguration @@ -31,6 +32,7 @@ from .glide import ( DEFAULT_TIMEOUT_IN_MILLISECONDS, MAX_REQUEST_ARGS_LEN, + ClusterScanCursor, create_leaked_bytes_vec, start_socket_listener_external, value_from_pointer, @@ -525,6 +527,35 @@ class GlideClusterClient(BaseClient, ClusterCommands): https://github.com/aws/babushka/wiki/Python-wrapper#redis-cluster """ + async def _cluster_scan( + self, + cursor: ClusterScanCursor, + match: Optional[TEncodable] = None, + count: Optional[int] = None, + type: Optional[ObjectType] = None, + ) -> List[Union[ClusterScanCursor, List[bytes]]]: + if self._is_closed: + raise ClosingError( + "Unable to execute requests; the client is closed. Please create a new client." + ) + request = RedisRequest() + request.callback_idx = self._get_callback_index() + # Take out the id string from the wrapping object + cursor_string = cursor.get_cursor() + request.cluster_scan.cursor = cursor_string + if match is not None: + request.cluster_scan.match_pattern = ( + match + if isinstance(match, str) + else match.decode() if isinstance(match, bytes) else match + ) + if count is not None: + request.cluster_scan.count = count + if type is not None: + request.cluster_scan.object_type = type.value + response = await self._write_request_await_response(request) + return [ClusterScanCursor(bytes(response[0]).decode()), response[1]] + def _get_protobuf_conn_request(self) -> ConnectionRequest: return self.config._create_a_protobuf_conn_request(cluster_mode=True) diff --git a/python/python/tests/test_scan.py b/python/python/tests/test_scan.py new file mode 100644 index 0000000000..7d79eb36a1 --- /dev/null +++ b/python/python/tests/test_scan.py @@ -0,0 +1,485 @@ +from __future__ import annotations + +from typing import cast + +import pytest +from glide import ClusterScanCursor +from glide.async_commands.command_args import ObjectType +from glide.config import ProtocolVersion +from glide.exceptions import RequestError +from glide.glide_client import GlideClient, GlideClusterClient +from tests.utils.utils import get_random_string + + +@pytest.mark.asyncio +class TestScan: + # Cluster scan tests + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_simple(self, redis_client: GlideClusterClient): + key = get_random_string(10) + expected_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + expected_keys_encoded = map(lambda k: k.encode(), expected_keys) + cursor = ClusterScanCursor() + keys: list[str] = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = cast(list[str], result[1]) + keys.extend(result_keys) + + assert set(expected_keys_encoded) == set(keys) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_object_type_and_pattern( + self, redis_client: GlideClusterClient + ): + key = get_random_string(10) + expected_keys = [f"key:{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + unexpected_type_keys = [f"{key}:{i}" for i in range(100, 200)] + for key in unexpected_type_keys: + await redis_client.sadd(key, ["value"]) + encoded_unexpected_type_keys = map(lambda k: k.encode(), unexpected_type_keys) + unexpected_pattern_keys = [f"{i}" for i in range(200, 300)] + await redis_client.mset({k: "value" for k in unexpected_pattern_keys}) + encoded_unexpected_pattern_keys = map( + lambda k: k.encode(), unexpected_pattern_keys + ) + keys: list[str] = [] + cursor = ClusterScanCursor() + while not cursor.is_finished(): + result = await redis_client.scan( + cursor, match=b"key:*", type=ObjectType.STRING + ) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = cast(list[str], result[1]) + keys.extend(result_keys) + + assert set(encoded_expected_keys) == set(keys) + assert not set(encoded_unexpected_type_keys).intersection(set(keys)) + assert not set(encoded_unexpected_pattern_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_count(self, redis_client: GlideClusterClient): + key = get_random_string(10) + expected_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + cursor = ClusterScanCursor() + keys: list[str] = [] + successful_compared_scans = 0 + while not cursor.is_finished(): + result_of_1 = await redis_client.scan(cursor, count=1) + cursor = cast(ClusterScanCursor, result_of_1[0]) + result_keys_of_1 = cast(list[str], result_of_1[1]) + keys.extend(result_keys_of_1) + if cursor.is_finished(): + break + result_of_100 = await redis_client.scan(cursor, count=100) + cursor = cast(ClusterScanCursor, result_of_100[0]) + result_keys_of_100 = cast(list[str], result_of_100[1]) + keys.extend(result_keys_of_100) + if len(result_keys_of_100) > len(result_keys_of_1): + successful_compared_scans += 1 + + assert set(encoded_expected_keys) == set(keys) + assert successful_compared_scans > 0 + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_with_match(self, redis_client: GlideClusterClient): + unexpected_keys = [f"{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in unexpected_keys}) + encoded_unexpected_keys = map(lambda k: k.encode(), unexpected_keys) + key = get_random_string(10) + expected_keys = [f"key:{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + cursor = ClusterScanCursor() + keys: list[str] = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, match="key:*") + cursor = cast(ClusterScanCursor, result[0]) + result_keys = cast(list[str], result[1]) + keys.extend(result_keys) + assert set(encoded_expected_keys) == set(keys) + assert not set(encoded_unexpected_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + # We test whether the cursor is cleaned up after it is deleted, which we expect to happen when th GC is called + async def test_cluster_scan_cleaning_cursor(self, redis_client: GlideClusterClient): + key = get_random_string(10) + await redis_client.mset( + {k: "value" for k in [f"{key}:{i}" for i in range(100)]} + ) + cursor = cast( + ClusterScanCursor, (await redis_client.scan(ClusterScanCursor()))[0] + ) + cursor_string = cursor.get_cursor() + print(cursor_string) + del cursor + new_cursor_with_same_id = ClusterScanCursor(cursor_string) + with pytest.raises(RequestError) as e_info: + await redis_client.scan(new_cursor_with_same_id) + print(new_cursor_with_same_id) + print(new_cursor_with_same_id.get_cursor()) + assert "Invalid scan_state_cursor id" in str(e_info.value) + + @pytest.mark.parametrize("cluster_mode", [True]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_cluster_scan_all_types(self, redis_client: GlideClusterClient): + # We test that the scan command work for all types of keys + key = get_random_string(10) + string_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in string_keys}) + encoded_string_keys = list(map(lambda k: k.encode(), string_keys)) + + set_key = get_random_string(10) + set_keys = [f"{set_key}:{i}" for i in range(100, 200)] + for key in set_keys: + await redis_client.sadd(key, ["value"]) + encoded_set_keys = list(map(lambda k: k.encode(), set_keys)) + + hash_key = get_random_string(10) + hash_keys = [f"{hash_key}:{i}" for i in range(200, 300)] + for key in hash_keys: + await redis_client.hset(key, {"field": "value"}) + encoded_hash_keys = list(map(lambda k: k.encode(), hash_keys)) + + list_key = get_random_string(10) + list_keys = [f"{list_key}:{i}" for i in range(300, 400)] + for key in list_keys: + await redis_client.lpush(key, ["value"]) + encoded_list_keys = list(map(lambda k: k.encode(), list_keys)) + + zset_key = get_random_string(10) + zset_keys = [f"{zset_key}:{i}" for i in range(400, 500)] + for key in zset_keys: + await redis_client.zadd(key, {"value": 1}) + encoded_zset_keys = list(map(lambda k: k.encode(), zset_keys)) + + stream_key = get_random_string(10) + stream_keys = [f"{stream_key}:{i}" for i in range(500, 600)] + for key in stream_keys: + await redis_client.xadd(key, [("field", "value")]) + encoded_stream_keys = list(map(lambda k: k.encode(), stream_keys)) + + cursor = ClusterScanCursor() + keys: list[bytes] = [] + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.STRING) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_string_keys) == set(keys) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + cursor = ClusterScanCursor() + keys.clear() + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.SET) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_set_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + cursor = ClusterScanCursor() + keys.clear() + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.HASH) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_hash_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + cursor = ClusterScanCursor() + keys.clear() + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.LIST) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_list_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + cursor = ClusterScanCursor() + keys.clear() + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.ZSET) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_zset_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + cursor = ClusterScanCursor() + keys.clear() + while not cursor.is_finished(): + result = await redis_client.scan(cursor, type=ObjectType.STREAM) + cursor = cast(ClusterScanCursor, result[0]) + result_keys = result[1] + keys.extend(cast(list[bytes], result_keys)) + assert set(encoded_stream_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + + # Standalone scan tests + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_standalone_scan_simple(self, redis_client: GlideClient): + key = get_random_string(10) + expected_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + keys: list[str] = [] + cursor = b"0" + while True: + result = await redis_client.scan(cursor) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes + new_keys = cast(list[str], result[1]) + keys.extend(new_keys) + if cursor == b"0": + break + assert set(encoded_expected_keys) == set(keys) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_standalone_scan_with_object_type_and_pattern( + self, redis_client: GlideClient + ): + key = get_random_string(10) + expected_keys = [f"key:{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + unexpected_type_keys = [f"key:{i}" for i in range(100, 200)] + for key in unexpected_type_keys: + await redis_client.sadd(key, ["value"]) + unexpected_pattern_keys = [f"{i}" for i in range(200, 300)] + for key in unexpected_pattern_keys: + await redis_client.set(key, "value") + keys: list[str] = [] + cursor = b"0" + while True: + result = await redis_client.scan( + cursor, match=b"key:*", type=ObjectType.STRING + ) + cursor = cast(bytes, result[0]) + keys.extend(list(map(lambda k: k.decode(), cast(list[bytes], result[1])))) + if cursor == b"0": + break + assert set(expected_keys) == set(keys) + assert not set(unexpected_type_keys).intersection(set(keys)) + assert not set(unexpected_pattern_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_standalone_scan_with_count(self, redis_client: GlideClient): + key = get_random_string(10) + expected_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + cursor = "0" + keys: list[str] = [] + successful_compared_scans = 0 + while True: + result_of_1 = await redis_client.scan(cursor, count=1) + cursor_bytes = cast(bytes, result_of_1[0]) + cursor = cursor_bytes.decode() + keys_of_1 = cast(list[str], result_of_1[1]) + keys.extend(keys_of_1) + result_of_100 = await redis_client.scan(cursor, count=100) + cursor_bytes = cast(bytes, result_of_100[0]) + cursor = cursor_bytes.decode() + keys_of_100 = cast(list[str], result_of_100[1]) + keys.extend(keys_of_100) + if len(keys_of_100) > len(keys_of_1): + successful_compared_scans += 1 + if cursor == "0": + break + assert set(encoded_expected_keys) == set(keys) + assert successful_compared_scans > 0 + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_standalone_scan_with_match(self, redis_client: GlideClient): + key = get_random_string(10) + expected_keys = [f"key:{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in expected_keys}) + encoded_expected_keys = map(lambda k: k.encode(), expected_keys) + unexpected_keys = [f"{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in [f"{i}" for i in range(100)]}) + encoded_unexpected_keys = map(lambda k: k.encode(), unexpected_keys) + cursor = "0" + keys: list[str] = [] + while True: + result = await redis_client.scan(cursor, match="key:*") + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = cast(list[str], result[1]) + keys.extend(new_keys) + if cursor == "0": + break + assert set(encoded_expected_keys) == set(keys) + assert not set(encoded_unexpected_keys).intersection(set(keys)) + + @pytest.mark.parametrize("cluster_mode", [False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_standalone_scan_all_types(self, redis_client: GlideClient): + # We test that the scan command work for all types of keys + key = get_random_string(10) + string_keys = [f"{key}:{i}" for i in range(100)] + await redis_client.mset({k: "value" for k in string_keys}) + encoded_string_keys = list(map(lambda k: k.encode(), string_keys)) + + set_keys = [f"{key}:{i}" for i in range(100, 200)] + for key in set_keys: + await redis_client.sadd(key, ["value"]) + encoded_set_keys = list(map(lambda k: k.encode(), set_keys)) + + hash_keys = [f"{key}:{i}" for i in range(200, 300)] + for key in hash_keys: + await redis_client.hset(key, {"field": "value"}) + encoded_hash_keys = list(map(lambda k: k.encode(), hash_keys)) + + list_keys = [f"{key}:{i}" for i in range(300, 400)] + for key in list_keys: + await redis_client.lpush(key, ["value"]) + encoded_list_keys = list(map(lambda k: k.encode(), list_keys)) + + zset_keys = [f"{key}:{i}" for i in range(400, 500)] + for key in zset_keys: + await redis_client.zadd(key, {"value": 1}) + encoded_zset_keys = list(map(lambda k: k.encode(), zset_keys)) + + stream_keys = [f"{key}:{i}" for i in range(500, 600)] + for key in stream_keys: + await redis_client.xadd(key, [("field", "value")]) + encoded_stream_keys = list(map(lambda k: k.encode(), stream_keys)) + + cursor = "0" + keys: list[bytes] = [] + while True: + result = await redis_client.scan(cursor, type=ObjectType.STRING) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_string_keys) == set(keys) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + keys.clear() + while True: + result = await redis_client.scan(cursor, type=ObjectType.SET) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_set_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + keys.clear() + while True: + result = await redis_client.scan(cursor, type=ObjectType.HASH) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_hash_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + keys.clear() + while True: + result = await redis_client.scan(cursor, type=ObjectType.LIST) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_list_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + keys.clear() + while True: + result = await redis_client.scan(cursor, type=ObjectType.ZSET) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_zset_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_stream_keys).intersection(set(keys)) + + keys.clear() + while True: + result = await redis_client.scan(cursor, type=ObjectType.STREAM) + cursor_bytes = cast(bytes, result[0]) + cursor = cursor_bytes.decode() + new_keys = result[1] + keys.extend(cast(list[bytes], new_keys)) + if cursor == "0": + break + assert set(encoded_stream_keys) == set(keys) + assert not set(encoded_string_keys).intersection(set(keys)) + assert not set(encoded_set_keys).intersection(set(keys)) + assert not set(encoded_hash_keys).intersection(set(keys)) + assert not set(encoded_list_keys).intersection(set(keys)) + assert not set(encoded_zset_keys).intersection(set(keys)) diff --git a/python/src/lib.rs b/python/src/lib.rs index cdcc651b2d..b37ff1def8 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use glide_core::client::FINISHED_SCAN_CURSOR; /** * Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */ @@ -8,7 +9,6 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::{PyAny, PyBool, PyBytes, PyDict, PyFloat, PyList, PySet}; use pyo3::Python; - use redis::Value; pub const DEFAULT_TIMEOUT_IN_MILLISECONDS: u32 = @@ -33,6 +33,42 @@ impl Level { } } +/// This struct is used to keep track of the cursor of a cluster scan. +/// We want to avoid passing the cursor between layers of the application, +/// So we keep the state in the container and only pass the id of the cursor. +/// The cursor is stored in the container and can be retrieved using the id. +/// The cursor is removed from the container when the object is deleted (dropped). +#[pyclass] +#[derive(Default)] +pub struct ClusterScanCursor { + cursor: String, +} + +#[pymethods] +impl ClusterScanCursor { + #[new] + fn new(new_cursor: Option) -> Self { + match new_cursor { + Some(cursor) => ClusterScanCursor { cursor }, + None => ClusterScanCursor::default(), + } + } + + fn get_cursor(&self) -> String { + self.cursor.clone() + } + + fn is_finished(&self) -> bool { + self.cursor == *FINISHED_SCAN_CURSOR.to_string() + } +} + +impl Drop for ClusterScanCursor { + fn drop(&mut self) { + glide_core::cluster_scan_container::remove_scan_state_cursor(self.cursor.clone()); + } +} + #[pyclass] pub struct Script { hash: String, @@ -69,6 +105,7 @@ impl Script { fn glide(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::