diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index ef895bd7..6c2a8f2c 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -81,7 +81,13 @@ pub struct CassRowIterator { position: Option, } -pub struct CassCollectionIterator { +/// For sequential iteration over collection types +pub enum CassCollectionIterator { + SequenceIterator(CassSequenceIterator), + SeqMapIterator(CassMapIterator), +} + +pub struct CassSequenceIterator { sequence_iterator: SequenceIterator<'static, RawValue<'static>>, value: Option, count: usize, @@ -244,26 +250,48 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass (new_pos < row_iterator.row.columns.len()) as cass_bool_t } - CassIterator::CassCollectionIterator(collection_iterator) => { - let new_pos: usize = collection_iterator - .position - .map_or(0, |prev_pos| prev_pos + 1); + CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { + CassCollectionIterator::SequenceIterator(seq_iterator) => { + let new_pos: usize = seq_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - collection_iterator.position = Some(new_pos); + seq_iterator.position = Some(new_pos); - if new_pos < collection_iterator.count { - let raw_value = collection_iterator.sequence_iterator.next().unwrap(); - if let Ok(raw) = raw_value { - let raw_value_type = raw.spec; - let value = decode_value(raw, raw_value_type); - collection_iterator.value = value; + if new_pos < seq_iterator.count { + let raw_value = seq_iterator.sequence_iterator.next().unwrap(); + if let Ok(raw) = raw_value { + let raw_value_type = raw.spec; + let value = decode_value(raw, raw_value_type); + seq_iterator.value = value; - return true as cass_bool_t; + return true as cass_bool_t; + } } + + false as cass_bool_t } + CassCollectionIterator::SeqMapIterator(seq_map_iterator) => { + let new_pos: usize = seq_map_iterator.position.map_or(0, |prev_pos| prev_pos + 1); + seq_map_iterator.position = Some(new_pos); + + if new_pos < seq_map_iterator.count { + if new_pos % 2 == 0 { + let raw_value = seq_map_iterator.map_iterator.next().unwrap(); + if let Ok((raw_key, raw_value)) = raw_value { + let key_type = raw_key.spec; + let key = decode_value(raw_key, key_type); + let value_type = raw_value.spec; + let value = decode_value(raw_value, value_type); + seq_map_iterator.key = key; + seq_map_iterator.value = value; + } + } - false as cass_bool_t - } + return true as cass_bool_t; + } + + false as cass_bool_t + } + }, CassIterator::CassTupleIterator(tuple_iterator) => { let new_pos: usize = tuple_iterator.position.map_or(0, |prev_pos| prev_pos + 1); @@ -439,9 +467,25 @@ pub unsafe extern "C" fn cass_iterator_get_value( // Defined only for collections(list and set) or tuple iterator, for other types should return null match iter { - CassIterator::CassCollectionIterator(CassCollectionIterator { - value: Some(value), .. - }) => value, + CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { + CassCollectionIterator::SequenceIterator(CassSequenceIterator { + value: Some(value), + .. + }) => value, + CassCollectionIterator::SeqMapIterator(CassMapIterator { + key: Some(key), + value: Some(value), + position: Some(pos), + .. + }) => { + if pos % 2 == 0 { + key + } else { + value + } + } + _ => std::ptr::null(), + }, CassIterator::CassTupleIterator(CassTupleIterator { value: Some(value), .. }) => value, @@ -463,6 +507,12 @@ pub unsafe extern "C" fn cass_iterator_get_map_key( .is_some()); // assertion copied from c++ driver map_iterator.key.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded } + CassIterator::CassCollectionIterator(collection_iterator) => { + match collection_iterator { + CassCollectionIterator::SeqMapIterator(map_iter) => map_iter.key.as_ref().unwrap(), + CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator + } + } _ => std::ptr::null(), } } @@ -481,6 +531,14 @@ pub unsafe extern "C" fn cass_iterator_get_map_value( .is_some()); // assertion copied from c++ driver map_iterator.value.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded } + CassIterator::CassCollectionIterator(collection_iterator) => { + match collection_iterator { + CassCollectionIterator::SeqMapIterator(map_iter) => { + map_iter.value.as_ref().unwrap() + } + CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator + } + } _ => std::ptr::null(), } } @@ -737,16 +795,30 @@ pub unsafe extern "C" fn cass_iterator_from_collection( let item_count = collection.count; let column_type = collection.column_type; match column_type { + ColumnType::Map(_, _) => { + let map_iterator = MapIterator::deserialize(column_type, collection.frame_slice); + if let Ok(map_iter) = map_iterator { + let iterator = CassCollectionIterator::SeqMapIterator(CassMapIterator { + map_iterator: map_iter, + key: None, + value: None, + count: item_count * 2, + position: None, + }); + + return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); + } + } ColumnType::Set(_) | ColumnType::List(_) => { let sequence_iterator = SequenceIterator::deserialize(column_type, collection.frame_slice); if let Ok(seq_iterator) = sequence_iterator { - let iterator = CassCollectionIterator { + let iterator = CassCollectionIterator::SequenceIterator(CassSequenceIterator { sequence_iterator: seq_iterator, value: None, count: item_count, position: None, - }; + }); return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); }