diff --git a/scylla-cql/src/types/serialize/raw_batch.rs b/scylla-cql/src/types/serialize/raw_batch.rs index eb3252e08..6c755c5e8 100644 --- a/scylla-cql/src/types/serialize/raw_batch.rs +++ b/scylla-cql/src/types/serialize/raw_batch.rs @@ -52,11 +52,7 @@ pub trait RawBatchValuesIterator<'a> { where Self: Sized, { - let mut count = 0; - while self.skip_next().is_some() { - count += 1; - } - count + std::iter::from_fn(|| self.skip_next()).count() } } @@ -145,19 +141,26 @@ where { #[inline] fn serialize_next(&mut self, writer: &mut RowWriter) -> Option> { - let ctx = self.contexts.next()?; + // We do `unwrap_or` because we want the iterator length to be the same + // as the amount of values. Limiting to length of the amount of + // statements (contexts) causes the caller to not be able to correctly + // detect that amount of statements and values is different. + let ctx = self + .contexts + .next() + .unwrap_or(RowSerializationContext::empty()); self.batch_values_iterator.serialize_next(&ctx, writer) } fn is_empty_next(&mut self) -> Option { - self.contexts.next()?; + let _ = self.contexts.next(); let ret = self.batch_values_iterator.is_empty_next()?; Some(ret) } #[inline] fn skip_next(&mut self) -> Option<()> { - self.contexts.next()?; + let _ = self.contexts.next(); self.batch_values_iterator.skip_next()?; Some(()) } diff --git a/scylla/tests/integration/batch.rs b/scylla/tests/integration/batch.rs new file mode 100644 index 000000000..d711cb501 --- /dev/null +++ b/scylla/tests/integration/batch.rs @@ -0,0 +1,74 @@ +use scylla::batch::Batch; +use scylla::batch::BatchType; +use scylla::frame::frame_errors::BatchSerializationError; +use scylla::frame::frame_errors::CqlRequestSerializationError; +use scylla::query::Query; +use scylla::transport::errors::QueryError; + +use crate::utils::create_new_session_builder; +use crate::utils::setup_tracing; +use crate::utils::unique_keyspace_name; +use crate::utils::PerformDDL; + +use assert_matches::assert_matches; + +#[tokio::test] +#[ntest::timeout(60000)] +async fn batch_statements_and_values_mismatch_detected() { + setup_tracing(); + let session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + session + .ddl("CREATE TABLE IF NOT EXISTS batch_serialization_test (p int PRIMARY KEY, val int)") + .await + .unwrap(); + + let mut batch = Batch::new(BatchType::Logged); + let stmt = session + .prepare("INSERT INTO batch_serialization_test (p, val) VALUES (?, ?)") + .await + .unwrap(); + batch.append_statement(stmt.clone()); + batch.append_statement(Query::new( + "INSERT INTO batch_serialization_test (p, val) VALUES (3, 4)", + )); + batch.append_statement(stmt); + + // Subtest 1: counts are correct + { + session.batch(&batch, &((1, 2), (), (5, 6))).await.unwrap(); + } + + // Subtest 2: not enough values + { + let err = session.batch(&batch, &((1, 2), ())).await.unwrap_err(); + assert_matches!( + err, + QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization( + BatchSerializationError::ValuesAndStatementsLengthMismatch { + n_value_lists: 2, + n_statements: 3 + } + )) + ) + } + + // Subtest 3: too many values + { + let err = session + .batch(&batch, &((1, 2), (), (5, 6), (7, 8))) + .await + .unwrap_err(); + assert_matches!( + err, + QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization( + BatchSerializationError::ValuesAndStatementsLengthMismatch { + n_value_lists: 4, + n_statements: 3 + } + )) + ) + } +} diff --git a/scylla/tests/integration/main.rs b/scylla/tests/integration/main.rs index d510dc6fd..86b529dc0 100644 --- a/scylla/tests/integration/main.rs +++ b/scylla/tests/integration/main.rs @@ -1,4 +1,5 @@ mod authenticate; +mod batch; mod consistency; mod cql_collections; mod cql_types;