From a80897e2448935b5bc17a1c42470d4699880342a Mon Sep 17 00:00:00 2001 From: GFX9 Date: Mon, 18 Mar 2024 11:23:59 +0800 Subject: [PATCH] fix: prevent batch_index overflow in raw_curp Signed-off-by: GFX9 --- crates/curp/src/server/raw_curp/log.rs | 189 +++++++++++++++++++------ 1 file changed, 145 insertions(+), 44 deletions(-) diff --git a/crates/curp/src/server/raw_curp/log.rs b/crates/curp/src/server/raw_curp/log.rs index 4aee089d3..53c272cab 100644 --- a/crates/curp/src/server/raw_curp/log.rs +++ b/crates/curp/src/server/raw_curp/log.rs @@ -12,7 +12,7 @@ use bincode::serialized_size; use clippy_utilities::{NumericCast, OverflowArithmetic}; use itertools::Itertools; use tokio::sync::mpsc; -use tracing::error; +use tracing::{error, warn}; use crate::{ cmd::Command, @@ -82,9 +82,15 @@ impl FallbackContext { struct LogEntryVecDeque { /// A VecDeque to store log entries, it will be serialized and persisted entries: VecDeque>>, - /// The sum of serialized size of previous log entries - /// batch_index[i+1] = batch_index[i] + size(entries[i]) - batch_index: VecDeque, + /// entry size of each item in entries + entry_size: VecDeque, + /// the right index of the batch (offset) + /// batch_range: [i, i + batch_index[i]] + batch_index: VecDeque, + /// the first entry idx of the current batch window + first_entry_at_last_batch: usize, + /// the current batch window size + last_batch_size: u64, /// Batch size limit batch_limit: u64, } @@ -92,11 +98,12 @@ struct LogEntryVecDeque { impl LogEntryVecDeque { /// return a log entries with cap fn new(cap: usize, batch_limit: u64) -> Self { - let mut batch_index = VecDeque::with_capacity(cap.overflow_add(1)); - batch_index.push_back(0); Self { entries: VecDeque::with_capacity(cap), - batch_index, + entry_size: VecDeque::with_capacity(cap), + batch_index: VecDeque::with_capacity(cap), + first_entry_at_last_batch: 0, + last_batch_size: 0, batch_limit, } } @@ -111,21 +118,71 @@ impl LogEntryVecDeque { /// push a log entry into the back of queue fn push_back(&mut self, entry: Arc>) -> Result<(), bincode::Error> { + #![allow(clippy::indexing_slicing)] let entry_size = serialized_size(&entry)?; + if entry_size > self.batch_limit { + warn!("entry_size of an entry > batch_limit, which may be too small.",); + } + self.entries.push_back(entry); - let Some(&pre_entries_size) = self.batch_index.back() else { - unreachable!("batch_index cannot be None") - }; - self.batch_index - .push_back(pre_entries_size.overflow_add(entry_size)); + self.entry_size.push_back(entry_size); + self.batch_index.push_back(0); // placeholder + + if entry_size > self.batch_limit { + let entry_idx = self.batch_index.len() - 1; + for prev_idx in self.first_entry_at_last_batch..entry_idx { + self.batch_index[prev_idx] = entry_idx - prev_idx; // record offset but not absolute index + } + self.batch_index[entry_idx] = 1; + self.last_batch_size = 0; + self.first_entry_at_last_batch = entry_idx + 1; + return Ok(()); + } + + while self.last_batch_size + entry_size > self.batch_limit + && self.first_entry_at_last_batch < self.entries.len() + { + self.batch_index[self.first_entry_at_last_batch] = + self.entries.len() - 1 - self.first_entry_at_last_batch; // record offset but not absolute index + self.last_batch_size -= self.entry_size[self.first_entry_at_last_batch]; + self.first_entry_at_last_batch += 1; + } + + self.last_batch_size += entry_size; + + if self.first_entry_at_last_batch >= self.entries.len() { + self.batch_index[self.entries.len() - 1] = 1; + } + + if self.last_batch_size == self.batch_limit { + self.batch_index[self.first_entry_at_last_batch] = + self.entries.len() - self.first_entry_at_last_batch; // record offset but not absolute index + } + Ok(()) } /// pop a log entry from the front of queue fn pop_front(&mut self) -> Option>> { + #![allow(clippy::indexing_slicing)] if self.entries.front().is_some() { - _ = self.batch_index.pop_front(); + let front_size = self.entry_size[0]; + + if self.first_entry_at_last_batch == 0 { + self.last_batch_size -= front_size; + } else { + self.first_entry_at_last_batch -= 1; + } + + let _ = self + .batch_index + .pop_front() + .unwrap_or_else(|| unreachable!()); + let _ = self + .entry_size + .pop_front() + .unwrap_or_else(|| unreachable!()); self.entries.pop_front() } else { None @@ -134,36 +191,34 @@ impl LogEntryVecDeque { /// restore log entries from Vec fn restore(&mut self, entries: Vec>) { - let mut batch_index = VecDeque::with_capacity(entries.capacity()); - batch_index.push_back(0); - for entry in &entries { - #[allow(clippy::expect_used)] - let entry_size = - serialized_size(entry).expect("log entry {entry:?} cannot be serialized"); - if let Some(cur_size) = batch_index.back() { - batch_index.push_back(cur_size.overflow_add(entry_size)); - } - } + self.batch_index = VecDeque::with_capacity(entries.capacity()); + self.entries = VecDeque::with_capacity(entries.capacity()); + self.entry_size = VecDeque::with_capacity(entries.capacity()); - self.entries = entries.into_iter().map(Arc::new).collect(); - self.batch_index = batch_index; + self.last_batch_size = 0; + self.first_entry_at_last_batch = 0; + + for entry in entries { + let _unuse = self.push_back(Arc::from(entry)); + } } /// clear whole log entries fn clear(&mut self) { self.entries.clear(); + self.entry_size.clear(); self.batch_index.clear(); - self.batch_index.push_back(0); + self.last_batch_size = 0; + self.first_entry_at_last_batch = 0; } /// Get the range [left, right) of the log entry, whose size should be equal or smaller than `batch_limit` fn get_range_by_batch(&self, left: usize) -> Range { - #[allow(clippy::indexing_slicing)] - let target = self.batch_index[left].overflow_add(self.batch_limit); - // remove the fake index 0 in `batch_index` - match self.batch_index.binary_search(&target) { - Ok(right) => left..right, - Err(right) => left..right - 1, + #![allow(clippy::indexing_slicing)] + if self.batch_index[left] == 0 { + left..self.entries.len() + } else { + left..left + self.batch_index[left] } } @@ -175,15 +230,56 @@ impl LogEntryVecDeque { /// check whether the log entry range [li,..) exceeds the batch limit or not fn has_next_batch(&self, left: usize) -> bool { - if let (Some(&cur_size), Some(&last_size)) = - (self.batch_index.get(left), self.batch_index.back()) - { - let target_size = cur_size.overflow_add(self.batch_limit); - target_size <= last_size + if let Some(&offset) = self.batch_index.get(left) { + offset != 0 } else { false } } + + #[allow(unused)] + /// set batch limit and reconstruct `batch_index` + fn set_batch_limit(&mut self, batch_limit: u64) { + #![allow(clippy::indexing_slicing)] + self.batch_limit = batch_limit; + self.last_batch_size = 0; + self.first_entry_at_last_batch = 0; + self.batch_index.iter_mut().for_each(|val| *val = 0); + + for entry_idx in 0..self.entries.len() { + let entry_size = self.entry_size[entry_idx]; + + if entry_size > self.batch_limit { + for prev_idx in self.first_entry_at_last_batch..entry_idx { + self.batch_index[prev_idx] = entry_idx - prev_idx; // record offset but not absolute index + } + self.batch_index[entry_idx] = 1; + self.last_batch_size = 0; + self.first_entry_at_last_batch = entry_idx + 1; + continue; + } + + while self.last_batch_size + entry_size > self.batch_limit + && self.first_entry_at_last_batch < self.entries.len() + { + self.batch_index[self.first_entry_at_last_batch] = + entry_idx - self.first_entry_at_last_batch; // record offset but not absolute index + self.last_batch_size -= self.entry_size[self.first_entry_at_last_batch]; + self.first_entry_at_last_batch += 1; + } + + self.last_batch_size += entry_size; + + if self.first_entry_at_last_batch >= self.entries.len() { + self.batch_index[entry_idx] = 1; + } + + if entry_idx == self.entries.len() - 1 && self.last_batch_size == self.batch_limit { + self.batch_index[self.first_entry_at_last_batch] = + self.entries.len() - self.first_entry_at_last_batch; // record offset but not absolute index + } + } + } } impl std::ops::Deref for LogEntryVecDeque { @@ -448,6 +544,12 @@ impl Log { false }); } + + #[allow(unused)] + /// set batch limit and reconstruct `batch_index` + pub(super) fn set_batch_limit(&mut self, batch_limit: u64) { + self.entries.set_batch_limit(batch_limit); + } } #[cfg(test)] @@ -470,7 +572,7 @@ mod tests { } fn set_batch_limit(log: &mut Log, batch_limit: u64) { - log.entries.batch_limit = batch_limit; + log.set_batch_limit(batch_limit); } #[test] @@ -575,7 +677,7 @@ mod tests { .enumerate() .map(|(idx, cmd)| log.push(1, ProposeId(0, idx.numeric_cast()), cmd).unwrap()) .collect::>(); - let log_entry_size = log.entries.batch_index[1]; + let log_entry_size = log.entries.entry_size[0]; set_batch_limit(&mut log, 3 * log_entry_size - 1); let bound_1 = log.entries.get_range_by_batch(3); @@ -633,7 +735,7 @@ mod tests { let bound_5 = log.entries.get_range_by_batch(3); assert_eq!( bound_5, - 3..3, + 3..4, "batch_index = {:?}, batch = {}, log_entry_size = {}", log.entries.batch_index, log.entries.batch_limit, @@ -664,8 +766,7 @@ mod tests { log.restore_entries(entries); assert_eq!(log.entries.len(), 10); - assert_eq!(log.entries.batch_index.len(), 11); - assert_eq!(log.entries.batch_index[0], 0); + assert_eq!(log.entries.batch_index.len(), 10); let entry_size = log.entries.batch_index[1]; log.entries @@ -675,7 +776,7 @@ mod tests { .for_each(|(idx, &size)| { assert_eq!( size, - entry_size * idx.numeric_cast::(), + entry_size * idx, "batch_index = {:?}, batch = {}, entry_size = {}", log.entries.batch_index, log.entries.batch_limit, @@ -698,6 +799,6 @@ mod tests { log.compact(); assert_eq!(log.base_index, 12); assert_eq!(log.entries.front().unwrap().index, 13); - assert_eq!(log.entries.batch_index.len(), 19); + assert_eq!(log.entries.batch_index.len(), 18); } }