Skip to content

Commit

Permalink
query_result: Add sequential iteration over map
Browse files Browse the repository at this point in the history
C++ driver enables to iterate over a map sequentially like lists or sets.
The `CassCollectionIterator` is modified to enum to contain either
`SequenceIterator` or `MapIterator` for supporting this feature.
  • Loading branch information
Gor027 committed May 27, 2023
1 parent 4d04864 commit 1256ce9
Showing 1 changed file with 87 additions and 22 deletions.
109 changes: 87 additions & 22 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ pub struct CassRowIterator {
position: Option<usize>,
}

pub struct CassCollectionIterator {
/// For sequential iteration over collection types
pub enum CassCollectionIterator {
SequenceIterator(CassSequenceIterator),
SeqMapIterator(CassMapIterator),
}

pub struct CassSequenceIterator {
sequence_iterator: SequenceIterator<'static, RawValue<'static>>,
value: Option<CassValue>,
count: usize,
Expand Down Expand Up @@ -244,26 +250,48 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass

(new_pos < row_iterator.row.columns.len()) as cass_bool_t
}
CassIterator::CassCollectionIterator(collection_iterator) => {
let new_pos: usize = collection_iterator
.position
.map_or(0, |prev_pos| prev_pos + 1);
CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator {
CassCollectionIterator::SequenceIterator(seq_iterator) => {
let new_pos: usize = seq_iterator.position.map_or(0, |prev_pos| prev_pos + 1);

collection_iterator.position = Some(new_pos);
seq_iterator.position = Some(new_pos);

if new_pos < collection_iterator.count {
let raw_value = collection_iterator.sequence_iterator.next().unwrap();
if let Ok(raw) = raw_value {
let raw_value_type = raw.spec;
let value = decode_value(raw, raw_value_type);
collection_iterator.value = value;
if new_pos < seq_iterator.count {
let raw_value = seq_iterator.sequence_iterator.next().unwrap();
if let Ok(raw) = raw_value {
let raw_value_type = raw.spec;
let value = decode_value(raw, raw_value_type);
seq_iterator.value = value;

return true as cass_bool_t;
return true as cass_bool_t;
}
}

false as cass_bool_t
}
CassCollectionIterator::SeqMapIterator(seq_map_iterator) => {
let new_pos: usize = seq_map_iterator.position.map_or(0, |prev_pos| prev_pos + 1);
seq_map_iterator.position = Some(new_pos);

if new_pos < seq_map_iterator.count {
if new_pos % 2 == 0 {
let raw_value = seq_map_iterator.map_iterator.next().unwrap();
if let Ok((raw_key, raw_value)) = raw_value {
let key_type = raw_key.spec;
let key = decode_value(raw_key, key_type);
let value_type = raw_value.spec;
let value = decode_value(raw_value, value_type);
seq_map_iterator.key = key;
seq_map_iterator.value = value;
}
}

false as cass_bool_t
}
return true as cass_bool_t;
}

false as cass_bool_t
}
},
CassIterator::CassTupleIterator(tuple_iterator) => {
let new_pos: usize = tuple_iterator.position.map_or(0, |prev_pos| prev_pos + 1);

Expand Down Expand Up @@ -439,11 +467,20 @@ pub unsafe extern "C" fn cass_iterator_get_value(

// Defined only for collections(list and set) or tuple iterator, for other types should return null
match iter {
CassIterator::CassCollectionIterator(collection_iterator)
if collection_iterator.value.is_some() =>
{
collection_iterator.value.as_ref().unwrap()
}
CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator {
CassCollectionIterator::SequenceIterator(seq_iter) => seq_iter.value.as_ref().unwrap(),
CassCollectionIterator::SeqMapIterator(seq_map_iter) => {
if let Some(pos) = seq_map_iter.position {
if pos % 2 == 0 && seq_map_iter.key.is_some() {
return seq_map_iter.key.as_ref().unwrap();
} else if pos % 2 == 0 && seq_map_iter.value.is_some() {
return seq_map_iter.value.as_ref().unwrap();
}
}

std::ptr::null()
}
},
CassIterator::CassTupleIterator(tuple_iterator) if tuple_iterator.value.is_some() => {
tuple_iterator.value.as_ref().unwrap()
}
Expand All @@ -465,6 +502,12 @@ pub unsafe extern "C" fn cass_iterator_get_map_key(
.is_some()); // assertion copied from c++ driver
map_iterator.key.as_ref().unwrap()
}
CassIterator::CassCollectionIterator(collection_iterator) => {
match collection_iterator {
CassCollectionIterator::SeqMapIterator(map_iter) => map_iter.key.as_ref().unwrap(),
CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator
}
}
_ => std::ptr::null(),
}
}
Expand All @@ -483,6 +526,14 @@ pub unsafe extern "C" fn cass_iterator_get_map_value(
.is_some()); // assertion copied from c++ driver
map_iterator.value.as_ref().unwrap()
}
CassIterator::CassCollectionIterator(collection_iterator) => {
match collection_iterator {
CassCollectionIterator::SeqMapIterator(map_iter) => {
map_iter.value.as_ref().unwrap()
}
CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator
}
}
_ => std::ptr::null(),
}
}
Expand Down Expand Up @@ -735,16 +786,30 @@ pub unsafe extern "C" fn cass_iterator_from_collection(
let item_count = collection.count;
let column_type = collection.column_type;
match column_type {
ColumnType::Map(_, _) => {
let map_iterator = MapIterator::deserialize(column_type, collection.frame_slice);
if let Ok(map_iter) = map_iterator {
let iterator = CassCollectionIterator::SeqMapIterator(CassMapIterator {
map_iterator: map_iter,
key: None,
value: None,
count: item_count * 2,
position: None,
});

return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator)));
}
}
ColumnType::Set(_) | ColumnType::List(_) => {
let sequence_iterator =
SequenceIterator::deserialize(column_type, collection.frame_slice);
if let Ok(seq_iterator) = sequence_iterator {
let iterator = CassCollectionIterator {
let iterator = CassCollectionIterator::SequenceIterator(CassSequenceIterator {
sequence_iterator: seq_iterator,
value: None,
count: item_count,
position: None,
};
});

return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator)));
}
Expand Down

0 comments on commit 1256ce9

Please sign in to comment.