From 63e1907eec684ddcff198a392f50cffca2681789 Mon Sep 17 00:00:00 2001 From: Gor Stepanyan Date: Sat, 24 Jun 2023 16:07:28 +0200 Subject: [PATCH] query_result: Refactor collection iterators Wrapped value and position fields in `CassIteratorStateInfo` to convey more information about the collection iterators' state. --- scylla-rust-wrapper/src/query_result.rs | 382 +++++++++++++++--------- 1 file changed, 236 insertions(+), 146 deletions(-) diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 341c6564..29e8e059 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -97,6 +97,31 @@ impl CassIteratorStateInfo { } }; } + + /// Update iterator's state and return new state info. + /// Sets or replaces the old value with the `new_value` without changing the position. + fn update_value(&mut self, new_value: T) { + // Store a dummy NoValue temporarily as we cannot move state_info fields. + let old_state_info = std::mem::replace(self, CassIteratorStateInfo::NoValue); + *self = match old_state_info { + CassIteratorStateInfo::Value { position, .. } => CassIteratorStateInfo::Value { + value: new_value, + position, + }, + CassIteratorStateInfo::PositionNoValue { position } => CassIteratorStateInfo::Value { + value: new_value, + position, + }, + CassIteratorStateInfo::ValueNoPosition { .. } => CassIteratorStateInfo::Value { + value: new_value, + position: 0, + }, + CassIteratorStateInfo::NoValue => CassIteratorStateInfo::Value { + value: new_value, + position: 0, + }, + }; + } } pub struct CassResultIterator { @@ -116,32 +141,26 @@ pub enum CassCollectionIterator { pub struct CassSequenceIterator { sequence_iterator: SequenceIterator<'static, RawValue<'static>>, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo, } pub struct CassTupleIterator { sequence_iterator: SequenceIterator<'static, RawValue<'static>>, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo, } pub struct CassMapIterator { map_iterator: MapIterator<'static, RawValue<'static>, RawValue<'static>>, - key: Option, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo<(CassValue, CassValue)>, // (key, value) } pub struct CassUdtIterator { udt_iterator: UdtIterator<'static>, - field_value: Option, - field_name: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo<(String, CassValue)>, // (field_name, field_value) } pub struct CassSchemaMetaIterator { @@ -282,113 +301,167 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass } CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { CassCollectionIterator::SequenceIterator(seq_iterator) => { - let new_pos: usize = seq_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - seq_iterator.position = Some(new_pos); - - if new_pos < seq_iterator.count { - let raw_value = seq_iterator.sequence_iterator.next().unwrap(); - if let Ok(raw) = raw_value { - let raw_value_type = raw.spec; - let value = decode_value(raw, raw_value_type); - seq_iterator.value = value; - - return true as cass_bool_t; + seq_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = seq_iterator.state_info + { + if position < seq_iterator.count { + let raw_value = seq_iterator.sequence_iterator.next().unwrap(); + if let Ok(raw) = raw_value { + let raw_value_type = raw.spec; + let new_value = decode_value(raw, raw_value_type); + + if let Some(val) = new_value { + seq_iterator.state_info.update_value(val); + true as cass_bool_t // Value on new position is deserialized + } else { + false as cass_bool_t // New value is empty + } + } else { + false as cass_bool_t // Raw value is empty + } + } else { + false as cass_bool_t // Iterator reached the end } + } else { + false as cass_bool_t // Iterator position is unknown } - - false as cass_bool_t } CassCollectionIterator::SeqMapIterator(seq_map_iterator) => { - let new_pos: usize = seq_map_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - seq_map_iterator.position = Some(new_pos); - - if new_pos < seq_map_iterator.count { - if new_pos % 2 == 0 { - let raw_value = seq_map_iterator.map_iterator.next().unwrap(); - if let Ok((raw_key, raw_value)) = raw_value { - let key_type = raw_key.spec; - let key = decode_value(raw_key, key_type); - let value_type = raw_value.spec; - let value = decode_value(raw_value, value_type); - seq_map_iterator.key = key; - seq_map_iterator.value = value; + seq_map_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = + seq_map_iterator.state_info + { + // Decoding in pair (key, value) on even positions + if position < seq_map_iterator.count { + if position % 2 == 0 { + let raw_value = seq_map_iterator.map_iterator.next().unwrap(); + if let Ok((raw_key, raw_value)) = raw_value { + let key_type = raw_key.spec; + let new_key = decode_value(raw_key, key_type); + let value_type = raw_value.spec; + let new_value = decode_value(raw_value, value_type); + + if let (Some(k), Some(v)) = (new_key, new_value) { + seq_map_iterator.state_info.update_value((k, v)); + true as cass_bool_t // (Key, Value) on new position is deserialized + } else { + false as cass_bool_t // New (key, value) is empty + } + } else { + false as cass_bool_t // Raw value is empty + } + } else { + true as cass_bool_t // Do not deserialize on odd positions } + } else { + false as cass_bool_t // Iterator reached the end } - - return true as cass_bool_t; + } else { + false as cass_bool_t // Iterator position is unknown } - - false as cass_bool_t } }, CassIterator::CassTupleIterator(tuple_iterator) => { - let new_pos: usize = tuple_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - tuple_iterator.position = Some(new_pos); - - if new_pos < tuple_iterator.count { - let raw_value = tuple_iterator.sequence_iterator.next().unwrap(); - if let Ok(raw) = raw_value { - let type_in_pos = match raw.spec { - ColumnType::Tuple(type_defs) => type_defs.get(new_pos), - _ => panic!("Cannot get tuple out of non-tuple column type"), - }; - if let Some(spec) = type_in_pos { - let value = decode_value(raw, spec); - tuple_iterator.value = value; - - return true as cass_bool_t; + tuple_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = tuple_iterator.state_info + { + if position < tuple_iterator.count { + let raw_value = tuple_iterator.sequence_iterator.next().unwrap(); + if let Ok(raw) = raw_value { + let type_in_pos = match raw.spec { + ColumnType::Tuple(type_defs) => type_defs.get(position), + _ => panic!("Cannot get tuple out of non-tuple column type"), + }; + if let Some(spec) = type_in_pos { + let new_value = decode_value(raw, spec); + + if let Some(val) = new_value { + tuple_iterator.state_info.update_value(val); + true as cass_bool_t // Value on new position is deserialized + } else { + false as cass_bool_t // New value is empty + } + } else { + false as cass_bool_t // Value type is not known + } + } else { + false as cass_bool_t // Raw value is empty } + } else { + false as cass_bool_t // Iterator reached the end } + } else { + false as cass_bool_t // Iterator position is unknown } - - false as cass_bool_t } CassIterator::CassMapIterator(map_iterator) => { - let new_pos: usize = map_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - map_iterator.position = Some(new_pos); - - if new_pos < map_iterator.count { - let raw_value = map_iterator.map_iterator.next().unwrap(); - if let Ok((raw_key, raw_value)) = raw_value { - let key_type = raw_key.spec; - let key = decode_value(raw_key, key_type); - let value_type = raw_value.spec; - let value = decode_value(raw_value, value_type); - map_iterator.key = key; - map_iterator.value = value; + map_iterator.state_info.advance(); - return true as cass_bool_t; + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = map_iterator.state_info + { + if position < map_iterator.count { + let raw_value = map_iterator.map_iterator.next().unwrap(); + if let Ok((raw_key, raw_value)) = raw_value { + let key_type = raw_key.spec; + let new_key = decode_value(raw_key, key_type); + let value_type = raw_value.spec; + let new_value = decode_value(raw_value, value_type); + + if let (Some(k), Some(v)) = (new_key, new_value) { + map_iterator.state_info.update_value((k, v)); + true as cass_bool_t // (Key, Value) on new position is deserialized + } else { + false as cass_bool_t // New (key, value) is empty + } + } else { + false as cass_bool_t // Raw value is empty + } + } else { + false as cass_bool_t // Iterator reached the end } + } else { + false as cass_bool_t // Iterator position is unknown } - - false as cass_bool_t } CassIterator::CassUdtIterator(udt_iterator) => { - let new_pos: usize = udt_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - udt_iterator.position = Some(new_pos); - - if new_pos < udt_iterator.count { - let raw_value = udt_iterator.udt_iterator.next().unwrap(); - if let Ok((name_type, Some(frame_slice))) = raw_value { - let name = &name_type.0; - let field_type = &name_type.1; - let raw = RawValue { - spec: field_type, - slice: frame_slice, - }; - let value = decode_value(raw, field_type); - udt_iterator.field_value = value; - udt_iterator.field_name = Some(name.clone()); - - return true as cass_bool_t; + udt_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = udt_iterator.state_info + { + if position < udt_iterator.count { + let raw_value = udt_iterator.udt_iterator.next().unwrap(); + if let Ok((name_type, Some(frame_slice))) = raw_value { + let name = &name_type.0; + let field_type = &name_type.1; + let raw = RawValue { + spec: field_type, + slice: frame_slice, + }; + let new_value = decode_value(raw, field_type); + + if let Some(val) = new_value { + udt_iterator.state_info.update_value((name.clone(), val)); + true as cass_bool_t // Value on new position is deserialized + } else { + false as cass_bool_t // New value is empty + } + } else { + false as cass_bool_t // Raw value is empty + } + } else { + false as cass_bool_t // Iterator reached the end } + } else { + false as cass_bool_t // Iterator position is unknown } - - false as cass_bool_t } CassIterator::CassSchemaMetaIterator(schema_meta_iterator) => { let new_pos: usize = schema_meta_iterator @@ -494,16 +567,18 @@ pub unsafe extern "C" fn cass_iterator_get_value( match iter { CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { CassCollectionIterator::SequenceIterator(CassSequenceIterator { - value: Some(value), + state_info: CassIteratorStateInfo::Value { value, .. }, .. }) => value, CassCollectionIterator::SeqMapIterator(CassMapIterator { - key: Some(key), - value: Some(value), - position: Some(pos), + state_info: + CassIteratorStateInfo::Value { + value: (key, value), + position, + }, .. }) => { - if pos % 2 == 0 { + if position % 2 == 0 { key } else { value @@ -512,7 +587,8 @@ pub unsafe extern "C" fn cass_iterator_get_value( _ => std::ptr::null(), }, CassIterator::CassTupleIterator(CassTupleIterator { - value: Some(value), .. + state_info: CassIteratorStateInfo::Value { value, .. }, + .. }) => value, _ => std::ptr::null(), // null is returned if value in iterator is not set } @@ -525,19 +601,28 @@ pub unsafe extern "C" fn cass_iterator_get_map_key( let iter = ptr_to_ref(iterator); match iter { - CassIterator::CassMapIterator(map_iterator) => { - assert!(map_iterator - .position - .map(|pos| pos < map_iterator.count) - .is_some()); // assertion copied from c++ driver - map_iterator.key.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded - } - CassIterator::CassCollectionIterator(collection_iterator) => { - match collection_iterator { - CassCollectionIterator::SeqMapIterator(map_iter) => map_iter.key.as_ref().unwrap(), - CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator - } + CassIterator::CassMapIterator(CassMapIterator { + count, + state_info: + CassIteratorStateInfo::Value { + value: (key, _value), + position, + }, + .. + }) => { + assert!(*position < *count); // assertion copied from c++ driver + key } + CassIterator::CassCollectionIterator(CassCollectionIterator::SeqMapIterator( + CassMapIterator { + state_info: + CassIteratorStateInfo::Value { + value: (key, _value), + .. + }, + .. + }, + )) => key, _ => std::ptr::null(), } } @@ -549,21 +634,28 @@ pub unsafe extern "C" fn cass_iterator_get_map_value( let iter = ptr_to_ref(iterator); match iter { - CassIterator::CassMapIterator(map_iterator) => { - assert!(map_iterator - .position - .map(|pos| pos < map_iterator.count) - .is_some()); // assertion copied from c++ driver - map_iterator.value.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded - } - CassIterator::CassCollectionIterator(collection_iterator) => { - match collection_iterator { - CassCollectionIterator::SeqMapIterator(map_iter) => { - map_iter.value.as_ref().unwrap() - } - CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator - } + CassIterator::CassMapIterator(CassMapIterator { + count, + state_info: + CassIteratorStateInfo::Value { + value: (_key, value), + position, + }, + .. + }) => { + assert!(*position < *count); // assertion copied from c++ driver + value } + CassIterator::CassCollectionIterator(CassCollectionIterator::SeqMapIterator( + CassMapIterator { + state_info: + CassIteratorStateInfo::Value { + value: (_key, value), + .. + }, + .. + }, + )) => value, _ => std::ptr::null(), } } @@ -578,12 +670,15 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( match iter { CassIterator::CassUdtIterator(CassUdtIterator { - field_name: Some(field_name), count, - position, + state_info: + CassIteratorStateInfo::Value { + value: (field_name, _field_value), + position, + }, .. }) => { - assert!(position.map(|pos| pos < *count).is_some()); // assertion copied from c++ driver + assert!(*position < *count); // assertion copied from c++ driver write_str_to_c(field_name.as_str(), name, name_length); CassError::CASS_OK } @@ -599,12 +694,15 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_value( match iter { CassIterator::CassUdtIterator(CassUdtIterator { - field_value: Some(field_value), count, - position, + state_info: + CassIteratorStateInfo::Value { + value: (_field_name, field_value), + position, + }, .. }) => { - assert!(position.map(|pos| pos < *count).is_some()); // assertion copied from c++ driver + assert!(*position < *count); // assertion copied from c++ driver field_value } _ => std::ptr::null(), @@ -821,10 +919,8 @@ pub unsafe extern "C" fn cass_iterator_from_collection( if let Ok(map_iter) = map_iterator { let iterator = CassCollectionIterator::SeqMapIterator(CassMapIterator { map_iterator: map_iter, - key: None, - value: None, count: item_count * 2, - position: None, + state_info: CassIteratorStateInfo::NoValue, }); return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); @@ -836,9 +932,8 @@ pub unsafe extern "C" fn cass_iterator_from_collection( if let Ok(seq_iterator) = sequence_iterator { let iterator = CassCollectionIterator::SequenceIterator(CassSequenceIterator { sequence_iterator: seq_iterator, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }); return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); @@ -862,9 +957,8 @@ pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *m let sequence_iterator = SequenceIterator::new(column_type, item_count, frame_slice); let iterator = CassTupleIterator { sequence_iterator, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassTupleIterator(iterator))); @@ -884,10 +978,8 @@ pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut if let Ok(map_iter) = map_iterator { let iterator = CassMapIterator { map_iterator: map_iter, - key: None, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassMapIterator(iterator))); @@ -913,10 +1005,8 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type( let udt_iterator = UdtIterator::new(fields, frame_slice); let iterator = CassUdtIterator { udt_iterator, - field_name: None, - field_value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassUdtIterator(iterator)));