Skip to content

Commit

Permalink
Merge pull request #27 from barshaul/dns_check
Browse files Browse the repository at this point in the history
Added a check for DNS updates
  • Loading branch information
nihohit authored Sep 14, 2023
2 parents 37a627c + 880ea1a commit 371e52c
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 99 deletions.
35 changes: 26 additions & 9 deletions redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use futures_util::{
future::FutureExt,
stream::{Stream, StreamExt},
};
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use tokio_util::codec::Decoder;
Expand Down Expand Up @@ -183,7 +183,7 @@ pub(crate) async fn connect<C>(
where
C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
{
let con = connect_simple::<C>(connection_info, socket_addr).await?;
let (con, _ip) = connect_simple::<C>(connection_info, socket_addr).await?;
Connection::new(&connection_info.redis, con).await
}

Expand Down Expand Up @@ -387,11 +387,20 @@ pub(crate) async fn get_socket_addrs(
pub(crate) async fn connect_simple<T: RedisRuntime>(
connection_info: &ConnectionInfo,
socket_addr: Option<SocketAddr>,
) -> RedisResult<T> {
) -> RedisResult<(T, Option<IpAddr>)> {
Ok(match connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(socket_addrs.map(<T>::connect_tcp)).await?.0
select_ok(socket_addrs.map(|socket_addr| {
Box::pin(async move {
Ok::<_, RedisError>((
<T>::connect_tcp(socket_addr).await?,
Some(socket_addr.ip()),
))
})
}))
.await?
.0
}

#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
Expand All @@ -401,12 +410,20 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
insecure,
} => {
if let Some(socket_addr) = socket_addr {
return <T>::connect_tcp_tls(host, socket_addr, insecure).await;
return Ok::<_, RedisError>((
<T>::connect_tcp_tls(host, socket_addr, insecure).await?,
Some(socket_addr.ip()),
));
}
let socket_addrs = get_socket_addrs(host, port).await?;
select_ok(
socket_addrs.map(|socket_addr| <T>::connect_tcp_tls(host, socket_addr, insecure)),
)
select_ok(socket_addrs.map(|socket_addr| {
Box::pin(async move {
Ok::<_, RedisError>((
<T>::connect_tcp_tls(host, socket_addr, insecure).await?,
Some(socket_addr.ip()),
))
})
}))
.await?
.0
}
Expand All @@ -420,7 +437,7 @@ pub(crate) async fn connect_simple<T: RedisRuntime>(
}

#[cfg(unix)]
ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
ConnectionAddr::Unix(ref path) => (<T>::connect_unix(path).await?, None),

#[cfg(not(unix))]
ConnectionAddr::Unix(_) => {
Expand Down
57 changes: 45 additions & 12 deletions redis/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::time::Duration;

use std::net::IpAddr;
#[cfg(feature = "aio")]
use std::net::SocketAddr;
#[cfg(feature = "aio")]
Expand Down Expand Up @@ -74,7 +75,7 @@ impl Client {
/// Returns an async connection from the client.
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
pub async fn get_async_connection(&self) -> RedisResult<crate::aio::Connection> {
let con = match Runtime::locate() {
let (con, _ip) = match Runtime::locate() {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => {
self.get_simple_async_connection::<crate::aio::tokio::Tokio>(None)
Expand Down Expand Up @@ -131,6 +132,30 @@ impl Client {
}
}

/// For TCP connections: returns (async connection, Some(the direct IP address))
/// For Unix connections, returns (async connection, None)
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp")))
)]
pub async fn get_multiplexed_async_connection_and_ip(
&self,
) -> RedisResult<(crate::aio::MultiplexedConnection, Option<IpAddr>)> {
match Runtime::locate() {
#[cfg(feature = "tokio-comp")]
Runtime::Tokio => {
self.get_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(None)
.await
}
#[cfg(feature = "async-std-comp")]
Runtime::AsyncStd => {
self.get_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(None)
.await
}
}
}

/// Returns an async multiplexed connection from the client.
///
/// A multiplexed connection can be cloned, allowing requests to be be sent concurrently
Expand All @@ -142,6 +167,7 @@ impl Client {
) -> RedisResult<crate::aio::MultiplexedConnection> {
self.get_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(None)
.await
.map(|conn_and_ip| conn_and_ip.0)
}

/// Returns an async multiplexed connection from the client.
Expand All @@ -155,6 +181,7 @@ impl Client {
) -> RedisResult<crate::aio::MultiplexedConnection> {
self.get_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(None)
.await
.map(|conn_and_ip| conn_and_ip.0)
}

/// Returns an async multiplexed connection from the client and a future which must be polled
Expand All @@ -172,6 +199,7 @@ impl Client {
)> {
self.create_multiplexed_async_connection_inner::<crate::aio::tokio::Tokio>(None)
.await
.map(|conn_res| (conn_res.0, conn_res.1))
}

/// Returns an async multiplexed connection from the client and a future which must be polled
Expand All @@ -189,6 +217,7 @@ impl Client {
)> {
self.create_multiplexed_async_connection_inner::<crate::aio::async_std::AsyncStd>(None)
.await
.map(|(conn, conn_future, _ip)| (conn, conn_future))
}

/// Returns an async [`ConnectionManager`][connection-manager] from the client.
Expand Down Expand Up @@ -251,15 +280,15 @@ impl Client {
pub(crate) async fn get_multiplexed_async_connection_inner<T>(
&self,
socket_addr: Option<SocketAddr>,
) -> RedisResult<crate::aio::MultiplexedConnection>
) -> RedisResult<(crate::aio::MultiplexedConnection, Option<IpAddr>)>
where
T: crate::aio::RedisRuntime,
{
let (connection, driver) = self
let (connection, driver, ip) = self
.create_multiplexed_async_connection_inner::<T>(socket_addr)
.await?;
T::spawn(driver);
Ok(connection)
Ok((connection, ip))
}

async fn create_multiplexed_async_connection_inner<T>(
Expand All @@ -268,26 +297,30 @@ impl Client {
) -> RedisResult<(
crate::aio::MultiplexedConnection,
impl std::future::Future<Output = ()>,
Option<IpAddr>,
)>
where
T: crate::aio::RedisRuntime,
{
let con = self.get_simple_async_connection::<T>(socket_addr).await?;
crate::aio::MultiplexedConnection::new(&self.connection_info.redis, con).await
let (con, ip) = self.get_simple_async_connection::<T>(socket_addr).await?;
crate::aio::MultiplexedConnection::new(&self.connection_info.redis, con)
.await
.map(|res| (res.0, res.1, ip))
}

async fn get_simple_async_connection<T>(
&self,
socket_addr: Option<SocketAddr>,
) -> RedisResult<Pin<Box<dyn crate::aio::AsyncStream + Send + Sync>>>
) -> RedisResult<(
Pin<Box<dyn crate::aio::AsyncStream + Send + Sync>>,
Option<IpAddr>,
)>
where
T: crate::aio::RedisRuntime,
{
Ok(
crate::aio::connect_simple::<T>(&self.connection_info, socket_addr)
.await?
.boxed(),
)
let (conn, ip) =
crate::aio::connect_simple::<T>(&self.connection_info, socket_addr).await?;
Ok((conn.boxed(), ip))
}

#[cfg(feature = "connection-manager")]
Expand Down
Loading

0 comments on commit 371e52c

Please sign in to comment.