Skip to content

Commit

Permalink
Always check Guard collector
Browse files Browse the repository at this point in the history
Fixes #46.
  • Loading branch information
jonhoo committed Jan 30, 2020
1 parent 39f439f commit 32ff9b4
Showing 1 changed file with 82 additions and 9 deletions.
91 changes: 82 additions & 9 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,36 @@ pub struct HashMap<K, V, S = crate::DefaultHashBuilder> {
/// next element count value upon which to resize the table.
size_ctl: AtomicIsize,

/// Collector that all `Guard` references used for operations on this map must be tied to. It
/// is important that they all assocate with the _same_ `Collector`, otherwise you end up with
/// unsoundness as described in https://github.com/jonhoo/flurry/issues/46. Specifically, a
/// user can do:
///
/// ```rust,should_panic
/// # use flurry::HashMap;
/// use crossbeam_epoch as epoch;
/// let map: HashMap<_, _> = HashMap::default();
/// map.insert(42, String::from("hello"), &epoch::pin());
///
/// let evil = epoch::Collector::new();
/// let evil = evil.register();
/// let guard = evil.pin();
/// let oops = map.get(&42, &guard);
///
/// map.remove(&42, &epoch::pin());
/// // at this point, the default collector is allowed to free `"hello"`
/// // since no-one has the global epoch pinned as far as it is aware.
/// // `oops` is tied to the lifetime of a Guard that is not a part of
/// // the same epoch group, and so can now be dangling.
/// // but we can still access it!
/// assert_eq!(oops.unwrap(), "hello");
/// ```
///
/// We avoid that by checking that every external guard that is passed in is associated with
/// the `Collector` that was specified when the map was created (which may be the global
/// collector).
collector: epoch::Collector,

build_hasher: S,
}

Expand Down Expand Up @@ -131,6 +161,36 @@ where
count: AtomicUsize::new(0),
size_ctl: AtomicIsize::new(0),
build_hasher: hash_builder,
collector: epoch::default_collector().clone(),
}
}

/// Associate a custom [`epoch::Collector`] with this map.
///
/// By default, the global collector is used. With this method you can use a different
/// collector instead. This may be desireable if you want more control over when and how memory
/// reclamation happens.
///
/// Note that _all_ `Guard` references provided to access the returned map _must_ be
/// constructed using guards produced by `collector`. You can use [`HashMap::pin`] to get a
/// thread-local handle to the collector that lets you construct `Guard`s.
pub fn with_collector(mut self, collector: epoch::Collector) -> Self {
self.collector = collector;
self
}

/// Allocate a thread-local handle to the [`epoch::Collector`] associated with this map.
///
/// You can use the returned handle to produce [`epoch::Guard`] references.
pub fn register(&self) -> epoch::LocalHandle {
self.collector.register()
}

#[inline]
fn check_guard(&self, guard: &Guard) {
// guard.collector() may be `None` if it is unprotected
if let Some(c) = guard.collector() {
assert_eq!(c, &self.collector);
}
}

Expand Down Expand Up @@ -179,6 +239,7 @@ where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.check_guard(guard);
self.get(key, &guard).is_some()
}

Expand Down Expand Up @@ -238,7 +299,7 @@ where
///
/// Returns `None` if this map contains no mapping for the key.
///
/// To obtain a `Guard`, use [`epoch::pin`].
/// To obtain a `Guard`, use [`HashMap::register`].
///
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
/// form must match those for the key type.
Expand All @@ -248,6 +309,7 @@ where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.check_guard(guard);
let node = self.get_node(key, guard)?;

let v = node.value.load(Ordering::SeqCst, guard);
Expand Down Expand Up @@ -286,6 +348,7 @@ where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.check_guard(guard);
let node = self.get_node(key, guard)?;

let v = node.value.load(Ordering::SeqCst, guard);
Expand Down Expand Up @@ -340,6 +403,7 @@ where
///
/// The value can be retrieved by calling [`HashMap::get`] with a key that is equal to the original key.
pub fn insert<'g>(&'g self, key: K, value: V, guard: &'g Guard) -> Option<&'g V> {
self.check_guard(guard);
self.put(key, value, false, guard)
}

