diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 341c6564..527c195b 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -97,6 +97,31 @@ impl CassIteratorStateInfo { } }; } + + /// Update iterator's state and return new state info. + /// Sets or replaces the old value with the `new_value` without changing the position. + fn update_value(&mut self, new_value: T) { + // Store a dummy NoValue temporarily as we cannot move state_info fields. + let old_state_info = std::mem::replace(self, CassIteratorStateInfo::NoValue); + *self = match old_state_info { + CassIteratorStateInfo::Value { position, .. } => CassIteratorStateInfo::Value { + value: new_value, + position, + }, + CassIteratorStateInfo::PositionNoValue { position } => CassIteratorStateInfo::Value { + value: new_value, + position, + }, + CassIteratorStateInfo::ValueNoPosition { .. } => CassIteratorStateInfo::Value { + value: new_value, + position: 0, + }, + CassIteratorStateInfo::NoValue => CassIteratorStateInfo::Value { + value: new_value, + position: 0, + }, + }; + } } pub struct CassResultIterator { @@ -116,32 +141,26 @@ pub enum CassCollectionIterator { pub struct CassSequenceIterator { sequence_iterator: SequenceIterator<'static, RawValue<'static>>, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo, } pub struct CassTupleIterator { sequence_iterator: SequenceIterator<'static, RawValue<'static>>, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo, } pub struct CassMapIterator { map_iterator: MapIterator<'static, RawValue<'static>, RawValue<'static>>, - key: Option, - value: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo<(CassValue, CassValue)>, // (key, value) } pub struct CassUdtIterator { udt_iterator: UdtIterator<'static>, - field_value: Option, - field_name: Option, count: usize, - position: Option, + state_info: CassIteratorStateInfo<(String, CassValue)>, // (field_name, field_value) } pub struct CassSchemaMetaIterator { @@ -282,63 +301,82 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass } CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { CassCollectionIterator::SequenceIterator(seq_iterator) => { - let new_pos: usize = seq_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - seq_iterator.position = Some(new_pos); - - if new_pos < seq_iterator.count { - let raw_value = seq_iterator.sequence_iterator.next().unwrap(); - if let Ok(raw) = raw_value { - let raw_value_type = raw.spec; - let value = decode_value(raw, raw_value_type); - seq_iterator.value = value; - - return true as cass_bool_t; + seq_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = seq_iterator.state_info + { + if position < seq_iterator.count { + let raw_value = seq_iterator.sequence_iterator.next().unwrap(); + if let Ok(raw) = raw_value { + let raw_value_type = raw.spec; + let new_value = decode_value(raw, raw_value_type); + + if let Some(val) = new_value { + seq_iterator.state_info.update_value(val); + return true as cass_bool_t; + } + } } } false as cass_bool_t } CassCollectionIterator::SeqMapIterator(seq_map_iterator) => { - let new_pos: usize = seq_map_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - seq_map_iterator.position = Some(new_pos); - - if new_pos < seq_map_iterator.count { - if new_pos % 2 == 0 { - let raw_value = seq_map_iterator.map_iterator.next().unwrap(); - if let Ok((raw_key, raw_value)) = raw_value { - let key_type = raw_key.spec; - let key = decode_value(raw_key, key_type); - let value_type = raw_value.spec; - let value = decode_value(raw_value, value_type); - seq_map_iterator.key = key; - seq_map_iterator.value = value; + seq_map_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = + seq_map_iterator.state_info + { + // Decoding in pair (key, value) on even positions + if position < seq_map_iterator.count { + if position % 2 == 0 { + let raw_value = seq_map_iterator.map_iterator.next().unwrap(); + return if let Ok((raw_key, raw_value)) = raw_value { + let key_type = raw_key.spec; + let new_key = decode_value(raw_key, key_type); + let value_type = raw_value.spec; + let new_value = decode_value(raw_value, value_type); + + if let (Some(k), Some(v)) = (new_key, new_value) { + seq_map_iterator.state_info.update_value((k, v)); + true as cass_bool_t + } else { + false as cass_bool_t + } + } else { + false as cass_bool_t + }; } + return true as cass_bool_t; } - - return true as cass_bool_t; } false as cass_bool_t } }, CassIterator::CassTupleIterator(tuple_iterator) => { - let new_pos: usize = tuple_iterator.position.map_or(0, |prev_pos| prev_pos + 1); + tuple_iterator.state_info.advance(); - tuple_iterator.position = Some(new_pos); - - if new_pos < tuple_iterator.count { - let raw_value = tuple_iterator.sequence_iterator.next().unwrap(); - if let Ok(raw) = raw_value { - let type_in_pos = match raw.spec { - ColumnType::Tuple(type_defs) => type_defs.get(new_pos), - _ => panic!("Cannot get tuple out of non-tuple column type"), - }; - if let Some(spec) = type_in_pos { - let value = decode_value(raw, spec); - tuple_iterator.value = value; - - return true as cass_bool_t; + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = tuple_iterator.state_info + { + if position < tuple_iterator.count { + let raw_value = tuple_iterator.sequence_iterator.next().unwrap(); + if let Ok(raw) = raw_value { + let type_in_pos = match raw.spec { + ColumnType::Tuple(type_defs) => type_defs.get(position), + _ => panic!("Cannot get tuple out of non-tuple column type"), + }; + if let Some(spec) = type_in_pos { + let new_value = decode_value(raw, spec); + + if let Some(val) = new_value { + tuple_iterator.state_info.update_value(val); + return true as cass_bool_t; + } + } } } } @@ -346,45 +384,51 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass false as cass_bool_t } CassIterator::CassMapIterator(map_iterator) => { - let new_pos: usize = map_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - map_iterator.position = Some(new_pos); + map_iterator.state_info.advance(); - if new_pos < map_iterator.count { - let raw_value = map_iterator.map_iterator.next().unwrap(); - if let Ok((raw_key, raw_value)) = raw_value { - let key_type = raw_key.spec; - let key = decode_value(raw_key, key_type); - let value_type = raw_value.spec; - let value = decode_value(raw_value, value_type); - map_iterator.key = key; - map_iterator.value = value; - - return true as cass_bool_t; + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = map_iterator.state_info + { + if position < map_iterator.count { + let raw_value = map_iterator.map_iterator.next().unwrap(); + if let Ok((raw_key, raw_value)) = raw_value { + let key_type = raw_key.spec; + let new_key = decode_value(raw_key, key_type); + let value_type = raw_value.spec; + let new_value = decode_value(raw_value, value_type); + + if let (Some(k), Some(v)) = (new_key, new_value) { + map_iterator.state_info.update_value((k, v)); + return true as cass_bool_t; + } + } } } false as cass_bool_t } CassIterator::CassUdtIterator(udt_iterator) => { - let new_pos: usize = udt_iterator.position.map_or(0, |prev_pos| prev_pos + 1); - - udt_iterator.position = Some(new_pos); - - if new_pos < udt_iterator.count { - let raw_value = udt_iterator.udt_iterator.next().unwrap(); - if let Ok((name_type, Some(frame_slice))) = raw_value { - let name = &name_type.0; - let field_type = &name_type.1; - let raw = RawValue { - spec: field_type, - slice: frame_slice, - }; - let value = decode_value(raw, field_type); - udt_iterator.field_value = value; - udt_iterator.field_name = Some(name.clone()); - - return true as cass_bool_t; + udt_iterator.state_info.advance(); + + if let CassIteratorStateInfo::Value { position, .. } + | CassIteratorStateInfo::PositionNoValue { position } = udt_iterator.state_info + { + if position < udt_iterator.count { + let raw_value = udt_iterator.udt_iterator.next().unwrap(); + if let Ok((name_type, Some(frame_slice))) = raw_value { + let name = &name_type.0; + let field_type = &name_type.1; + let raw = RawValue { + spec: field_type, + slice: frame_slice, + }; + let new_value = decode_value(raw, field_type); + + if let Some(val) = new_value { + udt_iterator.state_info.update_value((name.clone(), val)); + return true as cass_bool_t; + } + } } } @@ -494,16 +538,18 @@ pub unsafe extern "C" fn cass_iterator_get_value( match iter { CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator { CassCollectionIterator::SequenceIterator(CassSequenceIterator { - value: Some(value), + state_info: CassIteratorStateInfo::Value { value, .. }, .. }) => value, CassCollectionIterator::SeqMapIterator(CassMapIterator { - key: Some(key), - value: Some(value), - position: Some(pos), + state_info: + CassIteratorStateInfo::Value { + value: (key, value), + position, + }, .. }) => { - if pos % 2 == 0 { + if position % 2 == 0 { key } else { value @@ -512,7 +558,8 @@ pub unsafe extern "C" fn cass_iterator_get_value( _ => std::ptr::null(), }, CassIterator::CassTupleIterator(CassTupleIterator { - value: Some(value), .. + state_info: CassIteratorStateInfo::Value { value, .. }, + .. }) => value, _ => std::ptr::null(), // null is returned if value in iterator is not set } @@ -525,19 +572,28 @@ pub unsafe extern "C" fn cass_iterator_get_map_key( let iter = ptr_to_ref(iterator); match iter { - CassIterator::CassMapIterator(map_iterator) => { - assert!(map_iterator - .position - .map(|pos| pos < map_iterator.count) - .is_some()); // assertion copied from c++ driver - map_iterator.key.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded - } - CassIterator::CassCollectionIterator(collection_iterator) => { - match collection_iterator { - CassCollectionIterator::SeqMapIterator(map_iter) => map_iter.key.as_ref().unwrap(), - CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator - } + CassIterator::CassMapIterator(CassMapIterator { + count, + state_info: + CassIteratorStateInfo::Value { + value: (key, _value), + position, + }, + .. + }) => { + assert!(*position < *count); // assertion copied from c++ driver + key } + CassIterator::CassCollectionIterator(CassCollectionIterator::SeqMapIterator( + CassMapIterator { + state_info: + CassIteratorStateInfo::Value { + value: (key, _value), + .. + }, + .. + }, + )) => key, _ => std::ptr::null(), } } @@ -549,21 +605,28 @@ pub unsafe extern "C" fn cass_iterator_get_map_value( let iter = ptr_to_ref(iterator); match iter { - CassIterator::CassMapIterator(map_iterator) => { - assert!(map_iterator - .position - .map(|pos| pos < map_iterator.count) - .is_some()); // assertion copied from c++ driver - map_iterator.value.as_ref().unwrap() // safe to unwrap if cass_iterator_next succeeded - } - CassIterator::CassCollectionIterator(collection_iterator) => { - match collection_iterator { - CassCollectionIterator::SeqMapIterator(map_iter) => { - map_iter.value.as_ref().unwrap() - } - CassCollectionIterator::SequenceIterator(_) => std::ptr::null(), // Cannot get map key from sequence iterator - } + CassIterator::CassMapIterator(CassMapIterator { + count, + state_info: + CassIteratorStateInfo::Value { + value: (_key, value), + position, + }, + .. + }) => { + assert!(*position < *count); // assertion copied from c++ driver + value } + CassIterator::CassCollectionIterator(CassCollectionIterator::SeqMapIterator( + CassMapIterator { + state_info: + CassIteratorStateInfo::Value { + value: (_key, value), + .. + }, + .. + }, + )) => value, _ => std::ptr::null(), } } @@ -578,12 +641,15 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( match iter { CassIterator::CassUdtIterator(CassUdtIterator { - field_name: Some(field_name), count, - position, + state_info: + CassIteratorStateInfo::Value { + value: (field_name, _field_value), + position, + }, .. }) => { - assert!(position.map(|pos| pos < *count).is_some()); // assertion copied from c++ driver + assert!(*position < *count); // assertion copied from c++ driver write_str_to_c(field_name.as_str(), name, name_length); CassError::CASS_OK } @@ -599,12 +665,15 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_value( match iter { CassIterator::CassUdtIterator(CassUdtIterator { - field_value: Some(field_value), count, - position, + state_info: + CassIteratorStateInfo::Value { + value: (_field_name, field_value), + position, + }, .. }) => { - assert!(position.map(|pos| pos < *count).is_some()); // assertion copied from c++ driver + assert!(*position < *count); // assertion copied from c++ driver field_value } _ => std::ptr::null(), @@ -821,10 +890,8 @@ pub unsafe extern "C" fn cass_iterator_from_collection( if let Ok(map_iter) = map_iterator { let iterator = CassCollectionIterator::SeqMapIterator(CassMapIterator { map_iterator: map_iter, - key: None, - value: None, count: item_count * 2, - position: None, + state_info: CassIteratorStateInfo::NoValue, }); return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); @@ -836,9 +903,8 @@ pub unsafe extern "C" fn cass_iterator_from_collection( if let Ok(seq_iterator) = sequence_iterator { let iterator = CassCollectionIterator::SequenceIterator(CassSequenceIterator { sequence_iterator: seq_iterator, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }); return Box::into_raw(Box::new(CassIterator::CassCollectionIterator(iterator))); @@ -862,9 +928,8 @@ pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *m let sequence_iterator = SequenceIterator::new(column_type, item_count, frame_slice); let iterator = CassTupleIterator { sequence_iterator, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassTupleIterator(iterator))); @@ -884,10 +949,8 @@ pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut if let Ok(map_iter) = map_iterator { let iterator = CassMapIterator { map_iterator: map_iter, - key: None, - value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassMapIterator(iterator))); @@ -913,10 +976,8 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type( let udt_iterator = UdtIterator::new(fields, frame_slice); let iterator = CassUdtIterator { udt_iterator, - field_name: None, - field_value: None, count: item_count, - position: None, + state_info: CassIteratorStateInfo::NoValue, }; return Box::into_raw(Box::new(CassIterator::CassUdtIterator(iterator)));