Skip to content

Commit

Permalink
query_result: Add sequential iteration over map
Browse files Browse the repository at this point in the history
C++ driver enables to iterate over a map sequentially like lists or sets.
The `CassCollectionIterator` is modified to enum to contain either
`SequenceIterator` or `MapIterator` for supporting this feature.
  • Loading branch information
Gor027 committed Jun 6, 2023
1 parent 1c8e188 commit 77820bf
Showing 1 changed file with 92 additions and 20 deletions.
112 changes: 92 additions & 20 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ pub struct CassRowIterator {
position: Option<usize>,
}

pub struct CassCollectionIterator {
/// For sequential iteration over collection types
pub enum CassCollectionIterator {
SequenceIterator(CassSequenceIterator),
SeqMapIterator(CassMapIterator),
}

pub struct CassSequenceIterator {
sequence_iterator: SequenceIterator<'static, RawValue<'static>>,
value: Option<CassValue>,
count: usize,
Expand Down Expand Up @@ -244,26 +250,48 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass

(new_pos < row_iterator.row.columns.len()) as cass_bool_t
}
CassIterator::CassCollectionIterator(collection_iterator) => {
let new_pos: usize = collection_iterator
.position
.map_or(0, |prev_pos| prev_pos + 1);
CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator {
CassCollectionIterator::SequenceIterator(seq_iterator) => {
let new_pos: usize = seq_iterator.position.map_or(0, |prev_pos| prev_pos + 1);

collection_iterator.position = Some(new_pos);
seq_iterator.position = Some(new_pos);

if new_pos < collection_iterator.count {
let raw_value = collection_iterator.sequence_iterator.next().unwrap();
if let Ok(raw) = raw_value {
let raw_value_type = raw.spec;
let value = decode_value(raw, raw_value_type);
collection_iterator.value = value;
if new_pos < seq_iterator.count {
let raw_value = seq_iterator.sequence_iterator.next().unwrap();
if let Ok(raw) = raw_value {
let raw_value_type = raw.spec;
let value = decode_value(raw, raw_value_type);
seq_iterator.value = value;

return true as cass_bool_t;
return true as cass_bool_t;
}
}

false as cass_bool_t
}
CassCollectionIterator::SeqMapIterator(seq_map_iterator) => {
let new_pos: usize = seq_map_iterator.position.map_or(0, |prev_pos| prev_pos + 1);
seq_map_iterator.position = Some(new_pos);

if new_pos < seq_map_iterator.count {
if new_pos % 2 == 0 {
let raw_value = seq_map_iterator.map_iterator.next().unwrap();
if let Ok((raw_key, raw_value)) = raw_value {
let key_type = raw_key.spec;
let key = decode_value(raw_key, key_type);
let value_type = raw_value.spec;
let value = decode_value(raw_value, value_type);
seq_map_iterator.key = key;
seq_map_iterator.value = value;
}
}

false as cass_bool_t
}
return true as cass_bool_t;
}

false as cass_bool_t
}
},
CassIterator::CassTupleIterator(tuple_iterator) => {
let new_pos: usize = tuple_iterator.position.map_or(0, |prev_pos| prev_pos + 1);

Expand Down Expand Up @@ -439,9 +467,25 @@ pub unsafe extern "C" fn cass_iterator_get_value(

// Defined only for collections(list and set) or tuple iterator, for other types should return null
match iter {
CassIterator::CassCollectionIterator(CassCollectionIterator {
value: Some(value), ..
}) => value,
CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator {
CassCollectionIterator::SequenceIterator(CassSequenceIterator {
value: Some(value),
..
}) => value,
CassCollectionIterator::SeqMapIterator(CassMapIterator {
key: Some(key),
value: Some(value),
position: Some(pos),
..
}) => {
if pos % 2 == 0 {
key
} else {
value
}
}
_ => std::ptr::null(),
},
CassIterator::CassTupleIterator(CassTupleIterator {
value: Some(value), ..
}) => value,
Expand All @@ -463,6 +507,12 @@ pub unsafe extern "C" fn cass_iterator_get_map_key(
.is_some()); // assertion copied from c++ driver
map_iterator.key.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded
}
CassIterator::CassCollectionIterator(collection_iterator) => {
match collection_iterator {
CassCollectionIterator::SeqMapIterator(map_iter) => map_iter.key.as_ref().unwrap(),
CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator
}
}
_ => std::ptr::null(),
}
}
Expand All @@ -481,6 +531,14 @@ pub unsafe extern "C" fn cass_iterator_get_map_value(
.is_some()); // assertion copied from c++ driver
map_iterator.value.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded
}
CassIterator::CassCollectionIterator(collection_iterator) => {
match collection_iterator {
CassCollectionIterator::SeqMapIterator(map_iter) => {
map_iter.value.as_ref().unwrap()
}
CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator
}
}
_ => std::ptr::null(),
}
}
Expand Down Expand Up @@ -737,16 +795,30 @@ pub unsafe extern "C" fn cass_iterator_from_collection(
let item_count = collection.count;
let column_type = collection.column_type;
match column_type {
ColumnType::Map(_, _) => {
let map_iterator = MapIterator::deserialize(column_type, collection.frame_slice);
if let Ok(map_iter) = map_iterator {
let iterator = CassCollectionIterator::SeqMapIterator(CassMapIterator {
map_iterator: map_iter,
key: None,
value: None,
count: item_count * 2,
position: None,
});

return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator)));
}
}
ColumnType::Set(_) | ColumnType::List(_) => {
let sequence_iterator =
SequenceIterator::deserialize(column_type, collection.frame_slice);
if let Ok(seq_iterator) = sequence_iterator {
let iterator = CassCollectionIterator {
let iterator = CassCollectionIterator::SequenceIterator(CassSequenceIterator {
sequence_iterator: seq_iterator,
value: None,
count: item_count,
position: None,
};
});

return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator)));
}
Expand Down

0 comments on commit 77820bf

Please sign in to comment.