diff --git a/src/map.rs b/src/map.rs index c373d5958..0101953d8 100644 --- a/src/map.rs +++ b/src/map.rs @@ -929,6 +929,53 @@ impl HashMap { } } + /// Retains only the elements specified by the predicate and breaks the iteration when + /// the predicate fails. Keeps the allocated memory for reuse. + /// + /// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Ok(false)` until + /// `f(&k, &mut v)` returns `Err(())` + /// The elements are visited in unsorted (and unspecified) order. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashMap; + /// + /// let mut map: HashMap = (0..8).map(|x|(x, x*10)).collect(); + /// assert_eq!(map.len(), 8); + /// let mut removed = 0; + /// map.retain_with_break(|&k, _| if removed < 3 { + /// if k % 2 == 0 { + /// Ok(true) + /// } else { + /// removed += 1; + /// Ok(false) + /// } + /// } else { + /// Err(()) + /// }); + /// + /// // We can see, that the number of elements inside map is changed and the + /// // length matches when we have aborted the `Err(())` + /// assert_eq!(map.len(), 5); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&K, &mut V) -> core::result::Result, + { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.table.iter() { + let &mut (ref key, ref mut value) = item.as_mut(); + match f(key, value) { + Ok(false) => self.table.erase(item), + Err(_) => break, + _ => continue, + } + } + } + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -5909,6 +5956,30 @@ mod test_map { assert_eq!(map[&6], 60); } + #[test] + fn test_retain_with_break() { + let mut map: HashMap = (0..100).map(|x| (x, x * 10)).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + map.retain_with_break(|&k, _| { + if removed < 40 { + if k > 50 { + removed += 1; + Ok(false) + } else { + Ok(true) + } + } else { + Err(()) + } + }); + assert_eq!(map.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert_eq!(map[&k], k * 10); + } + } + #[test] fn test_extract_if() { { diff --git a/src/set.rs b/src/set.rs index d57390f67..f498a64da 100644 --- a/src/set.rs +++ b/src/set.rs @@ -372,6 +372,36 @@ impl HashSet { self.map.retain(|k, _| f(k)); } + /// Retains only the elements specified by the predicate until the predicate fails. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)`. + /// + /// # Examples + /// + /// ``` + /// use hashbrown::HashSet; + /// + /// let xs = [1,2,3,4,5,6]; + /// let mut set: HashSet = xs.into_iter().collect(); + /// let mut count = 0; + /// set.retain_with_break(|&k| if count < 2 { + /// if k % 2 == 0 { + /// Ok(true) + /// } else { + /// Ok(false) + /// } + /// } else { + /// Err(()) + /// }); + /// assert_eq!(set.len(), 3); + /// ``` + pub fn retain_with_break(&mut self, mut f: F) + where + F: FnMut(&T) -> core::result::Result, + { + self.map.retain_with_break(|k, _| f(k)); + } + /// Drains elements which are true under the given predicate, /// and returns an iterator over the removed items. /// @@ -2980,6 +3010,30 @@ mod test_set { assert!(set.contains(&6)); } + #[test] + fn test_retain_with_break() { + let mut set: HashSet = (0..100).collect(); + // looping and removing any key > 50, but stop after 40 iterations + let mut removed = 0; + set.retain_with_break(|&k| { + if removed < 40 { + if k > 50 { + removed += 1; + Ok(false) + } else { + Ok(true) + } + } else { + Err(()) + } + }); + assert_eq!(set.len(), 60); + // check nothing up to 50 is removed + for k in 0..=50 { + assert!(set.contains(&k)); + } + } + #[test] fn test_extract_if() { { diff --git a/src/table.rs b/src/table.rs index 7f665b75a..83cf845a0 100644 --- a/src/table.rs +++ b/src/table.rs @@ -870,6 +870,59 @@ where } } + /// Retains only the elements specified by the predicate until the predicate fails. + /// + /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until + /// `f(&e)` returns `Err(())` + /// + /// # Examples + /// + /// ``` + /// # #[cfg(feature = "nightly")] + /// # fn test() { + /// use hashbrown::{HashTable, DefaultHashBuilder}; + /// use std::hash::BuildHasher; + /// + /// let mut table = HashTable::new(); + /// let hasher = DefaultHashBuilder::default(); + /// let hasher = |val: &_| hasher.hash_one(val); + /// let mut removed = 0; + /// for x in 1..=8 { + /// table.insert_unique(hasher(&x), x, hasher); + /// } + /// table.retain_with_break(|&mut v| if removed < 3 { + /// if v % 2 == 0 { + /// Ok(true) + /// } else { + /// removed += 1; + /// Ok(false) + /// } + /// } else { + /// Err(()) + /// }); + /// assert_eq!(table.len(), 5); + /// # } + /// # fn main() { + /// # #[cfg(feature = "nightly")] + /// # test() + /// # } + /// ``` + pub fn retain_with_break( + &mut self, + mut f: impl FnMut(&mut T) -> core::result::Result, + ) { + // Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in self.raw.iter() { + match f(item.as_mut()) { + Ok(false) => self.raw.erase(item), + Err(_) => break, + _ => continue, + } + } + } + } + /// Clears the set, returning all elements in an iterator. /// /// # Examples @@ -2372,12 +2425,44 @@ impl FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut #[cfg(test)] mod tests { + use crate::DefaultHashBuilder; + use super::HashTable; + use core::hash::BuildHasher; #[test] fn test_allocation_info() { assert_eq!(HashTable::<()>::new().allocation_size(), 0); assert_eq!(HashTable::::new().allocation_size(), 0); assert!(HashTable::::with_capacity(1).allocation_size() > core::mem::size_of::()); } + + #[test] + fn test_retain_with_break() { + let mut table = HashTable::new(); + let hasher = DefaultHashBuilder::default(); + let hasher = |val: &_| hasher.hash_one(val); + for x in 0..100 { + table.insert_unique(hasher(&x), x, hasher); + } + // looping and removing any value > 50, but stop after 40 iterations + let mut removed = 0; + table.retain_with_break(|&mut v| { + if removed < 40 { + if v > 50 { + removed += 1; + Ok(false) + } else { + Ok(true) + } + } else { + Err(()) + } + }); + assert_eq!(table.len(), 60); + // check nothing up to 50 is removed + for v in 0..=50 { + assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v)); + } + } }