diff --git a/spanner/src/session.rs b/spanner/src/session.rs index 79e674af..abce6e4e 100644 --- a/spanner/src/session.rs +++ b/spanner/src/session.rs @@ -9,7 +9,7 @@ use thiserror; use tokio::select; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot}; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tokio::time::{sleep, timeout}; use tokio_util::sync::CancellationToken; @@ -234,12 +234,23 @@ impl SessionPool { ) -> Result, Status> { let channel_num = conn_pool.num(); let creation_count_per_channel = min_opened / channel_num; + let remainder = min_opened % channel_num; let mut sessions = Vec::::new(); + let mut tasks = JoinSet::new(); for _ in 0..channel_num { + // Ensure that we create the exact number of requested sessions by adding the remainder to the first channel. + let creation_count = if channel_num == 0 { + creation_count_per_channel + remainder + } else { + creation_count_per_channel + }; let next_client = conn_pool.conn().with_metadata(client_metadata(&database)); - let new_sessions = - batch_create_sessions(next_client, database.as_str(), creation_count_per_channel).await?; + let database = database.clone(); + tasks.spawn(async move { batch_create_sessions(next_client, &database, creation_count).await }); + } + while let Some(r) = tasks.join_next().await { + let new_sessions = r.map_err(|e| Status::from_error(e.into()))??; sessions.extend(new_sessions); } tracing::debug!("initial session created count = {}", sessions.len()); @@ -493,17 +504,23 @@ impl SessionManager { cancel: CancellationToken, ) -> JoinHandle<()> { tokio::spawn(async move { + let mut tasks = JoinSet::default(); loop { - let session_count: usize = select! { + select! { + biased; + _ = cancel.cancelled() => break, + Some(Ok((session_count, result))) = tasks.join_next(), if !tasks.is_empty() => { + session_pool.inner.write().replenish(session_count, result); + } session_count = rx.recv() => match session_count { - Some(session_count) => session_count, + Some(session_count) => { + let client = conn_pool.conn().with_metadata(client_metadata(&database)); + let database = database.clone(); + tasks.spawn(async move { (session_count, batch_create_sessions(client, &database, session_count).await) }); + }, None => continue }, - _ = cancel.cancelled() => break - }; - let client = conn_pool.conn().with_metadata(client_metadata(&database)); - let result = batch_create_sessions(client, database.as_str(), session_count).await; - session_pool.inner.write().replenish(session_count, result); + } } tracing::trace!("shutdown session creation task."); })