Expand Down Expand Up @@ -564,6 +628,7 @@ where
Q: ?Sized + Hash + Eq,
F: FnOnce(&K, &V) -> Option<V>,
{
self.check_guard(guard);
let h = self.hash(&key);

let mut table = self.table.load(Ordering::SeqCst, guard);
Expand Down Expand Up @@ -1273,6 +1338,7 @@ where
/// Tries to reserve capacity for at least additional more elements.
/// The collection may reserve more space to avoid frequent reallocations.
pub fn reserve(&self, additional: usize, guard: &Guard) {
self.check_guard(guard);
let absolute = self.len() + additional;
self.try_presize(absolute, guard);
}
Expand All @@ -1288,6 +1354,7 @@ where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.check_guard(guard);
self.replace_node(key, None, None, guard)
}

Expand Down Expand Up @@ -1471,6 +1538,7 @@ where
where
F: FnMut(&K, &V) -> bool,
{
self.check_guard(guard);
// removed selected keys
for (k, v) in self.iter(&guard) {
if !f(k, v) {
Expand All @@ -1490,6 +1558,7 @@ where
where
F: FnMut(&K, &V) -> bool,
{
self.check_guard(guard);
// removed selected keys
for (k, v) in self.iter(&guard) {
if !f(k, v) {
Expand All @@ -1503,6 +1572,7 @@ where
///
/// To obtain a `Guard`, use [`epoch::pin`].
pub fn iter<'g>(&'g self, guard: &'g Guard) -> Iter<'g, K, V> {
self.check_guard(guard);
let table = self.table.load(Ordering::SeqCst, guard);
let node_iter = NodeIter::new(table, guard);
Iter { node_iter, guard }
Expand All @@ -1513,6 +1583,7 @@ where
///
/// To obtain a `Guard`, use [`epoch::pin`].
pub fn keys<'g>(&'g self, guard: &'g Guard) -> Keys<'g, K, V> {
self.check_guard(guard);
let table = self.table.load(Ordering::SeqCst, guard);
let node_iter = NodeIter::new(table, guard);
Keys { node_iter }
Expand All @@ -1523,6 +1594,7 @@ where
///
/// To obtain a `Guard`, use [`epoch::pin`].
pub fn values<'g>(&'g self, guard: &'g Guard) -> Values<'g, K, V> {
self.check_guard(guard);
let table = self.table.load(Ordering::SeqCst, guard);
let node_iter = NodeIter::new(table, guard);
Values { node_iter, guard }
Expand All @@ -1538,6 +1610,7 @@ where
#[cfg(test)]
/// Returns the capacity of the map.
fn capacity(&self, guard: &Guard) -> usize {
self.check_guard(guard);
let table = self.table.load(Ordering::Relaxed, &guard);

if table.is_null() {
Expand Down Expand Up @@ -1567,9 +1640,10 @@ where
return false;
}

let guard = epoch::pin();
self.iter(&guard)
.all(|(key, value)| other.get(key, &guard).map_or(false, |v| *value == *v))
let our_guard = self.collector.register().pin();
let their_guard = other.collector.register().pin();
self.iter(&our_guard)
.all(|(key, value)| other.get(key, &their_guard).map_or(false, |v| *value == *v))
}
}

Expand All @@ -1588,7 +1662,7 @@ where
S: BuildHasher,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let guard = epoch::pin();
let guard = self.collector.register().pin();
f.debug_map().entries(self.iter(&guard)).finish()
}
}
Expand Down Expand Up @@ -1637,8 +1711,7 @@ where
(iter.size_hint().0 + 1) / 2
};

let guard = epoch::pin();

let guard = self.collector.register().pin();
self.reserve(reserve, &guard);
(*self).put_all(iter.into_iter(), &guard);
}
Expand Down Expand Up @@ -1715,12 +1788,12 @@ where
fn clone(&self) -> HashMap<K, V, S> {
let cloned_map = Self::with_capacity_and_hasher(self.build_hasher.clone(), self.len());
{
let guard = epoch::pin();
let guard = self.collector.register().pin();
for (k, v) in self.iter(&guard) {
cloned_map.insert(k.clone(), v.clone(), &guard);
}
}
cloned_map
cloned_map.with_collector(self.collector.clone())
}
}

Expand Down

0 comments on commit 32ff9b4

Please sign in to comment.