From 41e6e19dfbc87c13f50b16ef6de9353bb3d19b73 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Sun, 10 Dec 2023 13:52:15 +0200 Subject: [PATCH] Add optional client_name property to RedisConnectionInfo, which will be used with 'CLIENT SETNAME' command during connection setup. --- redis/Cargo.toml | 1 + redis/src/aio/mod.rs | 15 +++++ redis/src/cluster.rs | 1 + redis/src/cluster_client.rs | 18 ++++++ redis/src/connection.rs | 23 ++++++++ redis/src/sentinel.rs | 2 + redis/tests/support/cluster.rs | 4 ++ redis/tests/support/mod.rs | 97 +++++++++++++++++-------------- redis/tests/support/util.rs | 13 +++++ redis/tests/test_async.rs | 33 +++++++++++ redis/tests/test_basic.rs | 23 ++++++++ redis/tests/test_cluster.rs | 27 +++++++++ redis/tests/test_cluster_async.rs | 36 ++++++++++++ 13 files changed, 249 insertions(+), 44 deletions(-) diff --git a/redis/Cargo.toml b/redis/Cargo.toml index 11da8f562..06fef3a1e 100644 --- a/redis/Cargo.toml +++ b/redis/Cargo.toml @@ -137,6 +137,7 @@ tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] tempfile = "=3.6.0" once_cell = "1" anyhow = "1" +sscanf = "0.4.1" [[test]] name = "test_async" diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index b64f06d3b..24347c2fb 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -140,6 +140,21 @@ where } } + if let Some(client_name) = &connection_info.client_name { + match cmd("CLIENT") + .arg("SETNAME") + .arg(client_name) + .query_async(con) + .await + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + // result is ignored, as per the command's instructions. // https://redis.io/commands/client-setinfo/ let _: RedisResult<()> = crate::connection::client_set_info_pipeline() diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 9c269f83b..4dfc56294 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -910,6 +910,7 @@ pub(crate) fn get_connection_info( redis: RedisConnectionInfo { password: cluster_params.password, username: cluster_params.username, + client_name: cluster_params.client_name, use_resp3: cluster_params.use_resp3, ..Default::default() }, diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 998c63476..a879811c3 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -33,6 +33,7 @@ struct BuilderParams { retries_configuration: RetryParams, connection_timeout: Option, topology_checks_interval: Option, + client_name: Option, use_resp3: bool, } @@ -87,6 +88,7 @@ pub(crate) struct ClusterParams { pub(crate) connection_timeout: Duration, pub(crate) topology_checks_interval: Option, pub(crate) tls_params: Option, + pub(crate) client_name: Option, pub(crate) use_resp3: bool, } @@ -111,6 +113,7 @@ impl ClusterParams { connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), topology_checks_interval: value.topology_checks_interval, tls_params, + client_name: value.client_name, use_resp3: value.use_resp3, }) } @@ -212,6 +215,15 @@ impl ClusterClientBuilder { ))); } + if node.redis.client_name.is_some() + && node.redis.client_name != cluster_params.client_name + { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different client_name among initial nodes.", + ))); + } + nodes.push(node); } @@ -221,6 +233,12 @@ impl ClusterClientBuilder { }) } + /// Sets client name for the new ClusterClient. + pub fn client_name(mut self, client_name: String) -> ClusterClientBuilder { + self.builder_params.client_name = Some(client_name); + self + } + /// Sets password for the new ClusterClient. pub fn password(mut self, password: String) -> ClusterClientBuilder { self.builder_params.password = Some(password); diff --git a/redis/src/connection.rs b/redis/src/connection.rs index 92e249dfa..f3ac85663 100644 --- a/redis/src/connection.rs +++ b/redis/src/connection.rs @@ -227,6 +227,8 @@ pub struct RedisConnectionInfo { pub password: Option, /// Use RESP 3 mode, Redis 6 or newer is required. pub use_resp3: bool, + /// Optionally a pass a client name that should be used for connection + pub client_name: Option, } impl FromStr for ConnectionInfo { @@ -387,6 +389,7 @@ fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { Some(v) => v == "true", _ => false, }, + client_name: None, }, }) } @@ -413,6 +416,7 @@ fn url_to_unix_connection_info(url: url::Url) -> RedisResult { Some(v) => v == "true", _ => false, }, + client_name: None, }, }) } @@ -979,6 +983,20 @@ fn setup_connection( } } + if connection_info.client_name.is_some() { + match cmd("CLIENT") + .arg("SETNAME") + .arg(connection_info.client_name.as_ref().unwrap()) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + // result is ignored, as per the command's instructions. // https://redis.io/commands/client-setinfo/ let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv); @@ -1708,6 +1726,7 @@ mod tests { username: Some("%johndoe%".to_string()), password: Some("#@<>$".to_string()), use_resp3: false, + client_name: None, }, }, ), @@ -1775,6 +1794,7 @@ mod tests { username: None, password: None, use_resp3: false, + client_name: None, }, }, ), @@ -1787,6 +1807,7 @@ mod tests { username: None, password: None, use_resp3: false, + client_name: None, }, }, ), @@ -1802,6 +1823,7 @@ mod tests { username: Some("%johndoe%".to_string()), password: Some("#@<>$".to_string()), use_resp3: false, + client_name: None, }, }, ), @@ -1817,6 +1839,7 @@ mod tests { username: Some("%johndoe%".to_string()), password: Some("&?= *+".to_string()), use_resp3: false, + client_name: None, }, }, ), diff --git a/redis/src/sentinel.rs b/redis/src/sentinel.rs index 9a32ac889..2045894fb 100644 --- a/redis/src/sentinel.rs +++ b/redis/src/sentinel.rs @@ -59,6 +59,7 @@ //! username: Some(String::from("foo")), //! password: Some(String::from("bar")), //! use_resp3: false, +//! client_name: None //! }), //! }), //! ) @@ -95,6 +96,7 @@ //! username: Some(String::from("user")), //! password: Some(String::from("pass")), //! use_resp3: false, +//! client_name: None //! }), //! }), //! redis::sentinel::SentinelServerType::Master, diff --git a/redis/tests/support/cluster.rs b/redis/tests/support/cluster.rs index ab76a8afd..4156b802a 100644 --- a/redis/tests/support/cluster.rs +++ b/redis/tests/support/cluster.rs @@ -74,6 +74,10 @@ impl RedisCluster { "world" } + pub fn client_name() -> &'static str { + "test_cluster_client" + } + pub fn new(nodes: u16, replicas: u16) -> RedisCluster { RedisCluster::with_modules(nodes, replicas, &[], false) } diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 446320721..a5faf7122 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -56,6 +56,7 @@ mod cluster; mod mock_cluster; mod util; +pub use self::util::*; #[cfg(any(feature = "cluster", feature = "cluster-async"))] pub use self::cluster::*; @@ -349,28 +350,7 @@ impl TestContext { Self::with_modules(&[], true) } - pub fn with_tls(tls_files: TlsFilePaths, mtls_enabled: bool) -> TestContext { - let redis_port = get_random_available_port(); - let addr = RedisServer::get_addr(redis_port); - - let server = RedisServer::new_with_addr_tls_modules_and_spawner( - addr, - None, - Some(tls_files), - mtls_enabled, - &[], - |cmd| { - cmd.spawn() - .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) - }, - ); - - #[cfg(feature = "tls-rustls")] - let client = - build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); - #[cfg(not(feature = "tls-rustls"))] - let client = redis::Client::open(server.connection_info()).unwrap(); - + fn connect_with_retries(client: &redis::Client) { let mut con; let millisecond = Duration::from_millis(1); @@ -395,6 +375,31 @@ impl TestContext { } } redis::cmd("FLUSHDB").execute(&mut con); + } + + pub fn with_tls(tls_files: TlsFilePaths, mtls_enabled: bool) -> TestContext { + let redis_port = get_random_available_port(); + let addr: ConnectionAddr = RedisServer::get_addr(redis_port); + + let server = RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + Some(tls_files), + mtls_enabled, + &[], + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); TestContext { server, @@ -412,30 +417,34 @@ impl TestContext { #[cfg(not(feature = "tls-rustls"))] let client = redis::Client::open(server.connection_info()).unwrap(); - let mut con; + Self::connect_with_retries(&client); - let millisecond = Duration::from_millis(1); - let mut retries = 0; - loop { - match client.get_connection() { - Err(err) => { - if err.is_connection_refusal() { - sleep(millisecond); - retries += 1; - if retries > 100000 { - panic!("Tried to connect too many times, last error: {err}"); - } - } else { - panic!("Could not connect: {err}"); - } - } - Ok(x) => { - con = x; - break; - } - } + TestContext { + server, + client, + use_resp3: use_resp3(), } - redis::cmd("FLUSHDB").execute(&mut con); + } + + pub fn with_client_name(clientname: &str) -> TestContext { + let server = RedisServer::with_modules(&[], false); + let con_info = redis::ConnectionInfo { + addr: server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + db: Default::default(), + username: None, + password: None, + use_resp3: Default::default(), + client_name: Some(clientname.to_string()), + }, + }; + + #[cfg(feature = "tls-rustls")] + let client = build_single_client(con_info, &server.tls_paths, false).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(con_info).unwrap(); + + Self::connect_with_retries(&client); TestContext { server, diff --git a/redis/tests/support/util.rs b/redis/tests/support/util.rs index fb0d020e6..8026b83fb 100644 --- a/redis/tests/support/util.rs +++ b/redis/tests/support/util.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + #[macro_export] macro_rules! assert_args { ($value:expr, $($args:expr),+) => { @@ -8,3 +10,14 @@ macro_rules! assert_args { assert_eq!(strings, vec![$($args),+]); } } + +pub fn parse_client_info(client_info: &str) -> HashMap { + let mut res = HashMap::new(); + + for line in client_info.split(' ') { + let this_attr: Vec<&str> = line.split('=').collect(); + res.insert(this_attr[0].to_string(), this_attr[1].to_string()); + } + + res +} diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index d9ca2b764..ff80d92fb 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -472,6 +472,7 @@ async fn invalid_password_issue_343() { username: None, password: Some("asdcasc".to_string()), use_resp3: false, + client_name: None, }, }; let client = redis::Client::open(coninfo).unwrap(); @@ -781,3 +782,35 @@ mod mtls_test { } } } + +#[test] +fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + use redis::RedisError; + let ctx = TestContext::with_client_name(CLIENT_NAME); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + + let client_info: String = redis::cmd("CLIENT") + .arg("INFO") + .query_async(&mut con) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} diff --git a/redis/tests/test_basic.rs b/redis/tests/test_basic.rs index 48da01698..78e50c14a 100644 --- a/redis/tests/test_basic.rs +++ b/redis/tests/test_basic.rs @@ -1448,3 +1448,26 @@ fn test_blocking_sorted_set_api() { ); } } + +#[test] +fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + + let ctx = TestContext::with_client_name(CLIENT_NAME); + let mut con = ctx.connection(); + + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); +} diff --git a/redis/tests/test_cluster.rs b/redis/tests/test_cluster.rs index f221b183a..43fb2519c 100644 --- a/redis/tests/test_cluster.rs +++ b/redis/tests/test_cluster.rs @@ -887,6 +887,33 @@ fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2 assert_eq!(result, expected_result); } +#[test] +fn test_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + let mut con = cluster.connection(); + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); +} + #[cfg(feature = "tls-rustls")] mod mtls_test { use super::*; diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 982ed2220..163987aa7 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -1978,6 +1978,42 @@ fn test_async_cluster_periodic_checks_update_topology_after_failover() { .unwrap(); } +#[test] +fn test_async_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + let client_info: String = cmd("CLIENT") + .arg("INFO") + .query_async(&mut connection) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + #[cfg(feature = "tls-rustls")] mod mtls_test { use crate::support::mtls_test::create_cluster_client_from_cluster;