Skip to content

Commit

Permalink
feat: add configuration for tls watch option (#3395)
Browse files Browse the repository at this point in the history
* feat: add configuration for tls watch option

* test: sleep longer to ensure async task run

* test: update config api integration test

* refactor: rename function
  • Loading branch information
sunng87 authored Mar 1, 2024
1 parent c1a3706 commit d4a54a0
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 8 deletions.
2 changes: 2 additions & 0 deletions config/frontend.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ runtime_size = 2
mode = "disable"
cert_path = ""
key_path = ""
watch = false

# PostgresSQL server options, see `standalone.example.toml`.
[postgres]
Expand All @@ -43,6 +44,7 @@ runtime_size = 2
mode = "disable"
cert_path = ""
key_path = ""
watch = false

# OpenTSDB protocol options, see `standalone.example.toml`.
[opentsdb]
Expand Down
6 changes: 5 additions & 1 deletion config/standalone.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ mode = "disable"
cert_path = ""
# Private key file path.
key_path = ""
# Watch for Certificate and key file change and auto reload
watch = false

# PostgresSQL server options.
[postgres]
Expand All @@ -62,6 +64,8 @@ mode = "disable"
cert_path = ""
# private key file path.
key_path = ""
# Watch for Certificate and key file change and auto reload
watch = false

# OpenTSDB protocol options.
[opentsdb]
Expand Down Expand Up @@ -118,7 +122,7 @@ sync_period = "1000ms"
# Number of topics to be created upon start.
# num_topics = 64
# Topic selector type.
# Available selector types:
# Available selector types:
# - "round_robin" (default)
# selector_type = "round_robin"
# The prefix of topic name.
Expand Down
7 changes: 4 additions & 3 deletions src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::ServerSqlQueryHandlerAdapter;
use servers::server::{Server, ServerHandlers};
use servers::tls::{watch_tls_config, ReloadableTlsServerConfig};
use servers::tls::{maybe_watch_tls_config, ReloadableTlsServerConfig};
use snafu::ResultExt;

use crate::error::{self, Result, StartServerSnafu};
Expand Down Expand Up @@ -199,7 +199,8 @@ where
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);

watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
// will not watch if watch is disabled in tls option
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;

let mysql_io_runtime = Arc::new(
RuntimeBuilder::default()
Expand Down Expand Up @@ -232,7 +233,7 @@ where
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);

watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;

