diff --git a/src/map.rs b/src/map.rs index 30c0a4eb57..314a9a1a1a 100644 --- a/src/map.rs +++ b/src/map.rs @@ -4132,6 +4132,11 @@ impl<'a, K, V, S, A: Allocator + Clone> RawVacantEntryMut<'a, K, V, S, A> { hash_builder: self.hash_builder, } } + + #[inline] + pub(crate) fn hasher(&self) -> &S { + self.hash_builder + } } impl Debug for RawEntryBuilderMut<'_, K, V, S, A> { diff --git a/src/set.rs b/src/set.rs index ce794ce50a..51b73ef348 100644 --- a/src/set.rs +++ b/src/set.rs @@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator}; use core::mem; use core::ops::{BitAnd, BitOr, BitXor, Sub}; -use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys}; +use super::map::{ + self, make_hash, make_insert_hash, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, + HashMap, Keys, RawEntryMut, +}; use crate::raw::{Allocator, Global}; // Future Optimization (FIXME!) @@ -953,6 +956,12 @@ where /// Inserts a value computed from `f` into the set if the given `value` is /// not present, then returns a reference to the value in the set. /// + /// # Panics + /// + /// Panics if the value from the function and the provided lookup value + /// are not equivalent or have different hashes. See [`Equivalent`] + /// and [`Hash`] for more information. + /// /// # Examples /// /// ``` @@ -967,6 +976,7 @@ where /// assert_eq!(value, pet); /// } /// assert_eq!(set.len(), 4); // a new "fish" was inserted + /// assert!(set.contains("fish")); /// ``` #[cfg_attr(feature = "inline-more", inline)] pub fn get_or_insert_with(&mut self, value: &Q, f: F) -> &T @@ -974,13 +984,32 @@ where Q: Hash + Equivalent, F: FnOnce(&Q) -> T, { + #[cold] + #[inline(never)] + fn assert_failed() { + panic!( + "the value from the function and the lookup value \ + must be equivalent and have the same hash" + ); + } + // Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with // `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`. - self.map - .raw_entry_mut() - .from_key(value) - .or_insert_with(|| (f(value), ())) - .0 + let hash = make_hash::(&self.map.hash_builder, value); + let raw_entry_builder = self.map.raw_entry_mut(); + match raw_entry_builder.from_key_hashed_nocheck(hash, value) { + RawEntryMut::Occupied(entry) => entry.into_key(), + RawEntryMut::Vacant(entry) => { + let insert_value = f(value); + let insert_value_hash = make_insert_hash::(entry.hasher(), &insert_value); + if !(hash == insert_value_hash && value.equivalent(&insert_value)) { + assert_failed(); + } + entry + .insert_hashed_nocheck(insert_value_hash, insert_value, ()) + .0 + } + } } /// Gets the given value's corresponding entry in the set for in-place manipulation. @@ -2429,7 +2458,7 @@ fn assert_covariance() { #[cfg(test)] mod test_set { use super::super::map::DefaultHashBuilder; - use super::HashSet; + use super::{make_hash, Equivalent, HashSet}; use std::vec::Vec; #[test] @@ -2886,4 +2915,88 @@ mod test_set { set.insert(i); } } + + #[test] + fn duplicate_insert() { + let mut set = HashSet::new(); + set.insert(1); + set.get_or_insert_with(&1, |_| 1); + set.get_or_insert_with(&1, |_| 1); + assert!([1].iter().eq(set.iter())); + } + + #[test] + #[allow(clippy::derived_hash_with_manual_eq)] + #[should_panic] + fn some_invalid_hash() { + use core::hash::{Hash, Hasher}; + #[derive(Eq, PartialEq)] + struct Invalid { + count: u32, + } + + struct InvalidRef { + count: u32, + } + impl Equivalent for InvalidRef { + fn equivalent(&self, key: &Invalid) -> bool { + self.count == key.count + } + } + impl Hash for Invalid { + fn hash(&self, state: &mut H) { + self.count.hash(state); + } + } + impl Hash for InvalidRef { + fn hash(&self, state: &mut H) { + let double = self.count * 2; + double.hash(state); + } + } + let mut set: HashSet = HashSet::new(); + let key = InvalidRef { count: 1 }; + let value = Invalid { count: 1 }; + if key.equivalent(&value) { + set.get_or_insert_with(&key, |_| value); + } + } + + #[test] + #[allow(clippy::derived_hash_with_manual_eq)] + #[should_panic] + fn some_invalid_equivalent() { + use core::hash::{Hash, Hasher}; + #[derive(Eq, PartialEq)] + struct Invalid { + count: u32, + other: u32, + } + + struct InvalidRef { + count: u32, + other: u32, + } + impl Equivalent for InvalidRef { + fn equivalent(&self, key: &Invalid) -> bool { + self.count == key.count && self.other == key.other + } + } + impl Hash for Invalid { + fn hash(&self, state: &mut H) { + self.count.hash(state); + } + } + impl Hash for InvalidRef { + fn hash(&self, state: &mut H) { + self.count.hash(state); + } + } + let mut set: HashSet = HashSet::new(); + let key = InvalidRef { count: 1, other: 1 }; + let value = Invalid { count: 1, other: 2 }; + if make_hash(set.hasher(), &key) == make_hash(set.hasher(), &value) { + set.get_or_insert_with(&key, |_| value); + } + } }