diff --git a/nexus/catalog/src/lib.rs b/nexus/catalog/src/lib.rs index 94119ff6c1..8cfc5da413 100644 --- a/nexus/catalog/src/lib.rs +++ b/nexus/catalog/src/lib.rs @@ -39,12 +39,12 @@ async fn run_migrations(client: &mut Client) -> anyhow::Result<()> { } #[derive(Debug, Clone)] -pub struct CatalogConfig { - pub host: String, +pub struct CatalogConfig<'a> { + pub host: &'a str, pub port: u16, - pub user: String, - pub password: String, - pub database: String, + pub user: &'a str, + pub password: &'a str, + pub database: &'a str, } #[derive(Debug, Clone)] @@ -54,25 +54,15 @@ pub struct WorkflowDetails { pub destination_peer: pt::peerdb_peers::Peer, } -impl CatalogConfig { - pub fn new(host: String, port: u16, user: String, password: String, database: String) -> Self { - Self { - host, - port, - user, - password, - database, - } - } - +impl<'a> CatalogConfig<'a> { // convert catalog config to PostgresConfig pub fn to_postgres_config(&self) -> pt::peerdb_peers::PostgresConfig { PostgresConfig { - host: self.host.clone(), + host: self.host.to_string(), port: self.port as u32, - user: self.user.clone(), - password: self.password.clone(), - database: self.database.clone(), + user: self.user.to_string(), + password: self.password.to_string(), + database: self.database.to_string(), transaction_snapshot: "".to_string(), metadata_schema: Some("".to_string()), ssh_config: None, @@ -85,7 +75,7 @@ impl CatalogConfig { } impl Catalog { - pub async fn new(catalog_config: &CatalogConfig) -> anyhow::Result { + pub async fn new<'a>(catalog_config: &CatalogConfig<'a>) -> anyhow::Result { let pt_config = catalog_config.to_postgres_config(); let client = connect_postgres(&pt_config).await?; let executor = PostgresQueryExecutor::new(None, &pt_config).await?; @@ -100,8 +90,8 @@ impl Catalog { run_migrations(&mut self.pg).await } - pub fn get_executor(&self) -> Arc { - self.executor.clone() + pub fn get_executor(&self) -> &Arc { + &self.executor } pub async fn create_peer(&self, peer: &Peer) -> anyhow::Result { diff --git a/nexus/parser/src/lib.rs b/nexus/parser/src/lib.rs index 02c8cb5a27..f99dbe8751 100644 --- a/nexus/parser/src/lib.rs +++ b/nexus/parser/src/lib.rs @@ -15,6 +15,7 @@ use tokio::sync::Mutex; const DIALECT: PostgreSqlDialect = PostgreSqlDialect {}; +#[derive(Clone)] pub struct NexusQueryParser { catalog: Arc>, } diff --git a/nexus/peer-bigquery/src/lib.rs b/nexus/peer-bigquery/src/lib.rs index d90f2f6c1d..c338518f45 100644 --- a/nexus/peer-bigquery/src/lib.rs +++ b/nexus/peer-bigquery/src/lib.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::Duration}; +use std::time::Duration; use anyhow::Context; use cursor::BigQueryCursorManager; @@ -22,7 +22,7 @@ pub struct BigQueryQueryExecutor { peer_name: String, project_id: String, dataset_id: String, - peer_connections: Arc, + peer_connections: PeerConnectionTracker, client: Box, cursor_manager: BigQueryCursorManager, } @@ -51,7 +51,7 @@ impl BigQueryQueryExecutor { pub async fn new( peer_name: String, config: &BigqueryConfig, - peer_connections: Arc, + peer_connections: PeerConnectionTracker, ) -> anyhow::Result { let client = bq_client_from_config(config).await?; let client = Box::new(client); diff --git a/nexus/peer-connections/src/lib.rs b/nexus/peer-connections/src/lib.rs index 6637c0c3fa..e006f7041b 100644 --- a/nexus/peer-connections/src/lib.rs +++ b/nexus/peer-connections/src/lib.rs @@ -21,6 +21,7 @@ impl PeerConnections { } } +#[derive(Clone)] pub struct PeerConnectionTracker { conn_uuid: uuid::Uuid, peer_connections: Arc, diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index a6d9ef4f06..2720ca5c4e 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -10,7 +10,7 @@ use bytes::{BufMut, BytesMut}; use catalog::{Catalog, CatalogConfig, WorkflowDetails}; use clap::Parser; use cursor::PeerCursors; -use dashmap::DashMap; +use dashmap::{mapref::entry::Entry as DashEntry, DashMap}; use flow_rs::grpc::{FlowGrpcClient, PeerValidationResult}; use peer_bigquery::BigQueryQueryExecutor; use peer_connections::{PeerConnectionTracker, PeerConnections}; @@ -64,24 +64,27 @@ impl AuthSource for FixedPasswordAuthSource { tracing::info!("login info: {:?}", login_info); // randomly generate a 4 byte salt - let salt = rand::thread_rng().gen::<[u8; 4]>().to_vec(); + let salt = rand::thread_rng().gen::<[u8; 4]>(); let password = &self.password; let hash_password = hash_md5_password( login_info.user().map(|s| s.as_str()).unwrap_or(""), password, - salt.as_ref(), + &salt, ); - Ok(Password::new(Some(salt), hash_password.as_bytes().to_vec())) + Ok(Password::new( + Some(salt.to_vec()), + hash_password.as_bytes().to_vec(), + )) } } pub struct NexusBackend { catalog: Arc>, - peer_connections: Arc, + peer_connections: PeerConnectionTracker, portal_store: Arc>, - query_parser: Arc, - peer_cursors: Arc>, - executors: Arc>>, + query_parser: NexusQueryParser, + peer_cursors: Mutex, + executors: DashMap>, flow_handler: Option>>, peerdb_fdw_mode: bool, } @@ -89,7 +92,7 @@ pub struct NexusBackend { impl NexusBackend { pub fn new( catalog: Arc>, - peer_connections: Arc, + peer_connections: PeerConnectionTracker, flow_handler: Option>>, peerdb_fdw_mode: bool, ) -> Self { @@ -98,9 +101,9 @@ impl NexusBackend { catalog, peer_connections, portal_store: Arc::new(MemPortalStore::new()), - query_parser: Arc::new(query_parser), - peer_cursors: Arc::new(Mutex::new(PeerCursors::new())), - executors: Arc::new(DashMap::new()), + query_parser, + peer_cursors: Mutex::new(PeerCursors::new()), + executors: DashMap::new(), flow_handler, peerdb_fdw_mode, } @@ -109,7 +112,7 @@ impl NexusBackend { // execute a statement on a peer async fn execute_statement<'a>( &self, - executor: Arc, + executor: &dyn QueryExecutor, stmt: &sqlparser::ast::Statement, peer_holder: Option>, ) -> PgWireResult>> { @@ -120,8 +123,6 @@ impl NexusBackend { )]), QueryOutput::Stream(rows) => { let schema = rows.schema(); - // todo: why is this a vector of response rather than a single response? - // can this be because of multiple statements? let res = sendable_stream_to_query_response(schema, rows)?; Ok(vec![res]) } @@ -822,11 +823,13 @@ impl NexusBackend { QueryAssociation::Catalog => { tracing::info!("handling catalog query: {}", stmt); let catalog = self.catalog.lock().await; - catalog.get_executor() + Arc::clone(catalog.get_executor()) } }; - let res = self.execute_statement(executor, &stmt, peer_holder).await; + let res = self + .execute_statement(executor.as_ref(), &stmt, peer_holder) + .await; // log the error if execution failed if let Err(err) = &res { tracing::error!("query execution failed: {:?}", err); @@ -845,7 +848,7 @@ impl NexusBackend { match peer { None => { let catalog = self.catalog.lock().await; - catalog.get_executor() + Arc::clone(catalog.get_executor()) } Some(peer) => self.get_peer_executor(peer).await.map_err(|err| { PgWireError::ApiError(Box::new(PgError::Internal { @@ -855,7 +858,8 @@ impl NexusBackend { } }; - self.execute_statement(executor, &stmt, peer_holder).await + self.execute_statement(executor.as_ref(), &stmt, peer_holder) + .await } NexusStatement::Empty => Ok(vec![Response::EmptyQuery]), @@ -908,40 +912,48 @@ impl NexusBackend { } async fn get_peer_executor(&self, peer: &Peer) -> anyhow::Result> { - if let Some(executor) = self.executors.get(&peer.name) { - return Ok(Arc::clone(executor.value())); - } - - let executor: Arc = match &peer.config { - Some(Config::BigqueryConfig(ref c)) => { - let executor = - BigQueryQueryExecutor::new(peer.name.clone(), c, self.peer_connections.clone()) + Ok(match self.executors.entry(peer.name.clone()) { + DashEntry::Occupied(entry) => Arc::clone(entry.get()), + DashEntry::Vacant(entry) => { + let executor: Arc = match &peer.config { + Some(Config::BigqueryConfig(ref c)) => { + let executor = BigQueryQueryExecutor::new( + peer.name.clone(), + c, + self.peer_connections.clone(), + ) .await?; - Arc::new(executor) - } - Some(Config::PostgresConfig(ref c)) => { - let peername = Some(peer.name.clone()); - let executor = peer_postgres::PostgresQueryExecutor::new(peername, c).await?; - Arc::new(executor) - } - Some(Config::SnowflakeConfig(ref c)) => { - let executor = peer_snowflake::SnowflakeQueryExecutor::new(c).await?; - Arc::new(executor) - } - _ => { - panic!("peer type not supported: {:?}", peer) - } - }; + Arc::new(executor) + } + Some(Config::PostgresConfig(ref c)) => { + let peername = Some(peer.name.clone()); + let executor = + peer_postgres::PostgresQueryExecutor::new(peername, c).await?; + Arc::new(executor) + } + Some(Config::SnowflakeConfig(ref c)) => { + let executor = peer_snowflake::SnowflakeQueryExecutor::new(c).await?; + Arc::new(executor) + } + _ => { + panic!("peer type not supported: {:?}", peer) + } + }; - self.executors - .insert(peer.name.clone(), Arc::clone(&executor)); - Ok(executor) + entry.insert(Arc::clone(&executor)); + executor + } + }) } } #[async_trait] impl SimpleQueryHandler for NexusBackend { - async fn do_query<'a, C>(&self, _client: &mut C, sql: &'a str) -> PgWireResult>> + async fn do_query<'a, C>( + &self, + _client: &mut C, + sql: &'a str, + ) -> PgWireResult>> where C: ClientInfo + Unpin + Send + Sync, { @@ -960,6 +972,7 @@ fn parameter_to_string(portal: &Portal, idx: usize) -> PgW "'{}'", portal .parameter::(idx, param_type)? + .map(|s| s.replace('\'', "''")) .as_deref() .unwrap_or("") )), @@ -1002,7 +1015,7 @@ impl ExtendedQueryHandler for NexusBackend { } fn query_parser(&self) -> Arc { - self.query_parser.clone() + Arc::new(self.query_parser.clone()) } async fn do_query<'a, C>( @@ -1131,42 +1144,6 @@ impl ExtendedQueryHandler for NexusBackend { } } -struct MakeNexusBackend { - catalog: Arc>, - peer_connections: Arc, - flow_handler: Option>>, - peerdb_fdw_mode: bool, -} - -impl MakeNexusBackend { - fn new( - catalog: Catalog, - peer_connections: Arc, - flow_handler: Option>>, - peerdb_fdw_mode: bool, - ) -> Self { - Self { - catalog: Arc::new(Mutex::new(catalog)), - peer_connections, - flow_handler, - peerdb_fdw_mode, - } - } -} - -impl MakeHandler for MakeNexusBackend { - type Handler = Arc; - - fn make(&self) -> Self::Handler { - Arc::new(NexusBackend::new( - self.catalog.clone(), - self.peer_connections.clone(), - self.flow_handler.clone(), - self.peerdb_fdw_mode, - )) - } -} - /// Arguments for the nexus server. #[derive(Parser, Debug)] struct Args { @@ -1238,11 +1215,11 @@ struct Args { // Get catalog config from args fn get_catalog_config(args: &Args) -> CatalogConfig { CatalogConfig { - host: args.catalog_host.clone(), + host: &args.catalog_host, port: args.catalog_port, - user: args.catalog_user.clone(), - password: args.catalog_password.clone(), - database: args.catalog_database.clone(), + user: &args.catalog_user, + password: &args.catalog_password, + database: &args.catalog_database, } } @@ -1253,7 +1230,7 @@ impl ServerParameterProvider for NexusServerParameterProvider { where C: ClientInfo, { - let mut params = HashMap::with_capacity(4); + let mut params = HashMap::with_capacity(5); params.insert("server_version".to_owned(), "14".to_owned()); params.insert("server_encoding".to_owned(), "UTF8".to_owned()); params.insert("client_encoding".to_owned(), "UTF8".to_owned()); @@ -1298,7 +1275,7 @@ fn setup_tracing(log_dir: &str) -> TracerGuards { } } -async fn run_migrations(config: &CatalogConfig) -> anyhow::Result<()> { +async fn run_migrations<'a>(config: &CatalogConfig<'a>) -> anyhow::Result<()> { // retry connecting to the catalog 3 times with 30 seconds delay // if it fails, return an error for _ in 0..3 { @@ -1328,10 +1305,10 @@ pub async fn main() -> anyhow::Result<()> { let args = Args::parse(); let _guard = setup_tracing(&args.log_dir); - let authenticator = Arc::new(MakeMd5PasswordAuthStartupHandler::new( + let authenticator = MakeMd5PasswordAuthStartupHandler::new( Arc::new(FixedPasswordAuthSource::new(args.peerdb_password.clone())), Arc::new(NexusServerParameterProvider), - )); + ); let catalog_config = get_catalog_config(&args); run_migrations(&catalog_config).await?; @@ -1388,14 +1365,13 @@ pub async fn main() -> anyhow::Result<()> { let authenticator_ref = authenticator.make(); - let peerdb_fdw_mode = matches!(args.peerdb_fwd_mode.as_str(), "true"); - let processor = Arc::new(MakeNexusBackend::new( - catalog, - Arc::new(tracker), + let peerdb_fdw_mode = args.peerdb_fwd_mode == "true"; + let processor = Arc::new(NexusBackend::new( + Arc::new(Mutex::new(catalog)), + tracker, flow_handler.clone(), peerdb_fdw_mode, )); - let processor_ref = processor.make(); tokio::task::Builder::new() .name("tcp connection handler") .spawn(async move { @@ -1403,8 +1379,8 @@ pub async fn main() -> anyhow::Result<()> { socket, None, authenticator_ref, - processor_ref.clone(), - processor_ref, + processor.clone(), + processor, ) .await })?;