Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RawBatchValuesIteratorAdapter length equal to its internal BatchValuesIter length #1142

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions scylla-cql/src/types/serialize/raw_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ pub trait RawBatchValuesIterator<'a> {
where
Self: Sized,
{
let mut count = 0;
while self.skip_next().is_some() {
count += 1;
}
count
std::iter::from_fn(|| self.skip_next()).count()
}
}

Expand Down Expand Up @@ -145,19 +141,26 @@ where
{
#[inline]
fn serialize_next(&mut self, writer: &mut RowWriter) -> Option<Result<(), SerializationError>> {
let ctx = self.contexts.next()?;
// We do `unwrap_or` because we want the iterator length to be the same
// as the amount of values. Limiting to length of the amount of
// statements (contexts) causes the caller to not be able to correctly
// detect that amount of statements and values is different.
let ctx = self
.contexts
.next()
.unwrap_or(RowSerializationContext::empty());
self.batch_values_iterator.serialize_next(&ctx, writer)
}

fn is_empty_next(&mut self) -> Option<bool> {
self.contexts.next()?;
let _ = self.contexts.next();
let ret = self.batch_values_iterator.is_empty_next()?;
Some(ret)
}

#[inline]
fn skip_next(&mut self) -> Option<()> {
self.contexts.next()?;
let _ = self.contexts.next();
self.batch_values_iterator.skip_next()?;
Some(())
}
Expand Down
74 changes: 74 additions & 0 deletions scylla/tests/integration/batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use scylla::batch::Batch;
use scylla::batch::BatchType;
use scylla::frame::frame_errors::BatchSerializationError;
use scylla::frame::frame_errors::CqlRequestSerializationError;
use scylla::query::Query;
use scylla::transport::errors::QueryError;

use crate::utils::create_new_session_builder;
use crate::utils::setup_tracing;
use crate::utils::unique_keyspace_name;
use crate::utils::PerformDDL;

use assert_matches::assert_matches;

#[tokio::test]
#[ntest::timeout(60000)]
async fn batch_statements_and_values_mismatch_detected() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();
session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap();
session.use_keyspace(ks, false).await.unwrap();
session
.ddl("CREATE TABLE IF NOT EXISTS batch_serialization_test (p int PRIMARY KEY, val int)")
.await
.unwrap();

let mut batch = Batch::new(BatchType::Logged);
let stmt = session
.prepare("INSERT INTO batch_serialization_test (p, val) VALUES (?, ?)")
.await
.unwrap();
batch.append_statement(stmt.clone());
batch.append_statement(Query::new(
"INSERT INTO batch_serialization_test (p, val) VALUES (3, 4)",
));
batch.append_statement(stmt);

// Subtest 1: counts are correct
{
session.batch(&batch, &((1, 2), (), (5, 6))).await.unwrap();
}

// Subtest 2: not enough values
{
let err = session.batch(&batch, &((1, 2), ())).await.unwrap_err();
assert_matches!(
err,
QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization(
BatchSerializationError::ValuesAndStatementsLengthMismatch {
n_value_lists: 2,
n_statements: 3
}
))
)
}

// Subtest 3: too many values
{
let err = session
.batch(&batch, &((1, 2), (), (5, 6), (7, 8)))
.await
.unwrap_err();
assert_matches!(
err,
QueryError::CqlRequestSerialization(CqlRequestSerializationError::BatchSerialization(
BatchSerializationError::ValuesAndStatementsLengthMismatch {
n_value_lists: 4,
n_statements: 3
}
))
)
}
}
1 change: 1 addition & 0 deletions scylla/tests/integration/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod authenticate;
mod batch;
mod consistency;
mod cql_collections;
mod cql_types;
Expand Down
Loading