Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/batch length #824

Merged
merged 7 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cassandra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: cargo build --verbose --tests
- name: Run tests on cassandra
run: |
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info --skip test_large_batch_statements
- name: Stop the cluster
if: ${{ always() }}
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop
Expand Down
4 changes: 4 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ pub enum BadQuery {
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")]
TooManyQueriesInBatchStatement(usize),

/// Other reasons of bad query
#[error("{0}")]
Other(String),
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum ParseError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("type not yet implemented, id: {0}")]
TypeNotImplemented(i16),
TypeNotImplemented(u16),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
#[error(transparent)]
Expand Down
31 changes: 21 additions & 10 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut};
use num_enum::TryFromPrimitive;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::convert::{Infallible, TryFrom};
use std::net::IpAddr;
use std::net::SocketAddr;
use std::str;
Expand All @@ -16,7 +16,7 @@ use uuid::Uuid;
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))]
#[repr(i16)]
#[repr(u16)]
pub enum Consistency {
Any = 0x0000,
One = 0x0001,
Expand Down Expand Up @@ -98,6 +98,12 @@ impl From<std::str::Utf8Error> for ParseError {
}
}

impl From<Infallible> for ParseError {
fn from(_: Infallible) -> Self {
ParseError::BadIncomingData("Unexpected Infallible Error".to_string())
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what this trait implementation does?
AFAIU Infallible is for things that can never happen, so why do we want to convert it to a ParseError?

https://doc.rust-lang.org/std/convert/enum.Infallible.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the conversion of u16 to a usize. I wanted to do a simple as but didn't want to modify the existing code as much. I think converting from the previous i16 to a usize would have failed with the TryFromIntError error, but with u16 -> usize, it really is in fact infallible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, I understand now.
So we have a few pieces of code like this one:

let a: i16 = 123;
let b: usize = a.try_into()?;

And after changing the i16 to u16 they no longer compile because try_into() has an Infallible error type.
We can implement a conversion from Infallible to ParseError to make it compile, but it's a bit hacky.

I think it would be better to replace the try_into()?with into(), like this:

let a: u16 = 123;
let b: usize = a.into();

There is a an implentation of From<u16> for usize, so we can just use into() here.


impl From<std::array::TryFromSliceError> for ParseError {
fn from(_err: std::array::TryFromSliceError) -> Self {
ParseError::BadIncomingData("array try from slice failed".to_string())
Expand Down Expand Up @@ -169,13 +175,18 @@ fn type_long() {
}
}

pub fn read_short(buf: &mut &[u8]) -> Result<i16, ParseError> {
let v = buf.read_i16::<BigEndian>()?;
pub fn read_short(buf: &mut &[u8]) -> Result<u16, ParseError> {
let v = buf.read_u16::<BigEndian>()?;
Ok(v)
}

pub fn read_u16(buf: &mut &[u8]) -> Result<u16, ParseError> {
let v = buf.read_u16::<BigEndian>()?;
Ok(v)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the read_u16 isn't needed anymore, we have read_short that does the same thing. Let's remove read_u16.


pub fn write_short(v: i16, buf: &mut impl BufMut) {
buf.put_i16(v);
pub fn write_short(v: u16, buf: &mut impl BufMut) {
buf.put_u16(v);
}

pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
Expand All @@ -185,14 +196,14 @@ pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
}

fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> {
let v: i16 = v.try_into()?;
let v: u16 = v.try_into()?;
write_short(v, buf);
Ok(())
}

#[test]
fn type_short() {
let vals = [i16::MIN, -1, 0, 1, i16::MAX];
let vals: [u16; 3] = [0, 1, u16::MAX];
for val in vals.iter() {
let mut buf = Vec::new();
write_short(*val, &mut buf);
Expand Down Expand Up @@ -464,11 +475,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, ParseError> {
}

pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

#[test]
Expand Down
10 changes: 5 additions & 5 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Time(pub Duration);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SerializedValues {
serialized_values: Vec<u8>,
values_num: i16,
values_num: u16,
contains_names: bool,
}

Expand Down Expand Up @@ -134,7 +134,7 @@ impl SerializedValues {
if self.contains_names {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}
Comment on lines +137 to 139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The message for TooManyValues has to be adjusted as well, currently it mentions i16::MAX requests.

#[error("Too many values to add, max 32 767 values can be sent in a request")]
TooManyValues,


Expand All @@ -158,7 +158,7 @@ impl SerializedValues {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
self.contains_names = true;
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -184,15 +184,15 @@ impl SerializedValues {
}

pub fn write_to_request(&self, buf: &mut impl BufMut) {
buf.put_i16(self.values_num);
buf.put_u16(self.values_num);
buf.put(&self.serialized_values[..]);
}

pub fn is_empty(&self) -> bool {
self.values_num == 0
}

pub fn len(&self) -> i16 {
pub fn len(&self) -> u16 {
self.values_num
}

Expand Down
2 changes: 1 addition & 1 deletion scylla/src/statement/prepared_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ impl PreparedStatement {
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, i16),
NoPkIndexValue(u16, u16),
}

#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
67 changes: 67 additions & 0 deletions scylla/src/transport/large_batch_statements_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use assert_matches::assert_matches;

use scylla_cql::errors::{BadQuery, QueryError};

use crate::batch::BatchType;
use crate::query::Query;
use crate::{
batch::Batch,
test_utils::{create_new_session_builder, unique_keyspace_name},
QueryResult, Session,
};

#[tokio::test]
async fn test_large_batch_statements() {
let mut session = create_new_session_builder().build().await.unwrap();

let ks = unique_keyspace_name();
session = create_test_session(session, &ks).await;

let max_queries = u16::MAX as usize;
let batch_insert_result = write_batch(&session, max_queries, &ks).await;

assert!(batch_insert_result.is_ok());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change it to batch_insert_result.unwrap()? unwrap() will print the error message when the test fails. assert! will only give us assertion failed, which isn't very helpful.


let too_many_queries = u16::MAX as usize + 1;
let batch_insert_result = write_batch(&session, too_many_queries, &ks).await;
assert_matches!(
batch_insert_result.unwrap_err(),
QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(_too_many_queries)) if _too_many_queries == too_many_queries
)
}

async fn create_test_session(session: Session, ks: &String) -> Session {
session
.query(
format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use CREATE KEYSPACE without IF EXISTS, the ks name is guaranteed to be unique.

&[],
)
.await.unwrap();
session
.query(format!("DROP TABLE IF EXISTS {}.pairs;", ks), &[])
.await
.unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed, the test will create a new keyspace because ks is guaranteed to be unique.

session
.query(
format!("CREATE TABLE IF NOT EXISTS {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", ks),
&[],
)
.await.unwrap();
session
}

async fn write_batch(session: &Session, n: usize, ks: &String) -> Result<QueryResult, QueryError> {
let mut batch_query = Batch::new(BatchType::Logged);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An Unlogged batch might be faster to execute, it could speed up the test a bit.

let mut batch_values = Vec::new();
for i in 0..n {
let mut key = vec![0];
key.extend(i.to_be_bytes().as_slice());
let value = key.clone();
let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks);
let values = vec![key, value];
batch_values.push(values);
let query = Query::new(query);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using prepared statements instead? Prepare the statement once before creating the batch and then use it in append_statement: batch_query.append_statement(prepared.clone()). This should reduce the work needed by DB to process the batch.

batch_query.append_statement(query);
}
session.batch(&batch_query, batch_values).await
}
2 changes: 2 additions & 0 deletions scylla/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod silent_prepare_batch_test;
mod cql_types_test;
#[cfg(test)]
mod cql_value_test;
#[cfg(test)]
mod large_batch_statements_test;

pub use cluster::ClusterData;
pub use node::{KnownNode, Node, NodeAddr, NodeRef};
8 changes: 8 additions & 0 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize;
use crate::authentication::AuthenticatorProvider;
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
use scylla_cql::errors::BadQuery;

/// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses.
///
Expand Down Expand Up @@ -1143,6 +1144,13 @@ impl Session {
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
// If users batch statements by shard, they will be rewarded with full shard awareness

// check to ensure that we don't send a batch statement with more than u16::MAX queries
let batch_statements_length = batch.statements.len();
if batch_statements_length > u16::MAX as usize {
return Err(QueryError::BadQuery(
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
));
}
// Extract first serialized_value
let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?;
let first_serialized_value = first_serialized_value.as_deref();
Expand Down