diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index c99be91e..4a1400d6 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -103,9 +103,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.2.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0b3de4a0c5e67e16066a0715723abd91edc2f9001d09c46e1dca929351e130e" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "cc" @@ -991,6 +991,7 @@ dependencies = [ "assert_matches", "bigdecimal", "bindgen", + "bytes", "chrono", "lazy_static", "libc", diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index f941e0d6..a8b06a32 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -34,6 +34,7 @@ chrono = "0.4.20" assert_matches = "1.5.0" ntest = "0.9.0" rusty-fork = "0.3.0" +bytes = "1.4.0" scylla-proxy = { git = "https://github.com/Gor027/scylla-rust-driver.git", rev = "6585f06"} [lib] diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 34dccf53..2a5e8a65 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -1544,6 +1544,304 @@ pub unsafe extern "C" fn cass_result_paging_state_token( CassError::CASS_OK } +#[cfg(test)] +mod tests { + use crate::cass_error::CassError; + use crate::query_result::{ + cass_iterator_get_value, cass_iterator_next, cass_value_get_bool, CassCollectionIterator, + CassIterator, CassIteratorStateInfo, CassMapIterator, CassSequenceIterator, + CassTupleIterator, CassUdtIterator, + }; + use crate::testing::assert_cass_error_eq; + use crate::types::cass_bool_t; + use bytes::{BufMut, Bytes, BytesMut}; + use num_traits::Zero; + use scylla::frame::response::result::ColumnType; + use scylla::types::deserialize::value::{DeserializeCql, MapIterator}; + use scylla::types::deserialize::value::{SequenceIterator, UdtIterator}; + use scylla::types::deserialize::FrameSlice; + + #[test] + #[ntest::timeout(100)] + fn test_collection_seq_iterator_empty_raw_value() { + unsafe { + let mut bytes_mut = BytesMut::new(); + bytes_mut.put_i32(1); // Number of values + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = &ColumnType::List(Box::new(ColumnType::Text)); + let sequence_iterator = + SequenceIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)) + .unwrap(); + let mut collection_iterator = CassIterator::CassCollectionIterator( + CassCollectionIterator::SequenceIterator(CassSequenceIterator { + sequence_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }), + ); + let has_next = cass_iterator_next(&mut collection_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_seq_iterator_reached_the_end() { + unsafe { + let mut bytes_mut = BytesMut::new(); + let text = "test"; + bytes_mut.put_i32(1); // Number of values + bytes_mut.put_i32(text.len() as i32); + bytes_mut.put_slice(text.as_bytes()); + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = &ColumnType::List(Box::new(ColumnType::Text)); + let sequence_iterator = + SequenceIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)) + .unwrap(); + let mut collection_iterator = CassIterator::CassCollectionIterator( + CassCollectionIterator::SequenceIterator(CassSequenceIterator { + sequence_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }), + ); + let has_next = cass_iterator_next(&mut collection_iterator); + assert_ne!(has_next, 0); + // Reached the end + let has_next = cass_iterator_next(&mut collection_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_map_iterator_empty_raw_value() { + unsafe { + let mut bytes_mut = BytesMut::new(); + bytes_mut.put_i32(1); // Number of values + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = + &ColumnType::Map(Box::new(ColumnType::Text), Box::new(ColumnType::Text)); + let map_iterator = + MapIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)).unwrap(); + let mut collection_iterator = CassIterator::CassCollectionIterator( + CassCollectionIterator::SeqMapIterator(CassMapIterator { + map_iterator, + count: 2, + state_info: CassIteratorStateInfo::NoValue, + }), + ); + let has_next = cass_iterator_next(&mut collection_iterator); + assert!(has_next.is_zero()); + + let map_iterator = + MapIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)).unwrap(); + let mut map_iterator = CassIterator::CassMapIterator(CassMapIterator { + map_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut map_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_map_iterator_reached_the_end() { + unsafe { + let mut bytes_mut = BytesMut::new(); + let text = "key"; + let true_bytes: &[u8] = &[0x01]; + bytes_mut.put_i32(1); // Number of values + + // Put serialized (string, bool) pair in bytes + bytes_mut.put_i32(text.len() as i32); + bytes_mut.put_slice(text.as_bytes()); + bytes_mut.put_i32(1); + bytes_mut.put(true_bytes); + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = + &ColumnType::Map(Box::new(ColumnType::Text), Box::new(ColumnType::Boolean)); + let map_iterator = + MapIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)).unwrap(); + let mut collection_iterator = CassIterator::CassCollectionIterator( + CassCollectionIterator::SeqMapIterator(CassMapIterator { + map_iterator, + count: 2, + state_info: CassIteratorStateInfo::NoValue, + }), + ); + let has_next = cass_iterator_next(&mut collection_iterator); // Position on key + assert_ne!(has_next, 0); + let has_next = cass_iterator_next(&mut collection_iterator); // Position on value + assert_ne!(has_next, 0); + // Reached the end + let has_next = cass_iterator_next(&mut collection_iterator); + assert!(has_next.is_zero()); + + let map_iterator = + MapIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)).unwrap(); + let mut map_iterator = CassIterator::CassMapIterator(CassMapIterator { + map_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut map_iterator); + assert_ne!(has_next, 0); + // Reached the end + let has_next = cass_iterator_next(&mut map_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_seq_map_iterator_deserialize_pair() { + // To test that sequential iterator deserializes in pairs, + // after the first `cass_iterator_next` call the value of the first pair + // will be retrieved. + unsafe { + let mut bytes_mut = BytesMut::new(); + let key1 = "key1"; + let true_bytes: &[u8] = &[0x01]; + bytes_mut.put_i32(1); // Number of values + + // Put serialized (string, bool) pairs in bytes + bytes_mut.put_i32(key1.len() as i32); + bytes_mut.put_slice(key1.as_bytes()); + bytes_mut.put_i32(1); + bytes_mut.put(true_bytes); + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = + &ColumnType::Map(Box::new(ColumnType::Text), Box::new(ColumnType::Boolean)); + let map_iterator = + MapIterator::deserialize(column_type.as_ref().unwrap(), Some(slice_frame)).unwrap(); + let mut collection_iterator = CassIterator::CassCollectionIterator( + CassCollectionIterator::SeqMapIterator(CassMapIterator { + map_iterator, + count: 2, + state_info: CassIteratorStateInfo::NoValue, + }), + ); + + let has_next = cass_iterator_next(&mut collection_iterator); + assert_ne!(has_next, 0); + + // Hacking the iterator to not change the value but increment the position + if let CassIterator::CassCollectionIterator(CassCollectionIterator::SeqMapIterator( + CassMapIterator { state_info, .. }, + )) = &mut collection_iterator + { + state_info.advance(); + } + + // Reached the end + let value = cass_iterator_get_value(&collection_iterator); + let mut output = cass_bool_t::default(); + assert_cass_error_eq!(cass_value_get_bool(value, &mut output), CassError::CASS_OK); + assert_ne!(output, 0); // Value should be true + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_tuple_iterator_empty_raw_value() { + unsafe { + let mut bytes_mut = BytesMut::new(); + bytes_mut.put_i32(1); // Number of values + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = &ColumnType::Tuple(vec![ColumnType::Text]); + let sequence_iterator = + SequenceIterator::new(column_type.as_ref().unwrap(), 1, slice_frame); + let mut tuple_iterator = CassIterator::CassTupleIterator(CassTupleIterator { + sequence_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut tuple_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_tuple_iterator_reached_the_end() { + unsafe { + let mut bytes_mut = BytesMut::new(); + let text = "test"; + bytes_mut.put_i32(text.len() as i32); + bytes_mut.put_slice(text.as_bytes()); + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let column_type: *const ColumnType = &ColumnType::Tuple(vec![ColumnType::Text]); + let sequence_iterator = + SequenceIterator::new(column_type.as_ref().unwrap(), 1, slice_frame); + let mut tuple_iterator = CassIterator::CassTupleIterator(CassTupleIterator { + sequence_iterator, + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut tuple_iterator); + assert_ne!(has_next, 0); + // Reached the end + let has_next = cass_iterator_next(&mut tuple_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_udt_iterator_empty_raw_value() { + unsafe { + let mut bytes_mut = BytesMut::new(); + bytes_mut.put_i32(1); // Number of values + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let fields: *const Vec<(String, ColumnType)> = + &vec![("string".to_owned(), ColumnType::Text)]; + let mut udt_iterator = CassIterator::CassUdtIterator(CassUdtIterator { + udt_iterator: UdtIterator::new(fields.as_ref().unwrap(), slice_frame), + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut udt_iterator); + assert!(has_next.is_zero()); + } + } + + #[test] + #[ntest::timeout(100)] + fn test_collection_udt_iterator_reached_the_end() { + unsafe { + let mut bytes_mut = BytesMut::new(); + let text = "test"; + bytes_mut.put_i32(text.len() as i32); + bytes_mut.put_slice(text.as_bytes()); + let bytes: *const Bytes = &Bytes::from(bytes_mut); + let slice_frame = FrameSlice::new(bytes.as_ref().unwrap()); + let fields: *const Vec<(String, ColumnType)> = + &vec![("string".to_owned(), ColumnType::Text)]; + let mut udt_iterator = CassIterator::CassUdtIterator(CassUdtIterator { + udt_iterator: UdtIterator::new(fields.as_ref().unwrap(), slice_frame), + count: 1, + state_info: CassIteratorStateInfo::NoValue, + }); + let has_next = cass_iterator_next(&mut udt_iterator); + assert_ne!(has_next, 0); + // Reached the end + let has_next = cass_iterator_next(&mut udt_iterator); + assert!(has_next.is_zero()); + } + } +} + // CassResult functions: /* extern "C" {