Skip to content

Commit

Permalink
feat: Add retain_with_break to HashSet/Table/Map
Browse files Browse the repository at this point in the history
With the removal of the raw table, it is hard to implement an efficient
loop to conditionally remove/keep certain fields up to a limit. i.e. a
loop that can be aborted and does not require rehash the key for removal
of the entry.
  • Loading branch information
tugtugtug committed Nov 14, 2024
1 parent b74e3a7 commit 98d4756
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 0 deletions.
71 changes: 71 additions & 0 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,53 @@ impl<K, V, S, A: Allocator> HashMap<K, V, S, A> {
}
}

/// Retains only the elements specified by the predicate and breaks the iteration when
/// the predicate fails. Keeps the allocated memory for reuse.
///
/// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Ok(false)` until
/// `f(&k, &mut v)` returns `Err(())`
/// The elements are visited in unsorted (and unspecified) order.
///
/// # Examples
///
/// ```
/// use hashbrown::HashMap;
///
/// let mut map: HashMap<i32, i32> = (0..8).map(|x|(x, x*10)).collect();
/// assert_eq!(map.len(), 8);
/// let mut removed = 0;
/// map.retain_with_break(|&k, _| if removed < 3 {
/// if k % 2 == 0 {
/// Ok(true)
/// } else {
/// removed += 1;
/// Ok(false)
/// }
/// } else {
/// Err(())
/// });
///
/// // We can see, that the number of elements inside map is changed and the
/// // length matches when we have aborted the `Err(())`
/// assert_eq!(map.len(), 5);
/// ```
pub fn retain_with_break<F>(&mut self, mut f: F)
where
F: FnMut(&K, &mut V) -> core::result::Result<bool, ()>,
{
// Here we only use `iter` as a temporary, preventing use-after-free
unsafe {
for item in self.table.iter() {
let &mut (ref key, ref mut value) = item.as_mut();
match f(key, value) {
Ok(false) => self.table.erase(item),
Err(_) => break,
_ => continue,
}
}
}
}

/// Drains elements which are true under the given predicate,
/// and returns an iterator over the removed items.
///
Expand Down Expand Up @@ -5909,6 +5956,30 @@ mod test_map {
assert_eq!(map[&6], 60);
}

#[test]
fn test_retain_with_break() {
let mut map: HashMap<i32, i32> = (0..100).map(|x| (x, x * 10)).collect();
// looping and removing any key > 50, but stop after 40 iterations
let mut removed = 0;
map.retain_with_break(|&k, _| {
if removed < 40 {
if k > 50 {
removed += 1;
Ok(false)
} else {
Ok(true)
}
} else {
Err(())
}
});
assert_eq!(map.len(), 60);
// check nothing up to 50 is removed
for k in 0..=50 {
assert_eq!(map[&k], k * 10);
}
}

#[test]
fn test_extract_if() {
{
Expand Down
54 changes: 54 additions & 0 deletions src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,36 @@ impl<T, S, A: Allocator> HashSet<T, S, A> {
self.map.retain(|k, _| f(k));
}

/// Retains only the elements specified by the predicate until the predicate fails.
///
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)`.
///
/// # Examples
///
/// ```
/// use hashbrown::HashSet;
///
/// let xs = [1,2,3,4,5,6];
/// let mut set: HashSet<i32> = xs.into_iter().collect();
/// let mut count = 0;
/// set.retain_with_break(|&k| if count < 2 {
/// if k % 2 == 0 {
/// Ok(true)
/// } else {
/// Ok(false)
/// }
/// } else {
/// Err(())
/// });
/// assert_eq!(set.len(), 3);
/// ```
pub fn retain_with_break<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> core::result::Result<bool, ()>,
{
self.map.retain_with_break(|k, _| f(k));
}

/// Drains elements which are true under the given predicate,
/// and returns an iterator over the removed items.
///
Expand Down Expand Up @@ -2980,6 +3010,30 @@ mod test_set {
assert!(set.contains(&6));
}

#[test]
fn test_retain_with_break() {
let mut set: HashSet<i32> = (0..100).collect();
// looping and removing any key > 50, but stop after 40 iterations
let mut removed = 0;
set.retain_with_break(|&k| {
if removed < 40 {
if k > 50 {
removed += 1;
Ok(false)
} else {
Ok(true)
}
} else {
Err(())
}
});
assert_eq!(set.len(), 60);
// check nothing up to 50 is removed
for k in 0..=50 {
assert!(set.contains(&k));
}
}

#[test]
fn test_extract_if() {
{
Expand Down
85 changes: 85 additions & 0 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,59 @@ where
}
}

/// Retains only the elements specified by the predicate until the predicate fails.
///
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
/// `f(&e)` returns `Err(())`
///
/// # Examples
///
/// ```
/// # #[cfg(feature = "nightly")]
/// # fn test() {
/// use hashbrown::{HashTable, DefaultHashBuilder};
/// use std::hash::BuildHasher;
///
/// let mut table = HashTable::new();
/// let hasher = DefaultHashBuilder::default();
/// let hasher = |val: &_| hasher.hash_one(val);
/// let mut removed = 0;
/// for x in 1..=8 {
/// table.insert_unique(hasher(&x), x, hasher);
/// }
/// table.retain_with_break(|&mut v| if removed < 3 {
/// if v % 2 == 0 {
/// Ok(true)
/// } else {
/// removed += 1;
/// Ok(false)
/// }
/// } else {
/// Err(())
/// });
/// assert_eq!(table.len(), 5);
/// # }
/// # fn main() {
/// # #[cfg(feature = "nightly")]
/// # test()
/// # }
/// ```
pub fn retain_with_break(
&mut self,
mut f: impl FnMut(&mut T) -> core::result::Result<bool, ()>,
) {
// Here we only use `iter` as a temporary, preventing use-after-free
unsafe {
for item in self.raw.iter() {
match f(item.as_mut()) {
Ok(false) => self.raw.erase(item),
Err(_) => break,
_ => continue,
}
}
}
}

/// Clears the set, returning all elements in an iterator.
///
/// # Examples
Expand Down Expand Up @@ -2372,12 +2425,44 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut

#[cfg(test)]
mod tests {
use crate::DefaultHashBuilder;

use super::HashTable;

use core::hash::BuildHasher;
#[test]
fn test_allocation_info() {
assert_eq!(HashTable::<()>::new().allocation_size(), 0);
assert_eq!(HashTable::<u32>::new().allocation_size(), 0);
assert!(HashTable::<u32>::with_capacity(1).allocation_size() > core::mem::size_of::<u32>());
}

#[test]
fn test_retain_with_break() {
let mut table = HashTable::new();
let hasher = DefaultHashBuilder::default();
let hasher = |val: &_| hasher.hash_one(val);
for x in 0..100 {
table.insert_unique(hasher(&x), x, hasher);
}
// looping and removing any value > 50, but stop after 40 iterations
let mut removed = 0;
table.retain_with_break(|&mut v| {
if removed < 40 {
if v > 50 {
removed += 1;
Ok(false)
} else {
Ok(true)
}
} else {
Err(())
}
});
assert_eq!(table.len(), 60);
// check nothing up to 50 is removed
for v in 0..=50 {
assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v));
}
}
}

0 comments on commit 98d4756

Please sign in to comment.