From 82c1c99f0ff86509f9dd1e649ecdaddc5a3660cf Mon Sep 17 00:00:00 2001 From: Samuel Orji Date: Tue, 17 Oct 2023 15:55:39 +0100 Subject: [PATCH] Fix/batch length (#824) * changed the type of the maximum number of statements in a batch query from an i16 to a u16 according to the CQL protocol spec * add a guard when doing batch statements to prevent making calls to the server when the number of batch queries is greater than u16::MAX, as well as adding some tests --- .github/workflows/cassandra.yml | 2 +- scylla-cql/src/errors.rs | 4 ++ scylla-cql/src/frame/frame_errors.rs | 2 +- scylla-cql/src/frame/request/batch.rs | 2 +- scylla-cql/src/frame/response/result.rs | 4 +- scylla-cql/src/frame/types.rs | 20 +++--- scylla-cql/src/frame/value.rs | 12 ++-- scylla/src/statement/prepared_statement.rs | 2 +- .../transport/large_batch_statements_test.rs | 68 +++++++++++++++++++ scylla/src/transport/mod.rs | 2 + scylla/src/transport/session.rs | 8 +++ 11 files changed, 104 insertions(+), 22 deletions(-) create mode 100644 scylla/src/transport/large_batch_statements_test.rs diff --git a/.github/workflows/cassandra.yml b/.github/workflows/cassandra.yml index 5cc118f1f4..e22b915d46 100644 --- a/.github/workflows/cassandra.yml +++ b/.github/workflows/cassandra.yml @@ -29,7 +29,7 @@ jobs: run: cargo build --verbose --tests - name: Run tests on cassandra run: | - CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info + CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info --skip test_large_batch_statements - name: Stop the cluster if: ${{ always() }} run: docker compose -f test/cluster/cassandra/docker-compose.yml stop diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index 40587cfef6..9e80247e20 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -348,6 +348,10 @@ pub enum BadQuery { #[error("Passed invalid keyspace name to use: {0}")] BadKeyspaceName(#[from] BadKeyspaceName), + /// Too many queries in the batch statement + #[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")] + TooManyQueriesInBatchStatement(usize), + /// Other reasons of bad query #[error("{0}")] Other(String), diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 403b6ab5fd..3da4e26d01 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -40,7 +40,7 @@ pub enum ParseError { #[error(transparent)] IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] - TypeNotImplemented(i16), + TypeNotImplemented(u16), #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), #[error(transparent)] diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 3c0bad3931..35dd8c3c3b 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_short(buf)?.try_into()?; + let statements_count: usize = types::read_short(buf)?.into(); let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 288baf91eb..5ade677343 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -437,7 +437,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { 0x0030 => { let keyspace_name: String = types::read_string(buf)?.to_string(); let type_name: String = types::read_string(buf)?.to_string(); - let fields_size: usize = types::read_short(buf)?.try_into()?; + let fields_size: usize = types::read_short(buf)?.into(); let mut field_types: Vec<(String, ColumnType)> = Vec::with_capacity(fields_size); @@ -455,7 +455,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult { } } 0x0031 => { - let len: usize = types::read_short(buf)?.try_into()?; + let len: usize = types::read_short(buf)?.into(); let mut types = Vec::with_capacity(len); for _ in 0..len { types.push(deser_type(buf)?); diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index fd2254c8b0..672fe2f97e 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -16,7 +16,7 @@ use uuid::Uuid; #[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] #[cfg_attr(feature = "serde", derive(serde::Deserialize))] #[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))] -#[repr(i16)] +#[repr(u16)] pub enum Consistency { Any = 0x0000, One = 0x0001, @@ -169,30 +169,30 @@ fn type_long() { } } -pub fn read_short(buf: &mut &[u8]) -> Result { - let v = buf.read_i16::()?; +pub fn read_short(buf: &mut &[u8]) -> Result { + let v = buf.read_u16::()?; Ok(v) } -pub fn write_short(v: i16, buf: &mut impl BufMut) { - buf.put_i16(v); +pub fn write_short(v: u16, buf: &mut impl BufMut) { + buf.put_u16(v); } pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { let v = read_short(buf)?; - let v: usize = v.try_into()?; + let v: usize = v.into(); Ok(v) } fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> { - let v: i16 = v.try_into()?; + let v: u16 = v.try_into()?; write_short(v, buf); Ok(()) } #[test] fn type_short() { - let vals = [i16::MIN, -1, 0, 1, i16::MAX]; + let vals: [u16; 3] = [0, 1, u16::MAX]; for val in vals.iter() { let mut buf = Vec::new(); write_short(*val, &mut buf); @@ -464,11 +464,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result { } pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) { - write_short(c as i16, buf); + write_short(c as u16, buf); } #[test] diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index e9164f2531..17b75ea855 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -63,7 +63,7 @@ pub struct Time(pub Duration); #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct SerializedValues { serialized_values: Vec, - values_num: i16, + values_num: u16, contains_names: bool, } @@ -77,7 +77,7 @@ pub struct CqlDuration { #[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum SerializeValuesError { - #[error("Too many values to add, max 32 767 values can be sent in a request")] + #[error("Too many values to add, max 65,535 values can be sent in a request")] TooManyValues, #[error("Mixing named and not named values is not allowed")] MixingNamedAndNotNamedValues, @@ -134,7 +134,7 @@ impl SerializedValues { if self.contains_names { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -158,7 +158,7 @@ impl SerializedValues { return Err(SerializeValuesError::MixingNamedAndNotNamedValues); } self.contains_names = true; - if self.values_num == i16::MAX { + if self.values_num == u16::MAX { return Err(SerializeValuesError::TooManyValues); } @@ -184,7 +184,7 @@ impl SerializedValues { } pub fn write_to_request(&self, buf: &mut impl BufMut) { - buf.put_i16(self.values_num); + buf.put_u16(self.values_num); buf.put(&self.serialized_values[..]); } @@ -192,7 +192,7 @@ impl SerializedValues { self.values_num == 0 } - pub fn len(&self) -> i16 { + pub fn len(&self) -> u16 { self.values_num } diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index b57d5d4b23..9814e7350d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -339,7 +339,7 @@ impl PreparedStatement { #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] pub enum PartitionKeyExtractionError { #[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")] - NoPkIndexValue(u16, i16), + NoPkIndexValue(u16, u16), } #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs new file mode 100644 index 0000000000..29482e31ce --- /dev/null +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -0,0 +1,68 @@ +use assert_matches::assert_matches; + +use scylla_cql::errors::{BadQuery, QueryError}; + +use crate::batch::BatchType; +use crate::query::Query; +use crate::{ + batch::Batch, + test_utils::{create_new_session_builder, unique_keyspace_name}, + QueryResult, Session, +}; + +#[tokio::test] +async fn test_large_batch_statements() { + let mut session = create_new_session_builder().build().await.unwrap(); + + let ks = unique_keyspace_name(); + session = create_test_session(session, &ks).await; + + let max_queries = u16::MAX as usize; + let batch_insert_result = write_batch(&session, max_queries, &ks).await; + + batch_insert_result.unwrap(); + + let too_many_queries = u16::MAX as usize + 1; + let batch_insert_result = write_batch(&session, too_many_queries, &ks).await; + assert_matches!( + batch_insert_result.unwrap_err(), + QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(_too_many_queries)) if _too_many_queries == too_many_queries + ) +} + +async fn create_test_session(session: Session, ks: &String) -> Session { + session + .query( + format!("CREATE KEYSPACE {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks), + &[], + ) + .await.unwrap(); + session + .query( + format!( + "CREATE TABLE {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + ks + ), + &[], + ) + .await + .unwrap(); + session +} + +async fn write_batch(session: &Session, n: usize, ks: &String) -> Result { + let mut batch_query = Batch::new(BatchType::Unlogged); + let mut batch_values = Vec::new(); + let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks); + let query = Query::new(query); + let prepared_statement = session.prepare(query).await.unwrap(); + for i in 0..n { + let mut key = vec![0]; + key.extend(i.to_be_bytes().as_slice()); + let value = key.clone(); + let values = vec![key, value]; + batch_values.push(values); + batch_query.append_statement(prepared_statement.clone()); + } + session.batch(&batch_query, batch_values).await +} diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 939983cfc4..a33943645d 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -35,6 +35,8 @@ mod silent_prepare_batch_test; mod cql_types_test; #[cfg(test)] mod cql_value_test; +#[cfg(test)] +mod large_batch_statements_test; pub use cluster::ClusterData; pub use node::{KnownNode, Node, NodeAddr, NodeRef}; diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 35ff25475f..2f67874f8c 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize; use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; +use scylla_cql::errors::BadQuery; /// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses. /// @@ -1143,6 +1144,13 @@ impl Session { // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard // If users batch statements by shard, they will be rewarded with full shard awareness + // check to ensure that we don't send a batch statement with more than u16::MAX queries + let batch_statements_length = batch.statements.len(); + if batch_statements_length > u16::MAX as usize { + return Err(QueryError::BadQuery( + BadQuery::TooManyQueriesInBatchStatement(batch_statements_length), + )); + } // Extract first serialized_value let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?; let first_serialized_value = first_serialized_value.as_deref();