diff --git a/nexus/Cargo.lock b/nexus/Cargo.lock index 5f640f1dd4..aa945c245d 100644 --- a/nexus/Cargo.lock +++ b/nexus/Cargo.lock @@ -2249,9 +2249,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.22.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3770f56e1e8a608c6de40011b9a00c6b669c14d121024411701b4bc3b2a5be99" +checksum = "22fb7a8b4570b74080587c5f3e187553375d18e72a38c72ca7f70a065972c65d" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/nexus/Cargo.toml b/nexus/Cargo.toml index 7ee1674a4f..6d43dbc156 100644 --- a/nexus/Cargo.toml +++ b/nexus/Cargo.toml @@ -30,7 +30,7 @@ rust_decimal = { version = "1", default-features = false, features = [ ] } sqlparser = { git = "https://github.com/peerdb-io/sqlparser-rs.git", branch = "main" } tracing = "0.1" -pgwire = { version = "0.22", default-features = false, features = [ +pgwire = { version = "0.23", default-features = false, features = [ "scram", "server-api-ring", ] } diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index d4c7a9f7bf..f9943795e5 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -22,16 +22,17 @@ use peerdb_parser::{NexusParsedStatement, NexusQueryParser, NexusStatement}; use pgwire::{ api::{ auth::{ - scram::{gen_salted_password, MakeSASLScramAuthStartupHandler}, + scram::{gen_salted_password, SASLScramAuthStartupHandler}, AuthSource, LoginInfo, Password, ServerParameterProvider, }, + copy::NoopCopyHandler, portal::Portal, query::{ExtendedQueryHandler, SimpleQueryHandler}, results::{ DescribePortalResponse, DescribeResponse, DescribeStatementResponse, Response, Tag, }, stmt::StoredStatement, - ClientInfo, MakeHandler, Type, + ClientInfo, PgWireHandlerFactory, Type, }, error::{ErrorInfo, PgWireError, PgWireResult}, tokio::process_socket, @@ -49,7 +50,7 @@ use tracing_subscriber::{fmt, prelude::*, EnvFilter}; mod cursor; -struct FixedPasswordAuthSource { +pub struct FixedPasswordAuthSource { password: String, } @@ -1141,6 +1142,34 @@ async fn run_migrations<'a>(config: &CatalogConfig<'a>) -> anyhow::Result<()> { Err(anyhow::anyhow!("Failed to connect to catalog")) } +pub struct Handlers { + authenticator: Arc>, + nexus: Arc, +} + +impl PgWireHandlerFactory for Handlers { + type StartupHandler = SASLScramAuthStartupHandler; + type SimpleQueryHandler = NexusBackend; + type ExtendedQueryHandler = NexusBackend; + type CopyHandler = NoopCopyHandler; + + fn simple_query_handler(&self) -> Arc { + self.nexus.clone() + } + + fn extended_query_handler(&self) -> Arc { + self.nexus.clone() + } + + fn startup_handler(&self) -> Arc { + self.authenticator.clone() + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } +} + #[tokio::main] pub async fn main() -> anyhow::Result<()> { dotenvy::dotenv().ok(); @@ -1148,10 +1177,10 @@ pub async fn main() -> anyhow::Result<()> { let args = Args::parse(); let _guard = setup_tracing(args.log_dir.as_ref().map(|s| &s[..])); - let authenticator = MakeSASLScramAuthStartupHandler::new( + let authenticator = Arc::new(SASLScramAuthStartupHandler::new( Arc::new(FixedPasswordAuthSource::new(args.peerdb_password.clone())), Arc::new(NexusServerParameterProvider), - ); + )); let catalog_config = get_catalog_config(&args); run_migrations(&catalog_config).await?; @@ -1184,7 +1213,7 @@ pub async fn main() -> anyhow::Result<()> { let conn_flow_handler = flow_handler.clone(); let conn_peer_conns = peer_conns.clone(); let peerdb_fdw_mode = args.peerdb_fwd_mode == "true"; - let authenticator_ref = authenticator.make(); + let authenticator = authenticator.clone(); let pg_config = catalog_config.to_postgres_config(); tokio::task::spawn(async move { @@ -1193,7 +1222,7 @@ pub async fn main() -> anyhow::Result<()> { let conn_uuid = uuid::Uuid::new_v4(); let tracker = PeerConnectionTracker::new(conn_uuid, conn_peer_conns); - let processor = Arc::new(NexusBackend::new( + let nexus = Arc::new(NexusBackend::new( Arc::new(catalog), tracker, conn_flow_handler, @@ -1202,9 +1231,7 @@ pub async fn main() -> anyhow::Result<()> { process_socket( socket, None, - authenticator_ref, - processor.clone(), - processor, + Arc::new(Handlers { nexus, authenticator }), ) .await }