Skip to content

Commit

Permalink
Fixed the initial nodes expander to return the socketAddr object and …
Browse files Browse the repository at this point in the history
…maintain the provided hostname, for TLS hostname verifications.
  • Loading branch information
barshaul committed Aug 31, 2023
1 parent 6fee44c commit dcbbc21
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 39 deletions.
11 changes: 9 additions & 2 deletions redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ where
}
}

pub(crate) async fn connect<C>(connection_info: &ConnectionInfo) -> RedisResult<Connection<C>>
pub(crate) async fn connect<C>(
connection_info: &ConnectionInfo,
socket_addr: Option<SocketAddr>,
) -> RedisResult<Connection<C>>
where
C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
{
let con = connect_simple::<C>(connection_info).await?;
let con = connect_simple::<C>(connection_info, socket_addr).await?;
Connection::new(&connection_info.redis, con).await
}

Expand Down Expand Up @@ -383,6 +386,7 @@ pub(crate) async fn get_socket_addrs(

pub(crate) async fn connect_simple<T: RedisRuntime>(
connection_info: &ConnectionInfo,
socket_addr: Option<SocketAddr>,
) -> RedisResult<T> {
Ok(match connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
Expand All @@ -396,6 +400,9 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
port,
insecure,
} => {
if let Some(socket_addr) = socket_addr {
return <T>::connect_tcp_tls(host, socket_addr, insecure).await;
}
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(
socket_addrs.map(|socket_addr| <T>::connect_tcp_tls(host, socket_addr, insecure)),
Expand Down
35 changes: 20 additions & 15 deletions redis/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Duration;
use std::{net::SocketAddr, time::Duration};

#[cfg(feature = "aio")]
use std::pin::Pin;
Expand Down Expand Up @@ -75,12 +75,12 @@ impl Client {
let con = match Runtime::locate() {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => {
self.get_simple_async_connection::<crate::aio::tokio::Tokio>()
self.get_simple_async_connection::<crate::aio::tokio::Tokio>(None)
.await?
}
#[cfg(feature = "async-std-comp")]
Runtime::AsyncStd => {
self.get_simple_async_connection::<crate::aio::async_std::AsyncStd>()
self.get_simple_async_connection::<crate::aio::async_std::AsyncStd>(None)
.await?
}
};
Expand All @@ -94,7 +94,7 @@ impl Client {
pub async fn get_tokio_connection(&self) -> RedisResult<crate::aio::Connection> {
use crate::aio::RedisRuntime;
Ok(
crate::aio::connect::<crate::aio::tokio::Tokio>(&self.connection_info)
crate::aio::connect::<crate::aio::tokio::Tokio>(&self.connection_info, None)
.await?
.map(RedisRuntime::boxed),
)
Expand All @@ -106,7 +106,7 @@ impl Client {
pub async fn get_async_std_connection(&self) -> RedisResult<crate::aio::Connection> {
use crate::aio::RedisRuntime;
Ok(
crate::aio::connect::<crate::aio::async_std::AsyncStd>(&self.connection_info)
crate::aio::connect::<crate::aio::async_std::AsyncStd>(&self.connection_info, None)
.await?
.map(RedisRuntime::boxed),
)
Expand Down Expand Up @@ -138,7 +138,7 @@ impl Client {
pub async fn get_multiplexed_tokio_connection(
&self,
) -> RedisResult<crate::aio::MultiplexedConnection> {
self.get_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>()
self.get_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(None)
.await
}

Expand All @@ -151,7 +151,7 @@ impl Client {
pub async fn get_multiplexed_async_std_connection(
&self,
) -> RedisResult<crate::aio::MultiplexedConnection> {
self.get_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>()
self.get_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(None)
.await
}

Expand All @@ -168,7 +168,7 @@ impl Client {
crate::aio::MultiplexedConnection,
impl std::future::Future<Output = ()>,
)> {
self.create_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>()
self.create_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(None)
.await
}

Expand All @@ -185,7 +185,7 @@ impl Client {
crate::aio::MultiplexedConnection,
impl std::future::Future<Output = ()>,
)> {
self.create_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>()
self.create_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(None)
.await
}

Expand Down Expand Up @@ -246,41 +246,46 @@ impl Client {
.await
}

async fn get_multiplexed_async_connection_inner<T>(
pub(crate) async fn get_multiplexed_async_connection_inner<T>(
&self,
socket_addr: Option<SocketAddr>,
) -> RedisResult<crate::aio::MultiplexedConnection>
where
T: crate::aio::RedisRuntime,
{
let (connection, driver) = self
.create_multiplexed_async_connection_inner::<T>()
.create_multiplexed_async_connection_inner::<T>(socket_addr)
.await?;
T::spawn(driver);
Ok(connection)
}

async fn create_multiplexed_async_connection_inner<T>(
&self,
socket_addr: Option<SocketAddr>,
) -> RedisResult<(
crate::aio::MultiplexedConnection,
impl std::future::Future<Output = ()>,
)>
where
T: crate::aio::RedisRuntime,
{
let con = self.get_simple_async_connection::<T>().await?;
let con = self.get_simple_async_connection::<T>(socket_addr).await?;
crate::aio::MultiplexedConnection::new(&self.connection_info.redis, con).await
}

async fn get_simple_async_connection<T>(
&self,
socket_addr: Option<SocketAddr>,
) -> RedisResult<Pin<Box<dyn crate::aio::AsyncStream + Send + Sync>>>
where
T: crate::aio::RedisRuntime,
{
Ok(crate::aio::connect_simple::<T>(&self.connection_info)
.await?
.boxed())
Ok(
crate::aio::connect_simple::<T>(&self.connection_info, socket_addr)
.await?
.boxed(),
)
}

#[cfg(feature = "connection-manager")]
Expand Down
52 changes: 35 additions & 17 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use std::{
iter::Iterator,
marker::Unpin,
mem,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{self, AtomicUsize},
Expand Down Expand Up @@ -489,9 +490,10 @@ where

/// Go through each of the initial nodes and attempt to retrieve all IP entries from them.
/// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list.
/// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved.
pub(crate) async fn try_to_expand_initial_nodes(
initial_nodes: &[ConnectionInfo],
) -> Vec<String> {
) -> Vec<(String, Option<SocketAddr>)> {
stream::iter(initial_nodes)
.fold(
Vec::with_capacity(initial_nodes.len()),
Expand All @@ -505,19 +507,19 @@ where
} => (host, port),
crate::ConnectionAddr::Unix(_) => {
// We don't support multiple addresses for a Unix address. Store the initial node address and continue
acc.push(info.addr.to_string());
acc.push((info.addr.to_string(), None));
return acc;
}
};
match get_socket_addrs(host, *port).await {
Ok(socket_addrs) => {
for addr in socket_addrs {
acc.push(addr.to_string());
acc.push((info.addr.to_string(), Some(addr)));
}
}
Err(_) => {
// Couldn't find socket addresses, store the initial node address and continue
acc.push(info.addr.to_string());
acc.push((info.addr.to_string(), None));
}
};
acc
Expand All @@ -530,14 +532,15 @@ where
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
let initial_nodes: Vec<String> = Self::try_to_expand_initial_nodes(initial_nodes).await;
let initial_nodes: Vec<(String, Option<SocketAddr>)> =
Self::try_to_expand_initial_nodes(initial_nodes).await;
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|addr| {
.map(|node| {
let params = params.clone();
async move {
let result = connect_and_check(&addr, params).await;
let result = connect_and_check(&node.0, params, node.1).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Ok(conn) => Some((node.0, async { conn }.boxed().shared())),
Err(e) => {
trace!("Failed to connect to initial node: {:?}", e);
None
Expand Down Expand Up @@ -952,7 +955,7 @@ where

let addr_conn_option = match conn {
Some((addr, Some(conn))) => Some((addr, conn.await)),
Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone())
Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone(), None)
.await
.ok()
.map(|conn| (addr, conn)),
Expand Down Expand Up @@ -1119,10 +1122,10 @@ where
let mut conn = conn.await;
match check_connection(&mut conn, params.connection_timeout.into()).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params.clone()).await,
Err(_) => connect_and_check(addr, params.clone(), None).await,
}
} else {
connect_and_check(addr, params.clone()).await
connect_and_check(addr, params.clone(), None).await
}
}
}
Expand Down Expand Up @@ -1294,13 +1297,16 @@ where
/// and obtaining a connection handle.
pub trait Connect: Sized {
/// Connect to a node, returning handle for command execution.
fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
fn connect<'a, T>(info: T, socket_addr: Option<SocketAddr>) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a;
}

impl Connect for MultiplexedConnection {
fn connect<'a, T>(info: T) -> RedisFuture<'a, MultiplexedConnection>
fn connect<'a, T>(
info: T,
socket_addr: Option<SocketAddr>,
) -> RedisFuture<'a, MultiplexedConnection>
where
T: IntoConnectionInfo + Send + 'a,
{
Expand All @@ -1309,23 +1315,35 @@ impl Connect for MultiplexedConnection {
let client = crate::Client::open(connection_info)?;

#[cfg(feature = "tokio-comp")]
return client.get_multiplexed_tokio_connection().await;
return client
.get_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(socket_addr)
.await;

#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
return client.get_multiplexed_async_std_connection().await;
return client
.get_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(
socket_addr,
)
.await;
}
.boxed()
}
}

async fn connect_and_check<C>(node: &str, params: ClusterParams) -> RedisResult<C>
async fn connect_and_check<C>(
node: &str,
params: ClusterParams,
socket_addr: Option<SocketAddr>,
) -> RedisResult<C>
where
C: ConnectionLike + Connect + Send + 'static,
{
let read_from_replicas = params.read_from_replicas;
let connection_timeout = params.connection_timeout.into();
let info = get_connection_info(node, params)?;
let mut conn: C = C::connect(info).timeout(connection_timeout).await??;
let mut conn: C = C::connect(info, socket_addr)
.timeout(connection_timeout)
.await??;
check_connection(&mut conn, connection_timeout).await?;
if read_from_replicas {
// If READONLY is sent to primary nodes, it will have no effect
Expand Down
6 changes: 3 additions & 3 deletions redis/tests/support/mock_cluster.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use redis::cluster::{self, ClusterClient, ClusterClientBuilder};
use std::{
collections::HashMap,
net::SocketAddr,
sync::{Arc, RwLock},
time::Duration,
};

use redis::cluster::{self, ClusterClient, ClusterClientBuilder};

use {
once_cell::sync::Lazy,
redis::{IntoConnectionInfo, RedisResult, Value},
Expand All @@ -32,7 +32,7 @@ pub struct MockConnection {

#[cfg(feature = "cluster-async")]
impl cluster_async::Connect for MockConnection {
fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
fn connect<'a, T>(info: T, _socket_addr: Option<SocketAddr>) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a,
{
Expand Down
5 changes: 3 additions & 2 deletions redis/tests/test_cluster_async.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![cfg(feature = "cluster-async")]
mod support;
use std::net::SocketAddr;
use std::sync::{
atomic::{self, AtomicI32, AtomicU16},
atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -241,12 +242,12 @@ struct ErrorConnection {
}

impl Connect for ErrorConnection {
fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
fn connect<'a, T>(info: T, _socket_addr: Option<SocketAddr>) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a,
{
Box::pin(async {
let inner = MultiplexedConnection::connect(info).await?;
let inner = MultiplexedConnection::connect(info, None).await?;
Ok(ErrorConnection { inner })
})
}
Expand Down

0 comments on commit dcbbc21

Please sign in to comment.