diff --git a/src/map.rs b/src/map.rs index 428fac40..582ae6ad 100644 --- a/src/map.rs +++ b/src/map.rs @@ -15,7 +15,8 @@ mod tests; pub use self::core::raw_entry_v1::{self, RawEntryApiV1}; pub use self::core::{Entry, IndexedEntry, OccupiedEntry, VacantEntry}; pub use self::iter::{ - Drain, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Splice, Values, ValuesMut, + Drain, ExtractIf, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Splice, Values, + ValuesMut, }; pub use self::slice::Slice; pub use crate::mutable_keys::MutableKeys; @@ -33,7 +34,7 @@ use alloc::vec::Vec; #[cfg(feature = "std")] use std::collections::hash_map::RandomState; -use self::core::IndexMapCore; +pub(crate) use self::core::{ExtractCore, IndexMapCore}; use crate::util::{third, try_simplify_range}; use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError}; @@ -301,6 +302,44 @@ impl IndexMap { Drain::new(self.core.drain(range)) } + /// Creates an iterator which uses a closure to determine if an element should be removed. + /// + /// If the closure returns true, the element is removed from the map and yielded. + /// If the closure returns false, or panics, the element remains in the map and will not be + /// yielded. + /// + /// Note that `extract_if` lets you mutate every value in the filter closure, regardless of + /// whether you choose to keep or remove it. + /// + /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use [`retain`] with a negated predicate if you do not need the returned iterator. + /// + /// [`retain`]: IndexMap::retain + /// + /// # Examples + /// + /// Splitting a map into even and odd keys, reusing the original map: + /// + /// ``` + /// use indexmap::IndexMap; + /// + /// let mut map: IndexMap = (0..8).map(|x| (x, x)).collect(); + /// let extracted: IndexMap = map.extract_if(|k, _v| k % 2 == 0).collect(); + /// + /// let evens = extracted.keys().copied().collect::>(); + /// let odds = map.keys().copied().collect::>(); + /// + /// assert_eq!(evens, vec![0, 2, 4, 6]); + /// assert_eq!(odds, vec![1, 3, 5, 7]); + /// ``` + pub fn extract_if(&mut self, pred: F) -> ExtractIf<'_, K, V, F> + where + F: FnMut(&K, &mut V) -> bool, + { + ExtractIf::new(&mut self.core, pred) + } + /// Splits the collection into two at the given index. /// /// Returns a newly allocated map containing the elements in the range diff --git a/src/map/core.rs b/src/map/core.rs index 2dca04a5..12e1c509 100644 --- a/src/map/core.rs +++ b/src/map/core.rs @@ -24,6 +24,7 @@ use crate::util::simplify_range; use crate::{Bucket, Entries, Equivalent, HashValue}; pub use entry::{Entry, IndexedEntry, OccupiedEntry, VacantEntry}; +pub(crate) use raw::ExtractCore; /// Core of the map that does not depend on S pub(crate) struct IndexMapCore { @@ -145,6 +146,7 @@ impl IndexMapCore { #[inline] pub(crate) fn len(&self) -> usize { + debug_assert_eq!(self.entries.len(), self.indices.len()); self.indices.len() } diff --git a/src/map/core/raw.rs b/src/map/core/raw.rs index 233e41e7..a862993b 100644 --- a/src/map/core/raw.rs +++ b/src/map/core/raw.rs @@ -96,6 +96,20 @@ impl IndexMapCore { // only the item references that are appropriately bound to `&mut self`. unsafe { self.indices.iter().map(|bucket| bucket.as_mut()) } } + + pub(crate) fn extract(&mut self) -> ExtractCore<'_, K, V> { + // SAFETY: We must have consistent lengths to start, so that's a hard assertion. + // Then the worst `set_len(0)` can do is leak items if `ExtractCore` doesn't drop. + assert_eq!(self.entries.len(), self.indices.len()); + unsafe { + self.entries.set_len(0); + } + ExtractCore { + map: self, + current: 0, + new_len: 0, + } + } } /// A view into an occupied raw entry in an `IndexMap`. @@ -143,3 +157,80 @@ impl<'a, K, V> RawTableEntry<'a, K, V> { (self.map, index) } } + +pub(crate) struct ExtractCore<'a, K, V> { + map: &'a mut IndexMapCore, + current: usize, + new_len: usize, +} + +impl Drop for ExtractCore<'_, K, V> { + fn drop(&mut self) { + let old_len = self.map.indices.len(); + let mut new_len = self.new_len; + debug_assert!(new_len <= self.current); + debug_assert!(self.current <= old_len); + debug_assert!(old_len <= self.map.entries.capacity()); + + // SAFETY: We assume `new_len` and `current` were correctly maintained by the iterator. + // So `entries[new_len..current]` were extracted, but the rest before and after are valid. + unsafe { + if new_len == self.current { + // Nothing was extracted, so any remaining items can be left in place. + new_len = old_len; + } else if self.current < old_len { + // Need to shift the remaining items down. + let tail_len = old_len - self.current; + let base = self.map.entries.as_mut_ptr(); + let src = base.add(self.current); + let dest = base.add(new_len); + src.copy_to(dest, tail_len); + new_len += tail_len; + } + self.map.entries.set_len(new_len); + } + + if new_len != old_len { + // We don't keep track of *which* items were extracted, so reindex everything. + self.map.rebuild_hash_table(); + } + } +} + +impl ExtractCore<'_, K, V> { + pub(crate) fn extract_if(&mut self, mut pred: F) -> Option> + where + F: FnMut(&mut Bucket) -> bool, + { + let old_len = self.map.indices.len(); + debug_assert!(old_len <= self.map.entries.capacity()); + + let base = self.map.entries.as_mut_ptr(); + while self.current < old_len { + // SAFETY: We're maintaining both indices within bounds of the original entries, so + // 0..new_len and current..old_len are always valid items for our Drop to keep. + unsafe { + let item = base.add(self.current); + if pred(&mut *item) { + // Extract it! + self.current += 1; + return Some(item.read()); + } else { + // Keep it, shifting it down if needed. + if self.new_len != self.current { + debug_assert!(self.new_len < self.current); + let dest = base.add(self.new_len); + item.copy_to_nonoverlapping(dest, 1); + } + self.current += 1; + self.new_len += 1; + } + } + } + None + } + + pub(crate) fn remaining(&self) -> usize { + self.map.indices.len() - self.current + } +} diff --git a/src/map/iter.rs b/src/map/iter.rs index 1ec3703c..422fe4e1 100644 --- a/src/map/iter.rs +++ b/src/map/iter.rs @@ -1,5 +1,4 @@ -use super::core::IndexMapCore; -use super::{Bucket, Entries, IndexMap, Slice}; +use super::{Bucket, Entries, ExtractCore, IndexMap, IndexMapCore, Slice}; use alloc::vec::{self, Vec}; use core::fmt; @@ -711,3 +710,56 @@ where .finish() } } + +/// An extracting iterator for `IndexMap`. +/// +/// This `struct` is created by [`IndexMap::extract_if()`]. +/// See its documentation for more. +pub struct ExtractIf<'a, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + inner: ExtractCore<'a, K, V>, + pred: F, +} + +impl ExtractIf<'_, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + pub(super) fn new(core: &mut IndexMapCore, pred: F) -> ExtractIf<'_, K, V, F> { + ExtractIf { + inner: core.extract(), + pred, + } + } +} + +impl Iterator for ExtractIf<'_, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + type Item = (K, V); + + fn next(&mut self) -> Option { + self.inner + .extract_if(|bucket| { + let (key, value) = bucket.ref_mut(); + (self.pred)(key, value) + }) + .map(Bucket::key_value) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.inner.remaining())) + } +} + +impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F> +where + F: FnMut(&K, &mut V) -> bool, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtractIf").finish_non_exhaustive() + } +} diff --git a/src/set.rs b/src/set.rs index e2560843..eb744d2d 100644 --- a/src/set.rs +++ b/src/set.rs @@ -7,7 +7,7 @@ mod slice; mod tests; pub use self::iter::{ - Difference, Drain, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union, + Difference, Drain, ExtractIf, Intersection, IntoIter, Iter, Splice, SymmetricDifference, Union, }; pub use self::slice::Slice; @@ -253,6 +253,41 @@ impl IndexSet { Drain::new(self.map.core.drain(range)) } + /// Creates an iterator which uses a closure to determine if a value should be removed. + /// + /// If the closure returns true, then the value is removed and yielded. + /// If the closure returns false, the value will remain in the list and will not be yielded + /// by the iterator. + /// + /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use [`retain`] with a negated predicate if you do not need the returned iterator. + /// + /// [`retain`]: IndexSet::retain + /// + /// # Examples + /// + /// Splitting a set into even and odd values, reusing the original set: + /// + /// ``` + /// use indexmap::IndexSet; + /// + /// let mut set: IndexSet = (0..8).collect(); + /// let extracted: IndexSet = set.extract_if(|v| v % 2 == 0).collect(); + /// + /// let evens = extracted.into_iter().collect::>(); + /// let odds = set.into_iter().collect::>(); + /// + /// assert_eq!(evens, vec![0, 2, 4, 6]); + /// assert_eq!(odds, vec![1, 3, 5, 7]); + /// ``` + pub fn extract_if(&mut self, pred: F) -> ExtractIf<'_, T, F> + where + F: FnMut(&T) -> bool, + { + ExtractIf::new(&mut self.map.core, pred) + } + /// Splits the collection into two at the given index. /// /// Returns a newly allocated set containing the elements in the range diff --git a/src/set/iter.rs b/src/set/iter.rs index 3f8033c2..c893469d 100644 --- a/src/set/iter.rs +++ b/src/set/iter.rs @@ -1,3 +1,5 @@ +use crate::map::{ExtractCore, IndexMapCore}; + use super::{Bucket, Entries, IndexSet, Slice}; use alloc::vec::{self, Vec}; @@ -624,3 +626,53 @@ impl fmt::Debug for UnitValue { fmt::Debug::fmt(&self.0, f) } } + +/// An extracting iterator for `IndexSet`. +/// +/// This `struct` is created by [`IndexSet::extract_if()`]. +/// See its documentation for more. +pub struct ExtractIf<'a, T, F> +where + F: FnMut(&T) -> bool, +{ + inner: ExtractCore<'a, T, ()>, + pred: F, +} + +impl ExtractIf<'_, T, F> +where + F: FnMut(&T) -> bool, +{ + pub(super) fn new(core: &mut IndexMapCore, pred: F) -> ExtractIf<'_, T, F> { + ExtractIf { + inner: core.extract(), + pred, + } + } +} + +impl Iterator for ExtractIf<'_, T, F> +where + F: FnMut(&T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + self.inner + .extract_if(|bucket| (self.pred)(bucket.key_ref())) + .map(Bucket::key) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.inner.remaining())) + } +} + +impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F> +where + F: FnMut(&T) -> bool, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtractIf").finish_non_exhaustive() + } +} diff --git a/tests/quick.rs b/tests/quick.rs index 4142e2d6..78d41afe 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -174,6 +174,47 @@ quickcheck_limit! { } } + fn extract_if_odd(insert: Vec) -> bool { + let mut map = IndexMap::new(); + for &x in &insert { + map.insert(x, x.to_string()); + } + + let (odd, even): (Vec<_>, Vec<_>) = map.keys().copied().partition(|k| k % 2 == 1); + + let extracted: Vec<_> = map + .extract_if(|k, _| k % 2 == 1) + .map(|(k, _)| k) + .collect(); + + even.iter().all(|k| map.contains_key(k)) + && map.keys().eq(&even) + && extracted == odd + } + + fn extract_if_odd_limit(insert: Vec, limit: usize) -> bool { + let mut map = IndexMap::new(); + for &x in &insert { + map.insert(x, x.to_string()); + } + let limit = limit % (map.len() + 1); + + let mut i = 0; + let (odd, other): (Vec<_>, Vec<_>) = map.keys().copied().partition(|k| { + k % 2 == 1 && i < limit && { i += 1; true } + }); + + let extracted: Vec<_> = map + .extract_if(|k, _| k % 2 == 1) + .map(|(k, _)| k) + .take(limit) + .collect(); + + other.iter().all(|k| map.contains_key(k)) + && map.keys().eq(&other) + && extracted == odd + } + fn shift_remove(insert: Vec, remove: Vec) -> bool { let mut map = IndexMap::new(); for &key in &insert {