diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index d57612dec3..8429e5aec1 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -26,7 +26,7 @@ pub(crate) use ssl_config::SslConfig; use crate::authentication::AuthenticatorProvider; use scylla_cql::frame::response::authenticate::Authenticate; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::convert::TryFrom; use std::io::ErrorKind; use std::net::{IpAddr, SocketAddr}; @@ -52,7 +52,7 @@ use crate::frame::{ request::{self, batch, execute, query, register, SerializableRequest}, response::{event::Event, result, NonErrorResponse, Response, ResponseOpcode}, server_event_type::EventType, - value::{BatchValues, ValueList}, + value::{BatchValues, BatchValuesIterator, ValueList}, FrameParams, SerializedRequest, }; use crate::query::Query; @@ -772,11 +772,13 @@ impl Connection { pub(crate) async fn batch_with_consistency( &self, - batch: &Batch, + init_batch: &Batch, values: impl BatchValues, consistency: Consistency, serial_consistency: Option, ) -> Result { + let batch = self.prepare_batch(init_batch, &values).await?; + let batch_frame = batch::Batch { statements: Cow::Borrowed(&batch.statements), values, @@ -820,6 +822,58 @@ impl Connection { } } + async fn prepare_batch<'b>( + &self, + init_batch: &'b Batch, + values: impl BatchValues, + ) -> Result, QueryError> { + let mut to_prepare = HashSet::<&str>::new(); + + { + let mut values_iter = values.batch_values_iter(); + for stmt in &init_batch.statements { + if let BatchStatement::Query(query) = stmt { + let value = values_iter.next_serialized().transpose()?; + if let Some(v) = value { + if v.len() > 0 { + to_prepare.insert(&query.contents); + } + } + } else { + values_iter.skip_next(); + } + } + } + + if to_prepare.is_empty() { + return Ok(Cow::Borrowed(init_batch)); + } + + let mut prepared_queries = HashMap::<&str, PreparedStatement>::new(); + + for query in &to_prepare { + let prepared = self.prepare(&Query::new(query.to_string())).await?; + prepared_queries.insert(query, prepared); + } + + let mut batch: Cow = Cow::Owned(Default::default()); + batch.to_mut().config = init_batch.config.clone(); + for stmt in &init_batch.statements { + match stmt { + BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str()) + { + Some(prepared) => batch.to_mut().append_statement(prepared.clone()), + None => batch.to_mut().append_statement(query.clone()), + }, + BatchStatement::PreparedStatement(prepared) => { + batch.to_mut().append_statement(prepared.clone()); + } + } + } + + Ok(batch) + } + pub(crate) async fn use_keyspace( &self, keyspace_name: &VerifiedKeyspaceName, diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 6025bf639f..939983cfc4 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -28,6 +28,8 @@ mod authenticate_test; mod cql_collections_test; #[cfg(test)] mod session_test; +#[cfg(test)] +mod silent_prepare_batch_test; #[cfg(test)] mod cql_types_test; diff --git a/scylla/src/transport/silent_prepare_batch_test.rs b/scylla/src/transport/silent_prepare_batch_test.rs new file mode 100644 index 0000000000..3a2ed83baa --- /dev/null +++ b/scylla/src/transport/silent_prepare_batch_test.rs @@ -0,0 +1,110 @@ +use crate::{ + batch::Batch, + prepared_statement::PreparedStatement, + test_utils::{create_new_session_builder, unique_keyspace_name}, + Session, +}; +use std::collections::BTreeSet; + +#[tokio::test] +async fn test_quietly_prepare_batch() { + let session = create_new_session_builder().build().await.unwrap(); + + let ks = unique_keyspace_name(); + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session.use_keyspace(ks.clone(), false).await.unwrap(); + + session + .query( + "CREATE TABLE test_batch_table (a int, b int, primary key (a, b))", + (), + ) + .await + .unwrap(); + + let unprepared_insert_a_b: &str = "insert into test_batch_table (a, b) values (?, ?)"; + let unprepared_insert_a_7: &str = "insert into test_batch_table (a, b) values (?, 7)"; + let unprepared_insert_8_b: &str = "insert into test_batch_table (a, b) values (8, ?)"; + let unprepared_insert_1_2: &str = "insert into test_batch_table (a, b) values (1, 2)"; + let unprepared_insert_2_3: &str = "insert into test_batch_table (a, b) values (2, 3)"; + let unprepared_insert_3_4: &str = "insert into test_batch_table (a, b) values (3, 4)"; + let unprepared_insert_4_5: &str = "insert into test_batch_table (a, b) values (4, 5)"; + let prepared_insert_a_b: PreparedStatement = + session.prepare(unprepared_insert_a_b).await.unwrap(); + let prepared_insert_a_7: PreparedStatement = + session.prepare(unprepared_insert_a_7).await.unwrap(); + let prepared_insert_8_b: PreparedStatement = + session.prepare(unprepared_insert_8_b).await.unwrap(); + + { + let mut fully_prepared_batch: Batch = Default::default(); + fully_prepared_batch.append_statement(prepared_insert_a_b); + fully_prepared_batch.append_statement(prepared_insert_a_7.clone()); + fully_prepared_batch.append_statement(prepared_insert_8_b); + + session + .batch(&fully_prepared_batch, ((50, 60), (50,), (60,))) + .await + .unwrap(); + + assert_test_batch_table_rows_contain(&session, &[(50, 60), (50, 7), (8, 60)]).await; + } + + { + let mut unprepared_batch1: Batch = Default::default(); + unprepared_batch1.append_statement(unprepared_insert_1_2); + unprepared_batch1.append_statement(unprepared_insert_2_3); + unprepared_batch1.append_statement(unprepared_insert_3_4); + + session + .batch(&unprepared_batch1, ((), (), ())) + .await + .unwrap(); + assert_test_batch_table_rows_contain(&session, &[(1, 2), (2, 3), (3, 4)]).await; + } + + { + let mut unprepared_batch2: Batch = Default::default(); + unprepared_batch2.append_statement(unprepared_insert_a_b); + unprepared_batch2.append_statement(unprepared_insert_a_7); + unprepared_batch2.append_statement(unprepared_insert_8_b); + + session + .batch(&unprepared_batch2, ((12, 22), (12,), (22,))) + .await + .unwrap(); + assert_test_batch_table_rows_contain(&session, &[(12, 22), (12, 7), (8, 22)]).await; + } + + { + let mut partially_prepared_batch: Batch = Default::default(); + partially_prepared_batch.append_statement(unprepared_insert_a_b); + partially_prepared_batch.append_statement(prepared_insert_a_7); + partially_prepared_batch.append_statement(unprepared_insert_4_5); + + session + .batch(&partially_prepared_batch, ((33, 43), (33,), ())) + .await + .unwrap(); + assert_test_batch_table_rows_contain(&session, &[(33, 43), (33, 7), (4, 5)]).await; + } +} + +async fn assert_test_batch_table_rows_contain(sess: &Session, expected_rows: &[(i32, i32)]) { + let selected_rows: BTreeSet<(i32, i32)> = sess + .query("SELECT a, b FROM test_batch_table", ()) + .await + .unwrap() + .rows_typed::<(i32, i32)>() + .unwrap() + .map(|r| r.unwrap()) + .collect(); + for expected_row in expected_rows.iter() { + if !selected_rows.contains(expected_row) { + panic!( + "Expected {:?} to contain row: {:?}, but they didnt", + selected_rows, expected_row + ); + } + } +}