Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the initial nodes expander #30

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,11 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
ref host,
port,
insecure,
socket_addr,
} => {
if let Some(socket_addr) = socket_addr {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor:

let socket_addrs = if let Some(socket_addr) = socket_addr {
  socket_addr 
 } else {
   get_socket_addrs(host, port).await? 
};
select_ok(
    socket_addrs.map(|socket_addr| <T>::connect_tcp_tls(host, socket_addr, insecure)),
)
.await?
.0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_socket_addrs return an iterator

return <T>::connect_tcp_tls(host, socket_addr, insecure).await;
}
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(
socket_addrs.map(|socket_addr| <T>::connect_tcp_tls(host, socket_addr, insecure)),
Expand Down
19 changes: 14 additions & 5 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
//! ```
use std::cell::RefCell;
use std::iter::Iterator;
use std::net::SocketAddr;
use std::str::FromStr;
use std::thread;
use std::time::Duration;
Expand Down Expand Up @@ -309,7 +310,7 @@ where
}

fn connect(&self, node: &str) -> RedisResult<C> {
let info = get_connection_info(node, self.cluster_params.clone())?;
let info = get_connection_info(node, self.cluster_params.clone(), None)?;

let mut conn = C::connect(info, Some(self.cluster_params.connection_timeout))?;
if self.read_from_replicas {
Expand Down Expand Up @@ -708,6 +709,7 @@ fn get_random_connection<C: ConnectionLike + Connect + Sized>(
pub(crate) fn get_connection_info(
node: &str,
cluster_params: ClusterParams,
socket_addr: Option<SocketAddr>,
) -> RedisResult<ConnectionInfo> {
let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string");

Expand All @@ -721,7 +723,7 @@ pub(crate) fn get_connection_info(
.ok_or_else(invalid_error)?;

Ok(ConnectionInfo {
addr: get_connection_addr(host.to_string(), port, cluster_params.tls),
addr: get_connection_addr(host.to_string(), port, cluster_params.tls, socket_addr),
redis: RedisConnectionInfo {
password: cluster_params.password,
username: cluster_params.username,
Expand All @@ -730,17 +732,24 @@ pub(crate) fn get_connection_info(
})
}

pub(crate) fn get_connection_addr(host: String, port: u16, tls: Option<TlsMode>) -> ConnectionAddr {
pub(crate) fn get_connection_addr(
host: String,
port: u16,
tls: Option<TlsMode>,
socket_addr: Option<SocketAddr>,
) -> ConnectionAddr {
match tls {
Some(TlsMode::Secure) => ConnectionAddr::TcpTls {
host,
port,
insecure: false,
socket_addr,
},
Some(TlsMode::Insecure) => ConnectionAddr::TcpTls {
host,
port,
insecure: true,
socket_addr,
},
_ => ConnectionAddr::Tcp(host, port),
}
Expand Down Expand Up @@ -778,13 +787,13 @@ mod tests {
];

for (input, expected) in cases {
let res = get_connection_info(input, ClusterParams::default());
let res = get_connection_info(input, ClusterParams::default(), None);
assert_eq!(res.unwrap().addr, expected);
}

let cases = vec![":0", "[]:6379"];
for input in cases {
let res = get_connection_info(input, ClusterParams::default());
let res = get_connection_info(input, ClusterParams::default(), None);
assert_eq!(
res.err(),
Some(RedisError::from((
Expand Down
34 changes: 21 additions & 13 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use std::{
iter::Iterator,
marker::Unpin,
mem,
net::SocketAddr,
pin::Pin,
sync::{
atomic::{self, AtomicUsize},
Expand Down Expand Up @@ -489,9 +490,10 @@ where

/// Go through each of the initial nodes and attempt to retrieve all IP entries from them.
/// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list.
/// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved.
pub(crate) async fn try_to_expand_initial_nodes(
initial_nodes: &[ConnectionInfo],
) -> Vec<String> {
) -> Vec<(String, Option<SocketAddr>)> {
stream::iter(initial_nodes)
.fold(
Vec::with_capacity(initial_nodes.len()),
Expand All @@ -502,22 +504,23 @@ where
host,
port,
insecure: _,
socket_addr: _,
} => (host, port),
crate::ConnectionAddr::Unix(_) => {
// We don't support multiple addresses for a Unix address. Store the initial node address and continue
acc.push(info.addr.to_string());
acc.push((info.addr.to_string(), None));
return acc;
}
};
match get_socket_addrs(host, *port).await {
Ok(socket_addrs) => {
for addr in socket_addrs {
acc.push(addr.to_string());
acc.push((info.addr.to_string(), Some(addr)));
}
}
Err(_) => {
// Couldn't find socket addresses, store the initial node address and continue
acc.push(info.addr.to_string());
acc.push((info.addr.to_string(), None));
}
};
acc
Expand All @@ -530,14 +533,15 @@ where
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
let initial_nodes: Vec<String> = Self::try_to_expand_initial_nodes(initial_nodes).await;
let initial_nodes: Vec<(String, Option<SocketAddr>)> =
Self::try_to_expand_initial_nodes(initial_nodes).await;
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|addr| {
.map(|node| {
let params = params.clone();
async move {
let result = connect_and_check(&addr, params).await;
let result = connect_and_check(&node.0, params, node.1).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Ok(conn) => Some((node.0, async { conn }.boxed().shared())),
Err(e) => {
trace!("Failed to connect to initial node: {:?}", e);
None
Expand Down Expand Up @@ -952,7 +956,7 @@ where

let addr_conn_option = match conn {
Some((addr, Some(conn))) => Some((addr, conn.await)),
Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone())
Some((addr, None)) => connect_and_check(&addr, core.cluster_params.clone(), None)
.await
.ok()
.map(|conn| (addr, conn)),
Expand Down Expand Up @@ -1119,10 +1123,10 @@ where
let mut conn = conn.await;
match check_connection(&mut conn, params.connection_timeout.into()).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params.clone()).await,
Err(_) => connect_and_check(addr, params.clone(), None).await,
}
} else {
connect_and_check(addr, params.clone()).await
connect_and_check(addr, params.clone(), None).await
}
}
}
Expand Down Expand Up @@ -1318,13 +1322,17 @@ impl Connect for MultiplexedConnection {
}
}

async fn connect_and_check<C>(node: &str, params: ClusterParams) -> RedisResult<C>
async fn connect_and_check<C>(
node: &str,
params: ClusterParams,
socket_addr: Option<SocketAddr>,
) -> RedisResult<C>
where
C: ConnectionLike + Connect + Send + 'static,
{
let read_from_replicas = params.read_from_replicas;
let connection_timeout = params.connection_timeout.into();
let info = get_connection_info(node, params)?;
let info = get_connection_info(node, params, socket_addr)?;
let mut conn: C = C::connect(info).timeout(connection_timeout).await??;
check_connection(&mut conn, connection_timeout).await?;
if read_from_replicas {
Expand Down
1 change: 1 addition & 0 deletions redis/src/cluster_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl ClusterClientBuilder {
host: _,
port: _,
insecure,
socket_addr: _,
} => Some(match insecure {
false => TlsMode::Secure,
true => TlsMode::Insecure,
Expand Down
2 changes: 1 addition & 1 deletion redis/src/cluster_topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ pub(crate) fn parse_slots(raw_slot_resp: &Value, tls: Option<TlsMode>) -> RedisR
} else {
return None;
};
Some(get_connection_addr(ip.into_owned(), port, tls).to_string())
Some(get_connection_addr(ip.into_owned(), port, tls, None).to_string())
} else {
None
}
Expand Down
10 changes: 9 additions & 1 deletion redis/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, Write};
use std::net::{self, TcpStream, ToSocketAddrs};
use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
use std::ops::DerefMut;
use std::path::PathBuf;
use std::str::{from_utf8, FromStr};
Expand Down Expand Up @@ -64,6 +64,10 @@ pub enum ConnectionAddr {
host: String,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to doc "This name should remain the DNS name, for the purpose of TLS hostname verification. IP addresses can be saved in socket_addr."

/// Port
port: u16,
/// Optional - An internet socket address for this node.
barshaul marked this conversation as resolved.
Show resolved Hide resolved
/// If the hostname is a DNS endpoint, the socket address will encompass the actual IP.
/// The hostname should be preserved in the connection information as a DNS name for the purpose of TLS hostname verification.
socket_addr: Option<SocketAddr>,
/// Disable hostname verification when connecting.
///
/// # Warning
Expand Down Expand Up @@ -212,6 +216,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
host,
port,
insecure: true,
socket_addr: None,
},
Some(_) => fail!((
ErrorKind::InvalidClientConfig,
Expand All @@ -221,6 +226,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
host,
port,
insecure: false,
socket_addr: None,
},
}
}
Expand Down Expand Up @@ -433,6 +439,7 @@ impl ActualConnection {
ref host,
port,
insecure,
..
} => {
let tls_connector = if insecure {
TlsConnector::builder()
Expand Down Expand Up @@ -492,6 +499,7 @@ impl ActualConnection {
ref host,
port,
insecure,
socket_addr: _,
} => {
let host: &str = host;
let config = create_rustls_config(insecure)?;
Expand Down
1 change: 1 addition & 0 deletions redis/tests/support/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl ClusterType {
host: "127.0.0.1".into(),
port,
insecure: true,
socket_addr: None,
},
}
}
Expand Down
2 changes: 2 additions & 0 deletions redis/tests/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl RedisServer {
host: "127.0.0.1".to_string(),
port: redis_port,
insecure: true,
socket_addr: None,
}
} else {
redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), redis_port)
Expand Down Expand Up @@ -199,6 +200,7 @@ impl RedisServer {
host: host.clone(),
port,
insecure: true,
socket_addr: None,
};

RedisServer {
Expand Down