From 18e6a9b2695d44d56f681dbccce756fa8c8c36e9 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Tue, 18 Jun 2024 22:13:07 +0300 Subject: [PATCH] Pubsub implementation in glide-core and with Python wrapper. Works both in standalone and cluster modes. Pubsub configuration is provided in client creation params. --- benchmarks/rust/src/main.rs | 2 +- csharp/lib/src/lib.rs | 2 +- glide-core/Cargo.toml | 4 +- glide-core/benches/connections_benchmark.rs | 4 +- glide-core/benches/memory_benchmark.rs | 2 +- glide-core/src/client/mod.rs | 29 +++- .../src/client/reconnecting_connection.rs | 22 ++- glide-core/src/client/standalone_client.rs | 21 ++- glide-core/src/client/types.rs | 37 +++++ .../src/protobuf/connection_request.proto | 17 ++ glide-core/src/protobuf/redis_request.proto | 2 + glide-core/src/protobuf/response.proto | 1 + glide-core/src/request_type.rs | 6 + glide-core/src/socket_listener.rs | 50 +++++- glide-core/tests/test_client.rs | 1 + glide-core/tests/test_standalone_client.rs | 6 +- glide-core/tests/utilities/cluster.rs | 2 +- glide-core/tests/utilities/mod.rs | 18 +- go/src/lib.rs | 2 +- node/rust-client/src/lib.rs | 3 +- .../glide/async_commands/cluster_commands.py | 26 +++ .../async_commands/standalone_commands.py | 19 +++ python/python/glide/config.py | 56 ++++++- python/python/glide/redis_client.py | 157 +++++++++++++++--- python/python/tests/conftest.py | 8 + python/python/tests/test_async_client.py | 87 +++++++++- python/src/lib.rs | 8 +- submodules/redis-rs | 2 +- utils/cluster_manager.py | 2 +- 29 files changed, 525 insertions(+), 71 deletions(-) diff --git a/benchmarks/rust/src/main.rs b/benchmarks/rust/src/main.rs index 8503375195..c5098e13d1 100644 --- a/benchmarks/rust/src/main.rs +++ b/benchmarks/rust/src/main.rs @@ -236,7 +236,7 @@ async fn get_connection(args: &Args) -> Client { ..Default::default() }; - glide_core::client::Client::new(connection_request) + glide_core::client::Client::new(connection_request, None) .await .unwrap() } diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index fce015a376..32ca91f2a2 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -59,7 +59,7 @@ fn create_client_internal( .thread_name("GLIDE for Redis C# thread") .build()?; let _runtime_handle = runtime.enter(); - let client = runtime.block_on(GlideClient::new(request)).unwrap(); // TODO - handle errors. + let client = runtime.block_on(GlideClient::new(request, None)).unwrap(); // TODO - handle errors. Ok(Client { client, success_callback, diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index dc623c6714..1d934fabc3 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -20,7 +20,7 @@ tokio-retry = "0.3.0" protobuf = { version= "3", features = ["bytes", "with-bytes"], optional = true } integer-encoding = { version = "4.0.0", optional = true } thiserror = "1" -rand = { version = "0.8.5", optional = true } +rand = { version = "0.8.5" } futures-intrusive = "0.5.0" directories = { version = "4.0", optional = true } once_cell = "1.18.0" @@ -28,7 +28,7 @@ arcstr = "1.1.5" sha1_smol = "1.0.0" [features] -socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util", "bytes", "rand"] +socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util", "bytes"] [dev-dependencies] rsevents = "0.3.1" diff --git a/glide-core/benches/connections_benchmark.rs b/glide-core/benches/connections_benchmark.rs index fc98933de8..5930fd4b44 100644 --- a/glide-core/benches/connections_benchmark.rs +++ b/glide-core/benches/connections_benchmark.rs @@ -83,7 +83,7 @@ fn get_connection_info(address: ConnectionAddr) -> redis::ConnectionInfo { fn multiplexer_benchmark(c: &mut Criterion, address: ConnectionAddr, group: &str) { benchmark(c, address, "multiplexer", group, |address, runtime| { let client = redis::Client::open(get_connection_info(address)).unwrap(); - runtime.block_on(async { client.get_multiplexed_tokio_connection().await.unwrap() }) + runtime.block_on(async { client.get_multiplexed_tokio_connection(None).await.unwrap() }) }); } @@ -120,7 +120,7 @@ fn cluster_connection_benchmark( builder = builder.read_from_replicas(); } let client = builder.build().unwrap(); - client.get_async_connection().await + client.get_async_connection(None).await }) .unwrap() }); diff --git a/glide-core/benches/memory_benchmark.rs b/glide-core/benches/memory_benchmark.rs index c6e307bae2..7f81e3fddb 100644 --- a/glide-core/benches/memory_benchmark.rs +++ b/glide-core/benches/memory_benchmark.rs @@ -26,7 +26,7 @@ where { let runtime = Builder::new_current_thread().enable_all().build().unwrap(); runtime.block_on(async { - let client = Client::new(create_connection_request().into()) + let client = Client::new(create_connection_request().into(), None) .await .unwrap(); f(client).await; diff --git a/glide-core/src/client/mod.rs b/glide-core/src/client/mod.rs index 3197c82a23..2abcd06aa4 100644 --- a/glide-core/src/client/mod.rs +++ b/glide-core/src/client/mod.rs @@ -9,7 +9,7 @@ use logger_core::log_info; use redis::aio::ConnectionLike; use redis::cluster_async::ClusterConnection; use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo}; -use redis::{Cmd, ErrorKind, Value}; +use redis::{Cmd, ErrorKind, PushInfo, Value}; use redis::{RedisError, RedisResult}; pub use standalone_client::StandaloneClient; use std::io; @@ -21,6 +21,7 @@ use self::value_conversion::{convert_to_expected_type, expected_type_for_cmd, ge mod reconnecting_connection; mod standalone_client; mod value_conversion; +use tokio::sync::mpsc; pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1); @@ -44,6 +45,7 @@ pub(super) fn get_redis_connection_info( let protocol = connection_request.protocol.unwrap_or_default(); let db = connection_request.database_id; let client_name = connection_request.client_name.clone(); + let pubsub_subscriptions = connection_request.pubsub_subscriptions.clone(); match &connection_request.authentication_info { Some(info) => redis::RedisConnectionInfo { db, @@ -51,11 +53,13 @@ pub(super) fn get_redis_connection_info( password: info.password.clone(), protocol, client_name, + pubsub_subscriptions, }, None => redis::RedisConnectionInfo { db, protocol, client_name, + pubsub_subscriptions, ..Default::default() }, } @@ -373,6 +377,7 @@ fn to_duration(time_in_millis: Option, default: Duration) -> Duration { async fn create_cluster_client( request: ConnectionRequest, + push_sender: Option>, ) -> RedisResult { // TODO - implement timeout for each connection attempt let tls_mode = request.tls_mode.unwrap_or_default(); @@ -410,8 +415,11 @@ async fn create_cluster_client( }; builder = builder.tls(tls); } + if let Some(pubsub_subscriptions) = redis_connection_info.pubsub_subscriptions { + builder = builder.pubsub_subscriptions(pubsub_subscriptions); + } let client = builder.build()?; - client.get_async_connection().await + client.get_async_connection(push_sender).await } #[derive(thiserror::Error)] @@ -520,13 +528,22 @@ fn sanitized_request_string(request: &ConnectionRequest) -> String { String::new() }; + let pubsub_subscriptions = request + .pubsub_subscriptions + .as_ref() + .map(|pubsub_subscriptions| format!("\nPubsub subscriptions: {pubsub_subscriptions:?}")) + .unwrap_or_default(); + format!( - "\nAddresses: {addresses}{tls_mode}{cluster_mode}{request_timeout}{rfr_strategy}{connection_retry_strategy}{database_id}{protocol}{client_name}{periodic_checks}", + "\nAddresses: {addresses}{tls_mode}{cluster_mode}{request_timeout}{rfr_strategy}{connection_retry_strategy}{database_id}{protocol}{client_name}{periodic_checks}{pubsub_subscriptions}", ) } impl Client { - pub async fn new(request: ConnectionRequest) -> Result { + pub async fn new( + request: ConnectionRequest, + push_sender: Option>, + ) -> Result { const DEFAULT_CLIENT_CREATION_TIMEOUT: Duration = Duration::from_secs(10); log_info( @@ -536,13 +553,13 @@ impl Client { let request_timeout = to_duration(request.request_timeout, DEFAULT_RESPONSE_TIMEOUT); tokio::time::timeout(DEFAULT_CLIENT_CREATION_TIMEOUT, async move { let internal_client = if request.cluster_mode_enabled { - let client = create_cluster_client(request) + let client = create_cluster_client(request, push_sender) .await .map_err(ConnectionError::Cluster)?; ClientWrapper::Cluster { client } } else { ClientWrapper::Standalone( - StandaloneClient::create_client(request) + StandaloneClient::create_client(request, push_sender) .await .map_err(ConnectionError::Standalone)?, ) diff --git a/glide-core/src/client/reconnecting_connection.rs b/glide-core/src/client/reconnecting_connection.rs index ac33f6c005..a592a357a6 100644 --- a/glide-core/src/client/reconnecting_connection.rs +++ b/glide-core/src/client/reconnecting_connection.rs @@ -6,12 +6,13 @@ use crate::retry_strategies::RetryStrategy; use futures_intrusive::sync::ManualResetEvent; use logger_core::{log_debug, log_trace, log_warn}; use redis::aio::MultiplexedConnection; -use redis::{RedisConnectionInfo, RedisError, RedisResult}; +use redis::{PushInfo, RedisConnectionInfo, RedisError, RedisResult}; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; +use tokio::sync::mpsc; use tokio::task; use tokio_retry::Retry; @@ -45,6 +46,7 @@ struct InnerReconnectingConnection { #[derive(Clone)] pub(super) struct ReconnectingConnection { inner: Arc, + push_sender: Option>, } impl fmt::Debug for ReconnectingConnection { @@ -53,10 +55,13 @@ impl fmt::Debug for ReconnectingConnection { } } -async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult { +async fn get_multiplexed_connection( + client: &redis::Client, + push_sender: Option>, +) -> RedisResult { run_with_timeout( Some(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT), - client.get_multiplexed_async_connection(), + client.get_multiplexed_async_connection(push_sender), ) .await } @@ -64,9 +69,10 @@ async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult>, ) -> Result { let client = &connection_backend.connection_info; - let action = || get_multiplexed_connection(client); + let action = || get_multiplexed_connection(client, push_sender.clone()); match Retry::spawn(retry_strategy.get_iterator(), action).await { Ok(connection) => { @@ -85,6 +91,7 @@ async fn create_connection( state: Mutex::new(ConnectionState::Connected(connection)), backend: connection_backend, }), + push_sender, }) } Err(err) => { @@ -103,6 +110,7 @@ async fn create_connection( state: Mutex::new(ConnectionState::InitializedDisconnected), backend: connection_backend, }), + push_sender, }; connection.reconnect(); Err((connection, err)) @@ -141,6 +149,7 @@ impl ReconnectingConnection { connection_retry_strategy: RetryStrategy, redis_connection_info: RedisConnectionInfo, tls_mode: TlsMode, + push_sender: Option>, ) -> Result { log_debug( "connection creation", @@ -153,7 +162,7 @@ impl ReconnectingConnection { connection_available_signal: ManualResetEvent::new(true), client_dropped_flagged: AtomicBool::new(false), }; - create_connection(backend, connection_retry_strategy).await + create_connection(backend, connection_retry_strategy, push_sender).await } fn node_address(&self) -> String { @@ -211,6 +220,7 @@ impl ReconnectingConnection { log_debug("reconnect", "starting"); let connection_clone = self.clone(); + let push_sender = self.push_sender.clone(); // The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the // background, regardless of whether the calling task is dropped or not. task::spawn(async move { @@ -224,7 +234,7 @@ impl ReconnectingConnection { // Client was dropped, reconnection attempts can stop return; } - match get_multiplexed_connection(client).await { + match get_multiplexed_connection(client, push_sender.clone()).await { Ok(mut connection) => { if connection .send_packed_command(&redis::cmd("PING")) diff --git a/glide-core/src/client/standalone_client.rs b/glide-core/src/client/standalone_client.rs index 736155bbf0..6796f525a6 100644 --- a/glide-core/src/client/standalone_client.rs +++ b/glide-core/src/client/standalone_client.rs @@ -9,10 +9,12 @@ use futures::{future, stream, StreamExt}; #[cfg(standalone_heartbeat)] use logger_core::log_debug; use logger_core::log_warn; +use rand::Rng; use redis::cluster_routing::{self, is_readonly_cmd, ResponsePolicy, Routable, RoutingInfo}; -use redis::{RedisError, RedisResult, Value}; +use redis::{PushInfo, RedisError, RedisResult, Value}; use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use tokio::sync::mpsc; #[cfg(standalone_heartbeat)] use tokio::task; @@ -96,22 +98,33 @@ impl std::fmt::Debug for StandaloneClientConnectionError { impl StandaloneClient { pub async fn create_client( connection_request: ConnectionRequest, + push_sender: Option>, ) -> Result { if connection_request.addresses.is_empty() { return Err(StandaloneClientConnectionError::NoAddressesProvided); } - let redis_connection_info = get_redis_connection_info(&connection_request); + let mut redis_connection_info = get_redis_connection_info(&connection_request); + let pubsub_connection_info = redis_connection_info.clone(); + redis_connection_info.pubsub_subscriptions = None; let retry_strategy = RetryStrategy::new(connection_request.connection_retry_strategy); let tls_mode = connection_request.tls_mode; let node_count = connection_request.addresses.len(); + // randomize pubsub nodes, maybe a batter option is to always use the primary + let pubsub_node_index = rand::thread_rng().gen_range(0..node_count); + let pubsub_addr = &connection_request.addresses[pubsub_node_index]; let mut stream = stream::iter(connection_request.addresses.iter()) .map(|address| async { get_connection_and_replication_info( address, &retry_strategy, - &redis_connection_info, + if address.to_string() != pubsub_addr.to_string() { + &redis_connection_info + } else { + &pubsub_connection_info + }, tls_mode.unwrap_or(TlsMode::NoTls), + &push_sender, ) .await .map_err(|err| (format!("{}:{}", address.host, address.port), err)) @@ -392,12 +405,14 @@ async fn get_connection_and_replication_info( retry_strategy: &RetryStrategy, connection_info: &redis::RedisConnectionInfo, tls_mode: TlsMode, + push_sender: &Option>, ) -> Result<(ReconnectingConnection, Value), (ReconnectingConnection, RedisError)> { let result = ReconnectingConnection::new( address, retry_strategy.clone(), connection_info.clone(), tls_mode, + push_sender.clone(), ) .await; let reconnecting_connection = match result { diff --git a/glide-core/src/client/types.rs b/glide-core/src/client/types.rs index f942f64174..2422ed3d0c 100644 --- a/glide-core/src/client/types.rs +++ b/glide-core/src/client/types.rs @@ -2,6 +2,8 @@ * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +use logger_core::log_warn; +use std::collections::HashSet; use std::time::Duration; #[cfg(feature = "socket-layer")] @@ -20,6 +22,7 @@ pub struct ConnectionRequest { pub request_timeout: Option, pub connection_retry_strategy: Option, pub periodic_checks: Option, + pub pubsub_subscriptions: Option, } pub struct AuthenticationInfo { @@ -150,6 +153,39 @@ impl From for ConnectionRequest { PeriodicCheck::Disabled } }); + let mut pubsub_subscriptions: Option = None; + if let Some(protobuf_pubsub) = value.pubsub_subscriptions.0 { + let mut redis_pubsub = redis::PubSubSubscriptionInfo::new(); + for (pubsub_type, channels_patterns) in + protobuf_pubsub.channels_or_patterns_by_type.iter() + { + let kind = match *pubsub_type { + 0 => redis::PubSubSubscriptionKind::Exact, + 1 => redis::PubSubSubscriptionKind::Pattern, + 2 => redis::PubSubSubscriptionKind::Sharded, + 3_u32..=u32::MAX => { + log_warn( + "client creation", + format!( + "Omitting pubsub subscription on an unknown type: {:?}", + *pubsub_type + ), + ); + continue; + } + }; + + for channel_pattern in channels_patterns.channels_or_patterns.iter() { + redis_pubsub + .entry(kind) + .and_modify(|channels_patterns| { + channels_patterns.insert(channel_pattern.to_vec()); + }) + .or_insert(HashSet::from([channel_pattern.to_vec()])); + } + } + pubsub_subscriptions = Some(redis_pubsub); + } ConnectionRequest { read_from, @@ -163,6 +199,7 @@ impl From for ConnectionRequest { request_timeout, connection_retry_strategy, periodic_checks, + pubsub_subscriptions, } } } diff --git a/glide-core/src/protobuf/connection_request.proto b/glide-core/src/protobuf/connection_request.proto index ecdeeae1b2..a186a1f41f 100644 --- a/glide-core/src/protobuf/connection_request.proto +++ b/glide-core/src/protobuf/connection_request.proto @@ -36,6 +36,22 @@ message PeriodicChecksManualInterval { message PeriodicChecksDisabled { } +enum PubSubChannelType { + Exact = 0; + Pattern = 1; + Sharded = 2; +} + +message PubSubChannelsOrPatterns +{ + repeated bytes channels_or_patterns = 1; +} + +message PubSubSubscriptions +{ + map channels_or_patterns_by_type = 1; +} + // IMPORTANT - if you add fields here, you probably need to add them also in client/mod.rs:`sanitized_request_string`. message ConnectionRequest { repeated NodeAddress addresses = 1; @@ -52,6 +68,7 @@ message ConnectionRequest { PeriodicChecksManualInterval periodic_checks_manual_interval = 11; PeriodicChecksDisabled periodic_checks_disabled = 12; } + PubSubSubscriptions pubsub_subscriptions = 13; } message ConnectionRetryStrategy { diff --git a/glide-core/src/protobuf/redis_request.proto b/glide-core/src/protobuf/redis_request.proto index 984484d210..0b798437eb 100644 --- a/glide-core/src/protobuf/redis_request.proto +++ b/glide-core/src/protobuf/redis_request.proto @@ -224,6 +224,8 @@ enum RequestType { UnWatch = 184; GeoSearchStore = 185; SUnion = 186; + Publish = 187; + SPublish = 188; } message Command { diff --git a/glide-core/src/protobuf/response.proto b/glide-core/src/protobuf/response.proto index 33591112ba..871d38e476 100644 --- a/glide-core/src/protobuf/response.proto +++ b/glide-core/src/protobuf/response.proto @@ -21,6 +21,7 @@ message Response { RequestError request_error = 4; string closing_error = 5; } + bool is_push = 6; } enum ConstantResponse { diff --git a/glide-core/src/request_type.rs b/glide-core/src/request_type.rs index 19487468b0..46c4a0bd63 100644 --- a/glide-core/src/request_type.rs +++ b/glide-core/src/request_type.rs @@ -194,6 +194,8 @@ pub enum RequestType { UnWatch = 184, GeoSearchStore = 185, SUnion = 186, + Publish = 187, + SPublish = 188, } fn get_two_word_command(first: &str, second: &str) -> Cmd { @@ -391,6 +393,8 @@ impl From<::protobuf::EnumOrUnknown> for RequestType { ProtobufRequestType::Watch => RequestType::Watch, ProtobufRequestType::UnWatch => RequestType::UnWatch, ProtobufRequestType::GeoSearchStore => RequestType::GeoSearchStore, + ProtobufRequestType::Publish => RequestType::Publish, + ProtobufRequestType::SPublish => RequestType::SPublish, } } } @@ -584,6 +588,8 @@ impl RequestType { RequestType::Watch => Some(cmd("WATCH")), RequestType::UnWatch => Some(cmd("UNWATCH")), RequestType::GeoSearchStore => Some(cmd("GEOSEARCHSTORE")), + RequestType::Publish => Some(cmd("PUBLISH")), + RequestType::SPublish => Some(cmd("SPUBLISH")), } } } diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 82b8df6a69..2c9f91d753 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -21,7 +21,7 @@ use redis::cluster_routing::{ }; use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::RedisError; -use redis::{Cmd, Value}; +use redis::{Cmd, PushInfo, Value}; use std::cell::Cell; use std::rc::Rc; use std::{env, str}; @@ -30,6 +30,7 @@ use thiserror::Error; use tokio::io::ErrorKind::AddrInUse; use tokio::net::{UnixListener, UnixStream}; use tokio::runtime::Builder; +use tokio::sync::mpsc; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::task; @@ -184,6 +185,7 @@ async fn write_result( ) -> Result<(), io::Error> { let mut response = Response::new(); response.callback_idx = callback_index; + response.is_push = false; response.value = match resp_result { Ok(Value::Okay) => Some(response::response::Value::ConstantResponse( response::ConstantResponse::OK.into(), @@ -473,8 +475,9 @@ pub fn close_socket(socket_path: &String) { async fn create_client( writer: &Rc, request: ConnectionRequest, + push_tx: Option>, ) -> Result { - let client = match Client::new(request.into()).await { + let client = match Client::new(request.into(), push_tx).await { Ok(client) => client, Err(err) => return Err(ClientCreationError::ConnectionError(err)), }; @@ -485,13 +488,14 @@ async fn create_client( async fn wait_for_connection_configuration_and_create_client( client_listener: &mut UnixStreamListener, writer: &Rc, + push_tx: Option>, ) -> Result { // Wait for the server's address match client_listener.next_values::().await { Closed(reason) => Err(ClientCreationError::SocketListenerClosed(reason)), ReceivedValues(mut received_requests) => { if let Some(request) = received_requests.pop() { - create_client(writer, request).await + create_client(writer, request, push_tx).await } else { Err(ClientCreationError::UnhandledError( "No received requests".to_string(), @@ -518,6 +522,35 @@ async fn read_values_loop( } } +async fn push_manager_loop(mut push_rx: mpsc::UnboundedReceiver, writer: Rc) { + loop { + let result = push_rx.recv().await; + match result { + None => { + log_trace("push manager loop", "got None as from push manager"); + return; + } + Some(push_msg) => { + log_debug("push manager loop", format!("got PushInfo: {:?}", push_msg)); + let mut response = Response::new(); + response.callback_idx = 0; // callback_idx is not used with push notifications + response.is_push = true; + response.value = { + let push_val = Value::Push { + kind: (push_msg.kind), + data: (push_msg.data), + }; + let pointer = Box::leak(Box::new(push_val)); + let raw_pointer = pointer as *mut redis::Value; + Some(response::response::Value::RespPointer(raw_pointer as u64)) + }; + + _ = write_to_writer(response, &writer).await; + } + } + } +} + async fn listen_on_client_stream(socket: UnixStream) { let socket = Rc::new(socket); // Spawn a new task to listen on this client's stream @@ -525,14 +558,18 @@ async fn listen_on_client_stream(socket: UnixStream) { let mut client_listener = UnixStreamListener::new(socket.clone()); let accumulated_outputs = Cell::new(Vec::new()); let (sender, mut receiver) = channel(1); + let (push_tx, push_rx) = tokio::sync::mpsc::unbounded_channel(); let writer = Rc::new(Writer { socket, lock: write_lock, accumulated_outputs, closing_sender: sender, }); - let client_creation = - wait_for_connection_configuration_and_create_client(&mut client_listener, &writer); + let client_creation = wait_for_connection_configuration_and_create_client( + &mut client_listener, + &writer, + Some(push_tx), + ); let client = match client_creation.await { Ok(conn) => conn, Err(ClientCreationError::SocketListenerClosed(ClosingReason::ReadSocketClosed)) => { @@ -583,6 +620,9 @@ async fn listen_on_client_stream(socket: UnixStream) { } else { log_trace("client closing", "writer closed"); } + }, + _ = push_manager_loop(push_rx, writer.clone()) => { + log_trace("client closing", "push manager closed"); } } log_trace("client closing", "closing connection"); diff --git a/glide-core/tests/test_client.rs b/glide-core/tests/test_client.rs index 682b5de9b9..961bcd30f0 100644 --- a/glide-core/tests/test_client.rs +++ b/glide-core/tests/test_client.rs @@ -35,6 +35,7 @@ pub(crate) mod shared_client_tests { Client::new( create_connection_request(&[connection_addr.clone()], &configuration) .into(), + None, ) .await .ok() diff --git a/glide-core/tests/test_standalone_client.rs b/glide-core/tests/test_standalone_client.rs index aa7f3b6609..1073ad24fb 100644 --- a/glide-core/tests/test_standalone_client.rs +++ b/glide-core/tests/test_standalone_client.rs @@ -199,7 +199,7 @@ mod standalone_client_tests { connection_request.read_from = config.read_from.into(); block_on_all(async { - let mut client = StandaloneClient::create_client(connection_request.into()) + let mut client = StandaloneClient::create_client(connection_request.into(), None) .await .unwrap(); for mock in mocks.drain(1..config.number_of_replicas_dropped_after_connection + 1) { @@ -305,7 +305,7 @@ mod standalone_client_tests { let connection_request = create_connection_request(addresses.as_slice(), &Default::default()); block_on_all(async { - let client_res = StandaloneClient::create_client(connection_request.into()) + let client_res = StandaloneClient::create_client(connection_request.into(), None) .await .map_err(ConnectionError::Standalone); assert!(client_res.is_err()); @@ -344,7 +344,7 @@ mod standalone_client_tests { create_connection_request(addresses.as_slice(), &Default::default()); block_on_all(async { - let mut client = StandaloneClient::create_client(connection_request.into()) + let mut client = StandaloneClient::create_client(connection_request.into(), None) .await .unwrap(); diff --git a/glide-core/tests/utilities/cluster.rs b/glide-core/tests/utilities/cluster.rs index 6ff69a932e..6ed1a07fb4 100644 --- a/glide-core/tests/utilities/cluster.rs +++ b/glide-core/tests/utilities/cluster.rs @@ -249,7 +249,7 @@ pub async fn create_cluster_client( configuration.request_timeout = configuration.request_timeout.or(Some(10000)); let connection_request = create_connection_request(&addresses, &configuration); - Client::new(connection_request.into()).await.unwrap() + Client::new(connection_request.into(), None).await.unwrap() } pub async fn setup_test_basics_internal(configuration: TestConfiguration) -> ClusterTestBasics { diff --git a/glide-core/tests/utilities/mod.rs b/glide-core/tests/utilities/mod.rs index 04bd727a1d..0bebec2c82 100644 --- a/glide-core/tests/utilities/mod.rs +++ b/glide-core/tests/utilities/mod.rs @@ -12,7 +12,7 @@ use once_cell::sync::Lazy; use rand::{distributions::Alphanumeric, Rng}; use redis::{ cluster_routing::{MultipleNodeRoutingInfo, RoutingInfo}, - ConnectionAddr, RedisConnectionInfo, RedisResult, Value, + ConnectionAddr, PushInfo, RedisConnectionInfo, RedisResult, Value, }; use socket2::{Domain, Socket, Type}; use std::{ @@ -20,6 +20,7 @@ use std::{ sync::Mutex, time::Duration, }; use tempfile::TempDir; +use tokio::sync::mpsc; pub mod cluster; pub mod mocks; @@ -456,7 +457,7 @@ pub async fn wait_for_server_to_become_ready(server_address: &ConnectionAddr) { }) .unwrap(); loop { - match client.get_multiplexed_async_connection().await { + match client.get_multiplexed_async_connection(None).await { Err(err) => { if err.is_connection_refusal() { tokio::time::sleep(millisecond).await; @@ -546,6 +547,7 @@ pub async fn send_set_and_get(mut client: Client, key: String) { pub struct TestBasics { pub server: Option, pub client: StandaloneClient, + pub push_receiver: mpsc::UnboundedReceiver, } fn convert_to_protobuf_protocol( @@ -592,7 +594,8 @@ pub async fn setup_acl(addr: &ConnectionAddr, connection_info: &RedisConnectionI }) .unwrap(); let mut connection = - repeat_try_create(|| async { client.get_multiplexed_async_connection().await.ok() }).await; + repeat_try_create(|| async { client.get_multiplexed_async_connection(None).await.ok() }) + .await; let password = connection_info.password.clone().unwrap(); let username = connection_info @@ -689,11 +692,16 @@ pub(crate) async fn setup_test_basics_internal(configuration: &TestConfiguration let mut connection_request = create_connection_request(&[connection_addr], configuration); connection_request.cluster_mode_enabled = false; connection_request.protocol = configuration.protocol.into(); - let client = StandaloneClient::create_client(connection_request.into()) + let (push_sender, push_receiver) = tokio::sync::mpsc::unbounded_channel(); + let client = StandaloneClient::create_client(connection_request.into(), Some(push_sender)) .await .unwrap(); - TestBasics { server, client } + TestBasics { + server, + client, + push_receiver, + } } pub async fn setup_test_basics(use_tls: bool) -> TestBasics { diff --git a/go/src/lib.rs b/go/src/lib.rs index 72ffeca427..28ac6d0080 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -81,7 +81,7 @@ fn create_client_internal( errors::error_message(&redis_error) })?; let client = runtime - .block_on(GlideClient::new(ConnectionRequest::from(request))) + .block_on(GlideClient::new(ConnectionRequest::from(request), None)) .map_err(|err| err.to_string())?; Ok(ClientAdapter { client, diff --git a/node/rust-client/src/lib.rs b/node/rust-client/src/lib.rs index 0de3d2bae8..b83c38949e 100644 --- a/node/rust-client/src/lib.rs +++ b/node/rust-client/src/lib.rs @@ -67,7 +67,8 @@ impl AsyncClient { .build()?; let _runtime_handle = runtime.enter(); let client = to_js_result(redis::Client::open(connection_address))?; - let connection = to_js_result(runtime.block_on(client.get_multiplexed_async_connection()))?; + let connection = + to_js_result(runtime.block_on(client.get_multiplexed_async_connection(None)))?; Ok(AsyncClient { connection, runtime, diff --git a/python/python/glide/async_commands/cluster_commands.py b/python/python/glide/async_commands/cluster_commands.py index e010f1f54b..6796c9608f 100644 --- a/python/python/glide/async_commands/cluster_commands.py +++ b/python/python/glide/async_commands/cluster_commands.py @@ -452,3 +452,29 @@ async def sort_store( args = _build_sort_args(key, None, limit, None, order, alpha, store=destination) result = await self._execute_command(RequestType.Sort, args) return cast(int, result) + + async def publish(self, message: str, channel: str, sharded: bool = False) -> int: + """ + Publish message on pubsub channel. + This command aggregates PUBLISH and SPUBLISH commands functionalities. + The mode is selected using the 'sharded' parameter + See https://valkey.io/commands/publish and https://valkey.io/commands/spublish for more details. + + Args: + message: Message to publish + channel: Channel to publish the message on. + sharded: Use sharded pubsub mode. + + Returns: + int: Number of clients that received the message. + + Examples: + >>> await client.publish("Hi all!", "global-channel", False) + 1 # Publishes "Hi all!" message on global-channel channel using non-sharded mode + >>> await client.publish("Hi to sharded channel1!", "channel1, True) + 2 # Publishes "Hi to sharded channel1!" message on channel1 using sharded mode + """ + result = await self._execute_command( + RequestType.SPublish if sharded else RequestType.Publish, [channel, message] + ) + return cast(int, result) diff --git a/python/python/glide/async_commands/standalone_commands.py b/python/python/glide/async_commands/standalone_commands.py index f8776aadbb..74fff200a1 100644 --- a/python/python/glide/async_commands/standalone_commands.py +++ b/python/python/glide/async_commands/standalone_commands.py @@ -410,3 +410,22 @@ async def sort_store( ) result = await self._execute_command(RequestType.Sort, args) return cast(int, result) + + async def publish(self, message: str, channel: str) -> int: + """ + Publish message on pubsub channel. + See https://valkey.io/commands/publish for more details. + + Args: + message: Message to publish + channel: Channel to publish the message on. + + Returns: + int: Number of clients that received the message. + + Examples: + >>> await client.publish("Hi all!", "global-channel") + 1 # Publishes "Hi all!" message on global-channel channel + """ + result = await self._execute_command(RequestType.Publish, [channel, message]) + return cast(int, result) diff --git a/python/python/glide/config.py b/python/python/glide/config.py index 5c6ba07969..05117d54a2 100644 --- a/python/python/glide/config.py +++ b/python/python/glide/config.py @@ -1,7 +1,7 @@ # Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 -from enum import Enum -from typing import List, Optional, Union +from enum import Enum, IntEnum +from typing import Dict, List, Optional, Set, Union from glide.protobuf.connection_request_pb2 import ConnectionRequest from glide.protobuf.connection_request_pb2 import ProtocolVersion as SentProtocolVersion @@ -221,8 +221,23 @@ class RedisClientConfiguration(BaseClientConfiguration): database_id (Optional[int]): index of the logical database to connect to. client_name (Optional[str]): Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment. protocol (ProtocolVersion): The version of the Redis RESP protocol to communicate with the server. + pubsub_subscriptions (Optional[RedisClientConfiguration.PubSubSubscriptions]): Pubsub subscriptions to be used for the client. + Will be applied via SUBSCRIBE/PSUBSCRIBE commands during connection establishment. """ + class PubSubChannelModes(IntEnum): + """ + Describes pubsub subsciption modes. + See https://valkey.io/docs/topics/pubsub/ for more details + """ + + Exact = 0 + """ Use exact channel names """ + Pattern = 1 + """ Use channel name patterns """ + + PubSubSubscriptions = Dict[PubSubChannelModes, Set[str]] + def __init__( self, addresses: List[NodeAddress], @@ -234,6 +249,7 @@ def __init__( database_id: Optional[int] = None, client_name: Optional[str] = None, protocol: ProtocolVersion = ProtocolVersion.RESP3, + pubsub_subscriptions: Optional[PubSubSubscriptions] = None, ): super().__init__( addresses=addresses, @@ -246,6 +262,7 @@ def __init__( ) self.reconnect_strategy = reconnect_strategy self.database_id = database_id + self.pubsub_subscriptions = pubsub_subscriptions def _create_a_protobuf_conn_request( self, cluster_mode: bool = False @@ -263,6 +280,14 @@ def _create_a_protobuf_conn_request( if self.database_id: request.database_id = self.database_id + if self.pubsub_subscriptions: + for channel_type, channels_patterns in self.pubsub_subscriptions.items(): + entry = request.pubsub_subscriptions.channels_or_patterns_by_type[ + int(channel_type) + ] + for channel_pattern in channels_patterns: + entry.channels_or_patterns.append(str.encode(channel_pattern)) + return request @@ -290,12 +315,29 @@ class ClusterClientConfiguration(BaseClientConfiguration): These checks evaluate changes in the cluster's topology, triggering a slot refresh when detected. Periodic checks ensure a quick and efficient process by querying a limited number of nodes. Defaults to PeriodicChecksStatus.ENABLED_DEFAULT_CONFIGS. + pubsub_subscriptions (Optional[ClusterClientConfiguration.PubSubSubscriptions]): Pubsub subscriptions to be used for the client. + Will be applied via SUBSCRIBE/PSUBSCRIBE/SSUBSCRIBE commands during connection establishment. Notes: Currently, the reconnection strategy in cluster mode is not configurable, and exponential backoff with fixed values is used. """ + class PubSubChannelModes(IntEnum): + """ + Describes pubsub subsciption modes. + See https://valkey.io/docs/topics/pubsub/ for more details + """ + + Exact = 0 + """ Use exact channel names """ + Pattern = 1 + """ Use channel name patterns """ + Sharded = 2 + """ Use sharded pubsub """ + + PubSubSubscriptions = Dict[PubSubChannelModes, Set[str]] + def __init__( self, addresses: List[NodeAddress], @@ -308,6 +350,7 @@ def __init__( periodic_checks: Union[ PeriodicChecksStatus, PeriodicChecksManualInterval ] = PeriodicChecksStatus.ENABLED_DEFAULT_CONFIGS, + pubsub_subscriptions: Optional[PubSubSubscriptions] = None, ): super().__init__( addresses=addresses, @@ -319,6 +362,7 @@ def __init__( protocol=protocol, ) self.periodic_checks = periodic_checks + self.pubsub_subscriptions = pubsub_subscriptions def _create_a_protobuf_conn_request( self, cluster_mode: bool = False @@ -332,4 +376,12 @@ def _create_a_protobuf_conn_request( elif self.periodic_checks == PeriodicChecksStatus.DISABLED: request.periodic_checks_disabled.SetInParent() + if self.pubsub_subscriptions: + for channel_type, channels_patterns in self.pubsub_subscriptions.items(): + entry = request.pubsub_subscriptions.channels_or_patterns_by_type[ + int(channel_type) + ] + for channel_pattern in channels_patterns: + entry.channels_or_patterns.append(str.encode(channel_pattern)) + return request diff --git a/python/python/glide/redis_client.py b/python/python/glide/redis_client.py index 3d61d12b49..8f482c4871 100644 --- a/python/python/glide/redis_client.py +++ b/python/python/glide/redis_client.py @@ -2,7 +2,8 @@ import asyncio import threading -from typing import List, Optional, Tuple, Type, Union, cast +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type, Union, cast import async_timeout from glide.async_commands.cluster_commands import ClusterCommands @@ -60,6 +61,9 @@ def __init__(self, config: BaseClientConfiguration): self.socket_path: Optional[str] = None self._reader_task: Optional[asyncio.Task] = None self._is_closed: bool = False + self._pubsub_futures: List[asyncio.Future] = [] + self._pubsub_lock = threading.Lock() + self._pending_push_notifications: List[Response] = list() @classmethod async def create(cls, config: BaseClientConfiguration) -> Self: @@ -140,6 +144,14 @@ async def close(self, err_message: Optional[str] = None) -> None: if not response_future.done(): err_message = "" if err_message is None else err_message response_future.set_exception(ClosingError(err_message)) + try: + self._pubsub_lock.acquire() + for pubsub_future in self._pubsub_futures: + if not response_future.done() and not pubsub_future.cancelled(): + pubsub_future.set_exception(ClosingError("")) + finally: + self._pubsub_lock.release() + self._writer.close() await self._writer.wait_closed() self.__del__() @@ -243,6 +255,80 @@ async def _execute_script( set_protobuf_route(request, route) return await self._write_request_await_response(request) + @dataclass + class PubSubMsg: + message: str + channel: str + pattern: Optional[str] + + async def get_pubsub_message(self) -> PubSubMsg: + if self._is_closed: + raise ClosingError( + "Unable to execute requests; the client is closed. Please create a new client." + ) + # locking might not be required + response_future: asyncio.Future = asyncio.Future() + try: + self._pubsub_lock.acquire() + self._pubsub_futures.append(response_future) + self._push_pubsub_messages_safe() + finally: + self._pubsub_lock.release() + return await response_future + + def _cancel_pubsub_futures_with_exception_safe(self, exception: ConnectionError): + while len(self._pubsub_futures): + next_future = self._pubsub_futures.pop(0) + if not next_future.cancelled(): + next_future.set_exception(exception) + + def _push_pubsub_messages_safe(self): + while len(self._pending_push_notifications) and len(self._pubsub_futures): + next_push_notification = self._pending_push_notifications.pop(0) + next_push_notification: Dict = value_from_pointer( + next_push_notification.resp_pointer + ) + message_kind = next_push_notification["kind"] + if message_kind == "Disconnect": + # cancel all futures since we dont know how many (if any) messages wont arrive + # TODO: consider cancelling a single future + self._cancel_pubsub_futures_with_exception_safe( + ConnectionError( + "Warning, transport disconnect occured, messages might be lost" + ) + ) + elif ( + message_kind == "Message" + or message_kind == "PMessage" + or message_kind == "SMessage" + ): + next_future = self._pubsub_futures.pop(0) + values: List = next_push_notification["values"] + if message_kind == "PMessage": + msg = BaseRedisClient.PubSubMsg( + message=values[2], channel=values[1], pattern=values[0] + ) + else: + msg = BaseRedisClient.PubSubMsg( + message=values[1], channel=values[0], pattern=None + ) + next_future.set_result(msg) + elif ( + message_kind == "PSubscribe" + or message_kind == "Subscribe" + or message_kind == "SSubscribe" + or message_kind == "Unsubscribe" + ): + pass + else: + err_msg = f"Unsupported push message: '{message_kind}'" + ClientLogger.log(LogLevel.ERROR, "pubsub message", err_msg) + # cancel all futures since its a serious + # TODO: consider cancelling a single future + self._cancel_pubsub_futures_with_exception_safe( + ConnectionError(err_msg) + ) + async def _write_request_await_response(self, request: RedisRequest): # Create a response future for this request and add it to the available # futures map @@ -258,6 +344,47 @@ def _get_callback_index(self) -> int: # The list is empty return len(self._available_futures) + async def _process_response(self, response: Response) -> None: + res_future = self._available_futures.pop(response.callback_idx, None) + if not res_future or response.HasField("closing_error"): + err_msg = ( + response.closing_error + if response.HasField("closing_error") + else f"Client Error - closing due to unknown error. callback index: {response.callback_idx}" + ) + if res_future is not None: + res_future.set_exception(ClosingError(err_msg)) + await self.close(err_msg) + raise ClosingError(err_msg) + else: + self._available_callback_indexes.append(response.callback_idx) + if response.HasField("request_error"): + error_type = get_request_error_class(response.request_error.type) + res_future.set_exception(error_type(response.request_error.message)) + elif response.HasField("resp_pointer"): + res_future.set_result(value_from_pointer(response.resp_pointer)) + elif response.HasField("constant_response"): + res_future.set_result(OK) + else: + res_future.set_result(None) + + async def _process_push(self, response: Response) -> None: + if response.HasField("closing_error") or not response.HasField("resp_pointer"): + err_msg = ( + response.closing_error + if response.HasField("closing_error") + else "Client Error - push notification without resp_pointer" + ) + await self.close(err_msg) + raise ClosingError(err_msg) + + try: + self._pubsub_lock.acquire() + self._pending_push_notifications.append(response) + self._push_pubsub_messages_safe() + finally: + self._pubsub_lock.release() + async def _reader_loop(self) -> None: # Socket reader loop remaining_read_bytes = bytearray() @@ -280,32 +407,10 @@ async def _reader_loop(self) -> None: remaining_read_bytes = read_bytes[offset:] break response = cast(Response, response) - res_future = self._available_futures.pop(response.callback_idx, None) - if not res_future or response.HasField("closing_error"): - err_msg = ( - response.closing_error - if response.HasField("closing_error") - else f"Client Error - closing due to unknown error. callback index: {response.callback_idx}" - ) - if res_future is not None: - res_future.set_exception(ClosingError(err_msg)) - await self.close(err_msg) - raise ClosingError(err_msg) + if response.is_push: + await self._process_push(response=response) else: - self._available_callback_indexes.append(response.callback_idx) - if response.HasField("request_error"): - error_type = get_request_error_class( - response.request_error.type - ) - res_future.set_exception( - error_type(response.request_error.message) - ) - elif response.HasField("resp_pointer"): - res_future.set_result(value_from_pointer(response.resp_pointer)) - elif response.HasField("constant_response"): - res_future.set_result(OK) - else: - res_future.set_result(None) + await self._process_response(response=response) class RedisClusterClient(BaseRedisClient, ClusterCommands): diff --git a/python/python/tests/conftest.py b/python/python/tests/conftest.py index 7462cf9565..a1723e7a97 100644 --- a/python/python/tests/conftest.py +++ b/python/python/tests/conftest.py @@ -223,6 +223,12 @@ async def create_client( client_name: Optional[str] = None, protocol: ProtocolVersion = ProtocolVersion.RESP3, timeout: Optional[int] = None, + cluster_mode_pubsub: Optional[ + ClusterClientConfiguration.PubSubSubscriptions + ] = None, + standalone_mode_pubsub: Optional[ + RedisClientConfiguration.PubSubSubscriptions + ] = None, ) -> Union[RedisClient, RedisClusterClient]: # Create async socket client use_tls = request.config.getoption("--tls") @@ -238,6 +244,7 @@ async def create_client( client_name=client_name, protocol=protocol, request_timeout=timeout, + pubsub_subscriptions=cluster_mode_pubsub, ) return await RedisClusterClient.create(cluster_config) else: @@ -252,6 +259,7 @@ async def create_client( client_name=client_name, protocol=protocol, request_timeout=timeout, + pubsub_subscriptions=standalone_mode_pubsub, ) return await RedisClient.create(config) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index d1b2510655..fce1beff9b 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -18,7 +18,6 @@ ExpireOptions, ExpirySet, ExpiryType, - InfBound, InfoSection, InsertPosition, StreamAddOptions, @@ -43,7 +42,7 @@ ScoreBoundary, ScoreFilter, ) -from glide.config import ProtocolVersion, RedisCredentials +from glide.config import ClusterClientConfiguration, ProtocolVersion, RedisCredentials from glide.constants import OK, TResult from glide.redis_client import RedisClient, RedisClusterClient, TRedisClient from glide.routes import ( @@ -5328,3 +5327,87 @@ async def test_script(self, redis_client: TRedisClient): script = Script("return redis.call('GET', KEYS[1])") assert await redis_client.invoke_script(script, keys=[key1]) == "value1" assert await redis_client.invoke_script(script, keys=[key2]) == "value2" + + +@pytest.mark.asyncio +class TestPubSub: + + async def test_pubsub_basic_standalone(self, request): + CHANNEL_NAME = "test-channel" + MESSAGE = "test-message" + PATTERN = "*" + + publishing_client: RedisClusterClient = await create_client( + request, cluster_mode=False + ) + + standalone_mode_pubsub: ClusterClientConfiguration.PubSubSubscriptions = {} + standalone_mode_pubsub[ClusterClientConfiguration.PubSubChannelModes.Exact] = { + CHANNEL_NAME + } + standalone_mode_pubsub[ + ClusterClientConfiguration.PubSubChannelModes.Pattern + ] = {PATTERN} + + listening_client = await create_client( + request, cluster_mode=False, standalone_mode_pubsub=standalone_mode_pubsub + ) + + await publishing_client.publish(MESSAGE, CHANNEL_NAME) + # allow the message to propagate + await asyncio.sleep(1) + + pattern_cnt = 0 + pattern = None + for _ in range(2): + pubsub_msg = await listening_client.get_pubsub_message() + assert pubsub_msg.channel == CHANNEL_NAME + assert pubsub_msg.message == MESSAGE + if pubsub_msg.pattern: + pattern_cnt += 1 + pattern = pubsub_msg.pattern + + assert pattern == PATTERN + assert pattern_cnt == 1 + + async def test_pubsub_basic_clustermode(self, request): + CHANNEL_NAME = "test-channel" + SHARDED_CHANNEL_NAME = "test-channel-sharded" + MESSAGE = "test-message" + + publishing_client: RedisClusterClient = await create_client( + request, cluster_mode=True + ) + test_sharded = not await check_if_server_version_lt(publishing_client, "7.0.0") + + cluster_mode_pubsub: ClusterClientConfiguration.PubSubSubscriptions = {} + cluster_mode_pubsub[ClusterClientConfiguration.PubSubChannelModes.Exact] = { + CHANNEL_NAME + } + if test_sharded: + cluster_mode_pubsub[ + ClusterClientConfiguration.PubSubChannelModes.Sharded + ] = {SHARDED_CHANNEL_NAME} + + listening_client = await create_client( + request, cluster_mode=True, cluster_mode_pubsub=cluster_mode_pubsub + ) + + await publishing_client.publish(MESSAGE, CHANNEL_NAME) + # allow the message to propagate + await asyncio.sleep(1) + + pubsub_msg = await listening_client.get_pubsub_message() + assert pubsub_msg.channel == CHANNEL_NAME + assert pubsub_msg.message == MESSAGE + assert pubsub_msg.pattern is None + + if test_sharded: + await publishing_client.publish(MESSAGE, SHARDED_CHANNEL_NAME, sharded=True) + # allow the message to propagate + await asyncio.sleep(1) + + pubsub_msg = await listening_client.get_pubsub_message() + assert pubsub_msg.channel == SHARDED_CHANNEL_NAME + assert pubsub_msg.message == MESSAGE + assert pubsub_msg.pattern is None diff --git a/python/src/lib.rs b/python/src/lib.rs index e1a799a0dd..4380b064c9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -145,7 +145,13 @@ fn glide(_py: Python, m: &PyModule) -> PyResult<()> { Value::Boolean(boolean) => Ok(PyBool::new(py, boolean).into_py(py)), Value::VerbatimString { format: _, text } => Ok(text.into_py(py)), Value::BigNumber(bigint) => Ok(bigint.into_py(py)), - Value::Push { kind: _, data: _ } => todo!(), + Value::Push { kind, data } => { + let dict = PyDict::new(py); + dict.set_item("kind", format!("{kind:?}"))?; + let values: &PyList = PyList::new(py, iter_to_value(py, data)?); + dict.set_item("values", values)?; + Ok(dict.into_py(py)) + } } } diff --git a/submodules/redis-rs b/submodules/redis-rs index b36c95947d..cb81fb77b0 160000 --- a/submodules/redis-rs +++ b/submodules/redis-rs @@ -1 +1 @@ -Subproject commit b36c95947d70fef1629fbc821890fdac99381d53 +Subproject commit cb81fb77b0dde6d57e3127158a17f6f81eac5a23 diff --git a/utils/cluster_manager.py b/utils/cluster_manager.py index ffa9f9af1e..6028757375 100644 --- a/utils/cluster_manager.py +++ b/utils/cluster_manager.py @@ -497,7 +497,7 @@ def wait_for_a_message_in_redis_logs( continue log_file = f"{dir}/redis.log" - if server_ports and str(dir) not in server_ports: + if server_ports and os.path.basename(os.path.normpath(dir)) not in server_ports: continue if not wait_for_message(log_file, message, 10): raise Exception(