let pg_io_runtime = Arc::new(
RuntimeBuilder::default()
Expand Down
42 changes: 38 additions & 4 deletions src/servers/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pub struct TlsOption {
pub cert_path: String,
#[serde(default)]
pub key_path: String,
#[serde(default)]
pub watch: bool,
}

impl TlsOption {
Expand Down Expand Up @@ -138,6 +140,10 @@ impl TlsOption {
pub fn key_path(&self) -> &Path {
Path::new(&self.key_path)
}

pub fn watch_enabled(&self) -> bool {
self.mode != TlsMode::Disable && self.watch
}
}

/// A mutable container for TLS server config
Expand Down Expand Up @@ -186,8 +192,8 @@ impl ReloadableTlsServerConfig {
}
}

pub fn watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
if tls_server_config.get_tls_option().mode == TlsMode::Disable {
pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
if !tls_server_config.get_tls_option().watch_enabled() {
return Ok(());
}

Expand Down Expand Up @@ -250,6 +256,7 @@ mod tests {
mode: Disable,
cert_path: "/path/to/cert_path".to_string(),
key_path: "/path/to/key_path".to_string(),
watch: false
},
TlsOption::new(
Some(Disable),
Expand All @@ -274,6 +281,7 @@ mod tests {
assert!(matches!(t.mode, TlsMode::Disable));
assert!(t.key_path.is_empty());
assert!(t.cert_path.is_empty());
assert!(!t.watch_enabled());

let setup = t.setup();
let setup = setup.unwrap();
Expand All @@ -297,6 +305,7 @@ mod tests {
assert!(matches!(t.mode, TlsMode::Prefer));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
assert!(!t.watch_enabled());
}

#[test]
Expand All @@ -316,6 +325,7 @@ mod tests {
assert!(matches!(t.mode, TlsMode::Require));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
assert!(!t.watch_enabled());
}

#[test]
Expand All @@ -335,6 +345,7 @@ mod tests {
assert!(matches!(t.mode, TlsMode::VerifyCa));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
assert!(!t.watch_enabled());
}

#[test]
Expand All @@ -354,6 +365,28 @@ mod tests {
assert!(matches!(t.mode, TlsMode::VerifyFull));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
assert!(!t.watch_enabled());
}

#[test]
fn test_tls_option_watch_enabled() {
let s = r#"
{
"mode": "verify_full",
"cert_path": "/some_dir/some.crt",
"key_path": "/some_dir/some.key",
"watch": true
}
"#;

let t: TlsOption = serde_json::from_str(s).unwrap();

assert!(t.should_force_tls());

assert!(matches!(t.mode, TlsMode::VerifyFull));
assert!(!t.key_path.is_empty());
assert!(!t.cert_path.is_empty());
assert!(t.watch_enabled());
}

#[test]
Expand All @@ -377,12 +410,13 @@ mod tests {
.into_os_string()
.into_string()
.expect("failed to convert path to string"),
watch: true,
};

let server_config = Arc::new(
ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
);
watch_tls_config(server_config.clone()).expect("failed to watch server config");
maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");

assert_eq!(0, server_config.get_version());
assert!(server_config.get_server_config().is_some());
Expand All @@ -391,7 +425,7 @@ mod tests {
.expect("failed to copy key to tmpdir");

// waiting for async load
std::thread::sleep(std::time::Duration::from_millis(100));
std::thread::sleep(std::time::Duration::from_millis(300));
assert!(server_config.get_version() > 1);
assert!(server_config.get_server_config().is_some());
}
Expand Down
3 changes: 3 additions & 0 deletions src/servers/tests/mysql/mysql_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server-rsa.key".to_owned(),
watch: false,
};

let client_tls = false;
Expand Down Expand Up @@ -292,6 +293,7 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server-pkcs8.key".to_owned(),
watch: false,
};

let client_tls = false;
Expand Down Expand Up @@ -592,6 +594,7 @@ async fn do_test_query_all_datatypes_with_secure_server(
"tests/ssl/server-rsa.key".to_owned()
}
},
watch: false,
};

do_test_query_all_datatypes(server_tls, client_tls).await
Expand Down
3 changes: 3 additions & 0 deletions src/servers/tests/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ async fn test_server_secure_require_client_plain() -> Result<()> {
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server-rsa.key".to_owned(),
watch: false,
};
let server_port = start_test_server(server_tls).await?;
let r = create_plain_connection(server_port, false).await;
Expand All @@ -288,6 +289,7 @@ async fn test_server_secure_require_client_plain_with_pkcs8_priv_key() -> Result
mode: servers::tls::TlsMode::Require,
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server-pkcs8.key".to_owned(),
watch: false,
};
let server_port = start_test_server(server_tls).await?;
let r = create_plain_connection(server_port, false).await;
Expand Down Expand Up @@ -520,6 +522,7 @@ async fn do_simple_query_with_secure_server(
"tests/ssl/server-rsa.key".to_owned()
}
},
watch: false,
};

do_simple_query(server_tls, client_tls).await
Expand Down
2 changes: 2 additions & 0 deletions tests-integration/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ runtime_size = 2
mode = "disable"
cert_path = ""
key_path = ""
watch = false
[frontend.postgres]
enable = true
Expand All @@ -695,6 +696,7 @@ runtime_size = 2
mode = "disable"
cert_path = ""
key_path = ""
watch = false
[frontend.opentsdb]
enable = true
Expand Down

0 comments on commit d4a54a0

Please sign in to comment.