diff --git a/refinery/Cargo.toml b/refinery/Cargo.toml index eba439d9..6f32d6aa 100644 --- a/refinery/Cargo.toml +++ b/refinery/Cargo.toml @@ -16,7 +16,7 @@ edition = "2018" default = [] rusqlite-bundled = ["refinery-core/rusqlite-bundled"] rusqlite = ["refinery-core/rusqlite"] -postgres = ["refinery-core/postgres"] +postgres = ["refinery-core/postgres", "refinery-core/postgres-openssl", "refinery-core/openssl"] mysql = ["refinery-core/mysql", "refinery-core/flate2"] tokio-postgres = ["refinery-core/tokio-postgres"] mysql_async = ["refinery-core/mysql_async"] diff --git a/refinery_cli/Cargo.toml b/refinery_cli/Cargo.toml index ae58a2ee..8712cd79 100644 --- a/refinery_cli/Cargo.toml +++ b/refinery_cli/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [features] default = ["mysql", "postgresql", "sqlite-bundled", "mssql"] -postgresql = ["refinery-core/postgres"] +postgresql = ["refinery-core/postgres", "refinery-core/postgres-openssl", "refinery-core/openssl"] mysql = ["refinery-core/mysql", "refinery-core/flate2"] sqlite = ["refinery-core/rusqlite"] sqlite-bundled = ["sqlite", "refinery-core/rusqlite-bundled"] diff --git a/refinery_core/Cargo.toml b/refinery_core/Cargo.toml index 11b46917..b3fe9cc1 100644 --- a/refinery_core/Cargo.toml +++ b/refinery_core/Cargo.toml @@ -32,6 +32,8 @@ walkdir = "2.3.1" # allow multiple versions of the same dependency if API is similar rusqlite = { version = ">= 0.23, <= 0.28", optional = true } postgres = { version = "0.19", optional = true } +postgres-openssl = { version = "0.5", optional = true } +openssl = { version = "0.10", optional = true } tokio-postgres = { version = "0.7", optional = true } mysql = { version = ">= 21.0.0, <= 23", optional = true, default-features = false} mysql_async = { version = ">= 0.28, <= 0.30", optional = true } diff --git a/refinery_core/src/config.rs b/refinery_core/src/config.rs index c1fac221..a5f4ccac 100644 --- a/refinery_core/src/config.rs +++ b/refinery_core/src/config.rs @@ -5,6 +5,7 @@ use std::convert::TryFrom; use std::fs; use std::path::{Path, PathBuf}; use std::str::FromStr; +use std::{borrow::Cow, collections::HashMap}; use url::Url; // refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros @@ -34,6 +35,7 @@ impl Config { db_user: None, db_pass: None, db_name: None, + use_tls: None, #[cfg(feature = "tiberius-config")] trust_cert: false, }, @@ -138,6 +140,10 @@ impl Config { self.main.db_port.as_deref() } + pub fn use_tls(&self) -> Option { + self.main.use_tls + } + pub fn set_db_user(self, db_user: &str) -> Config { Config { main: Main { @@ -202,13 +208,12 @@ impl TryFrom for Config { } }; + let query_params = url + .query_pairs() + .collect::, Cow<'_, str>>>(); + cfg_if::cfg_if! { if #[cfg(feature = "tiberius-config")] { - use std::{borrow::Cow, collections::HashMap}; - let query_params = url - .query_pairs() - .collect::, Cow<'_, str>>>(); - let trust_cert = query_params. get("trust_cert") .unwrap_or(&Cow::Borrowed("false")) @@ -222,6 +227,21 @@ impl TryFrom for Config { } } + let use_tls = match query_params + .get("sslmode") + .unwrap_or(&Cow::Borrowed("disable")) + { + &Cow::Borrowed("disable") => Ok(false), + &Cow::Borrowed("require") => Ok(true), + _ => Err(()), + } + .map_err(|_| { + Error::new( + Kind::ConfigError("Invalid sslmode value, please use disable/require".into()), + None, + ) + })?; + Ok(Self { main: Main { db_type, @@ -237,6 +257,7 @@ impl TryFrom for Config { db_user: Some(url.username().to_string()), db_pass: url.password().map(|r| r.to_string()), db_name: Some(url.path().trim_start_matches('/').to_string()), + use_tls: Some(use_tls), #[cfg(feature = "tiberius-config")] trust_cert, }, @@ -268,6 +289,7 @@ struct Main { db_user: Option, db_pass: Option, db_name: Option, + use_tls: Option, #[cfg(feature = "tiberius-config")] #[serde(default)] trust_cert: bool, @@ -451,6 +473,24 @@ mod tests { ); } + #[test] + fn build_no_tls_conn_from_str() { + let config = + Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable") + .unwrap(); + assert!(config.use_tls().is_some()); + assert!(!config.use_tls().unwrap()); + } + + #[test] + fn build_tls_conn_from_str() { + let config = + Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require") + .unwrap(); + assert!(config.use_tls().is_some()); + assert!(config.use_tls().unwrap()); + } + #[test] fn builds_db_env_var_failure() { std::env::set_var("DATABASE_URL", "this_is_not_a_url"); diff --git a/refinery_core/src/drivers/config.rs b/refinery_core/src/drivers/config.rs index 92a00582..86675489 100644 --- a/refinery_core/src/drivers/config.rs +++ b/refinery_core/src/drivers/config.rs @@ -80,7 +80,16 @@ macro_rules! with_connection { cfg_if::cfg_if! { if #[cfg(feature = "postgres")] { let path = build_db_url("postgresql", &$config); - let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?; + + let conn; + if $config.use_tls().is_some() && $config.use_tls().unwrap() { + let builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap(); + let connector = postgres_openssl::MakeTlsConnector::new(builder.build()); + conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?; + } else { + conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?; + } + $op(conn) } else { panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");