From dcbbc2132aaba1e6e5dfd90dab389e1fc98d4d54 Mon Sep 17 00:00:00 2001 From: barshaul Date: Wed, 30 Aug 2023 12:46:08 +0000 Subject: [PATCH] Fixed the initial nodes expander to return the socketAddr object and maintain the provided hostname, for TLS hostname verifications. --- redis/src/aio/connection.rs | 11 ++++-- redis/src/client.rs | 35 ++++++++++--------- redis/src/cluster_async/mod.rs | 52 +++++++++++++++++++---------- redis/tests/support/mock_cluster.rs | 6 ++-- redis/tests/test_cluster_async.rs | 5 +-- 5 files changed, 70 insertions(+), 39 deletions(-) diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index f997778a0..ece019705 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -176,11 +176,14 @@ where } } -pub(crate) async fn connect(connection_info: &ConnectionInfo) -> RedisResult> +pub(crate) async fn connect( + connection_info: &ConnectionInfo, + socket_addr: Option, +) -> RedisResult> where C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send, { - let con = connect_simple::(connection_info).await?; + let con = connect_simple::(connection_info, socket_addr).await?; Connection::new(&connection_info.redis, con).await } @@ -383,6 +386,7 @@ pub(crate) async fn get_socket_addrs( pub(crate) async fn connect_simple( connection_info: &ConnectionInfo, + socket_addr: Option, ) -> RedisResult { Ok(match connection_info.addr { ConnectionAddr::Tcp(ref host, port) => { @@ -396,6 +400,9 @@ pub(crate) async fn connect_simple( port, insecure, } => { + if let Some(socket_addr) = socket_addr { + return ::connect_tcp_tls(host, socket_addr, insecure).await; + } let socket_addrs = get_socket_addrs(host, port).await?; select_ok( socket_addrs.map(|socket_addr| ::connect_tcp_tls(host, socket_addr, insecure)), diff --git a/redis/src/client.rs b/redis/src/client.rs index dd700aa0a..1b23f7697 100644 --- a/redis/src/client.rs +++ b/redis/src/client.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{net::SocketAddr, time::Duration}; #[cfg(feature = "aio")] use std::pin::Pin; @@ -75,12 +75,12 @@ impl Client { let con = match Runtime::locate() { #[cfg(feature = "tokio-comp")] Runtime::Tokio => { - self.get_simple_async_connection::() + self.get_simple_async_connection::(None) .await? } #[cfg(feature = "async-std-comp")] Runtime::AsyncStd => { - self.get_simple_async_connection::() + self.get_simple_async_connection::(None) .await? } }; @@ -94,7 +94,7 @@ impl Client { pub async fn get_tokio_connection(&self) -> RedisResult { use crate::aio::RedisRuntime; Ok( - crate::aio::connect::(&self.connection_info) + crate::aio::connect::(&self.connection_info, None) .await? .map(RedisRuntime::boxed), ) @@ -106,7 +106,7 @@ impl Client { pub async fn get_async_std_connection(&self) -> RedisResult { use crate::aio::RedisRuntime; Ok( - crate::aio::connect::(&self.connection_info) + crate::aio::connect::(&self.connection_info, None) .await? .map(RedisRuntime::boxed), ) @@ -138,7 +138,7 @@ impl Client { pub async fn get_multiplexed_tokio_connection( &self, ) -> RedisResult { - self.get_multiplexed_async_connection_inner::() + self.get_multiplexed_async_connection_inner::(None) .await } @@ -151,7 +151,7 @@ impl Client { pub async fn get_multiplexed_async_std_connection( &self, ) -> RedisResult { - self.get_multiplexed_async_connection_inner::() + self.get_multiplexed_async_connection_inner::(None) .await } @@ -168,7 +168,7 @@ impl Client { crate::aio::MultiplexedConnection, impl std::future::Future, )> { - self.create_multiplexed_async_connection_inner::() + self.create_multiplexed_async_connection_inner::(None) .await } @@ -185,7 +185,7 @@ impl Client { crate::aio::MultiplexedConnection, impl std::future::Future, )> { - self.create_multiplexed_async_connection_inner::() + self.create_multiplexed_async_connection_inner::(None) .await } @@ -246,14 +246,15 @@ impl Client { .await } - async fn get_multiplexed_async_connection_inner( + pub(crate) async fn get_multiplexed_async_connection_inner( &self, + socket_addr: Option, ) -> RedisResult where T: crate::aio::RedisRuntime, { let (connection, driver) = self - .create_multiplexed_async_connection_inner::() + .create_multiplexed_async_connection_inner::(socket_addr) .await?; T::spawn(driver); Ok(connection) @@ -261,6 +262,7 @@ impl Client { async fn create_multiplexed_async_connection_inner( &self, + socket_addr: Option, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -268,19 +270,22 @@ impl Client { where T: crate::aio::RedisRuntime, { - let con = self.get_simple_async_connection::().await?; + let con = self.get_simple_async_connection::(socket_addr).await?; crate::aio::MultiplexedConnection::new(&self.connection_info.redis, con).await } async fn get_simple_async_connection( &self, + socket_addr: Option, ) -> RedisResult>> where T: crate::aio::RedisRuntime, { - Ok(crate::aio::connect_simple::(&self.connection_info) - .await? - .boxed()) + Ok( + crate::aio::connect_simple::(&self.connection_info, socket_addr) + .await? + .boxed(), + ) } #[cfg(feature = "connection-manager")] diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 19755fc81..4907c8d59 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -27,6 +27,7 @@ use std::{ iter::Iterator, marker::Unpin, mem, + net::SocketAddr, pin::Pin, sync::{ atomic::{self, AtomicUsize}, @@ -489,9 +490,10 @@ where /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. /// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + /// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved. pub(crate) async fn try_to_expand_initial_nodes( initial_nodes: &[ConnectionInfo], - ) -> Vec { + ) -> Vec<(String, Option)> { stream::iter(initial_nodes) .fold( Vec::with_capacity(initial_nodes.len()), @@ -505,19 +507,19 @@ where } => (host, port), crate::ConnectionAddr::Unix(_) => { // We don't support multiple addresses for a Unix address. Store the initial node address and continue - acc.push(info.addr.to_string()); + acc.push((info.addr.to_string(), None)); return acc; } }; match get_socket_addrs(host, *port).await { Ok(socket_addrs) => { for addr in socket_addrs { - acc.push(addr.to_string()); + acc.push((info.addr.to_string(), Some(addr))); } } Err(_) => { // Couldn't find socket addresses, store the initial node address and continue - acc.push(info.addr.to_string()); + acc.push((info.addr.to_string(), None)); } }; acc @@ -530,14 +532,15 @@ where initial_nodes: &[ConnectionInfo], params: &ClusterParams, ) -> RedisResult> { - let initial_nodes: Vec = Self::try_to_expand_initial_nodes(initial_nodes).await; + let initial_nodes: Vec<(String, Option)> = + Self::try_to_expand_initial_nodes(initial_nodes).await; let connections = stream::iter(initial_nodes.iter().cloned()) - .map(|addr| { + .map(|node| { let params = params.clone(); async move { - let result = connect_and_check(&addr, params).await; + let result = connect_and_check(&node.0, params, node.1).await; match result { - Ok(conn) => Some((addr, async { conn }.boxed().shared())), + Ok(conn) => Some((node.0, async { conn }.boxed().shared())), Err(e) => { trace!("Failed to connect to initial node: {:?}", e); None @@ -952,7 +955,7 @@ where let addr_conn_option = match conn { Some((addr, Some(conn))) => Some((addr, conn.await)), - Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone()) + Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone(), None) .await .ok() .map(|conn| (addr, conn)), @@ -1119,10 +1122,10 @@ where let mut conn = conn.await; match check_connection(&mut conn, params.connection_timeout.into()).await { Ok(_) => Ok(conn), - Err(_) => connect_and_check(addr, params.clone()).await, + Err(_) => connect_and_check(addr, params.clone(), None).await, } } else { - connect_and_check(addr, params.clone()).await + connect_and_check(addr, params.clone(), None).await } } } @@ -1294,13 +1297,16 @@ where /// and obtaining a connection handle. pub trait Connect: Sized { /// Connect to a node, returning handle for command execution. - fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + fn connect<'a, T>(info: T, socket_addr: Option) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a; } impl Connect for MultiplexedConnection { - fn connect<'a, T>(info: T) -> RedisFuture<'a, MultiplexedConnection> + fn connect<'a, T>( + info: T, + socket_addr: Option, + ) -> RedisFuture<'a, MultiplexedConnection> where T: IntoConnectionInfo + Send + 'a, { @@ -1309,23 +1315,35 @@ impl Connect for MultiplexedConnection { let client = crate::Client::open(connection_info)?; #[cfg(feature = "tokio-comp")] - return client.get_multiplexed_tokio_connection().await; + return client + .get_multiplexed_async_connection_inner::(socket_addr) + .await; #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] - return client.get_multiplexed_async_std_connection().await; + return client + .get_multiplexed_async_connection_inner::( + socket_addr, + ) + .await; } .boxed() } } -async fn connect_and_check(node: &str, params: ClusterParams) -> RedisResult +async fn connect_and_check( + node: &str, + params: ClusterParams, + socket_addr: Option, +) -> RedisResult where C: ConnectionLike + Connect + Send + 'static, { let read_from_replicas = params.read_from_replicas; let connection_timeout = params.connection_timeout.into(); let info = get_connection_info(node, params)?; - let mut conn: C = C::connect(info).timeout(connection_timeout).await??; + let mut conn: C = C::connect(info, socket_addr) + .timeout(connection_timeout) + .await??; check_connection(&mut conn, connection_timeout).await?; if read_from_replicas { // If READONLY is sent to primary nodes, it will have no effect diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs index f6c8e7746..c55095359 100644 --- a/redis/tests/support/mock_cluster.rs +++ b/redis/tests/support/mock_cluster.rs @@ -1,11 +1,11 @@ +use redis::cluster::{self, ClusterClient, ClusterClientBuilder}; use std::{ collections::HashMap, + net::SocketAddr, sync::{Arc, RwLock}, time::Duration, }; -use redis::cluster::{self, ClusterClient, ClusterClientBuilder}; - use { once_cell::sync::Lazy, redis::{IntoConnectionInfo, RedisResult, Value}, @@ -32,7 +32,7 @@ pub struct MockConnection { #[cfg(feature = "cluster-async")] impl cluster_async::Connect for MockConnection { - fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + fn connect<'a, T>(info: T, _socket_addr: Option) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a, { diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index ef47ff178..c017307a7 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -1,5 +1,6 @@ #![cfg(feature = "cluster-async")] mod support; +use std::net::SocketAddr; use std::sync::{ atomic::{self, AtomicI32, AtomicU16}, atomic::{AtomicBool, Ordering}, @@ -241,12 +242,12 @@ struct ErrorConnection { } impl Connect for ErrorConnection { - fn connect<'a, T>(info: T) -> RedisFuture<'a, Self> + fn connect<'a, T>(info: T, _socket_addr: Option) -> RedisFuture<'a, Self> where T: IntoConnectionInfo + Send + 'a, { Box::pin(async { - let inner = MultiplexedConnection::connect(info).await?; + let inner = MultiplexedConnection::connect(info, None).await?; Ok(ErrorConnection { inner }) }) }