-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Use `arc_swap` to implement `AtomicStr` The previous implementation allowed use-after-free in a multi-threaded context. This PR fixes the problem by using `ArcSwap`, which is implements thread-safe swapping of `Arc`s, and is widely used. * Use `triomphe::Arc` to mitigate performance losses This change replaces `std::sync::Arc` with `triomphe::Arc`. The latter has no weak references, and is a lot faster because of that. * Update atomic_str.rs * Update lib.rs * Update lib.rs --------- Co-authored-by: Jason Lee <[email protected]>
- Loading branch information
Showing
5 changed files
with
83 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,65 @@ | ||
use std::fmt; | ||
use std::sync::atomic::{AtomicPtr, Ordering}; | ||
use std::sync::Arc; | ||
use std::ops::Deref; | ||
|
||
use arc_swap::{ArcSwapAny, Guard}; | ||
use triomphe::Arc; | ||
|
||
/// A thread-safe atomically reference-counting string. | ||
pub struct AtomicStr(AtomicPtr<String>); | ||
pub struct AtomicStr(ArcSwapAny<Arc<String>>); | ||
|
||
/// A thread-safe view the string that was stored when `AtomicStr::as_str()` was called. | ||
struct GuardedStr(Guard<Arc<String>>); | ||
|
||
impl Deref for GuardedStr { | ||
type Target = str; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
self.0.as_str() | ||
} | ||
} | ||
|
||
impl AtomicStr { | ||
/// Create a new `AtomicStr` with the given value. | ||
pub fn new(value: impl Into<String>) -> Self { | ||
pub fn new(value: &str) -> Self { | ||
let arced = Arc::new(value.into()); | ||
Self(AtomicPtr::new(Arc::into_raw(arced) as _)) | ||
Self(ArcSwapAny::new(arced)) | ||
} | ||
|
||
/// Get the string slice. | ||
pub fn as_str(&self) -> &str { | ||
unsafe { | ||
let arced_ptr = self.0.load(Ordering::SeqCst); | ||
assert!(!arced_ptr.is_null()); | ||
&*arced_ptr | ||
} | ||
} | ||
|
||
/// Get the cloned inner `Arc<String>`. | ||
pub fn clone_string(&self) -> Arc<String> { | ||
unsafe { | ||
let arced_ptr = self.0.load(Ordering::SeqCst); | ||
assert!(!arced_ptr.is_null()); | ||
Arc::increment_strong_count(arced_ptr); | ||
Arc::from_raw(arced_ptr) | ||
} | ||
pub fn as_str(&self) -> impl Deref<Target = str> { | ||
GuardedStr(self.0.load()) | ||
} | ||
|
||
/// Replaces the value at self with src, returning the old value, without dropping either. | ||
pub fn replace(&self, src: impl Into<String>) -> Arc<String> { | ||
unsafe { | ||
let arced_new = Arc::new(src.into()); | ||
let arced_old_ptr = self.0.swap(Arc::into_raw(arced_new) as _, Ordering::SeqCst); | ||
assert!(!arced_old_ptr.is_null()); | ||
Arc::from_raw(arced_old_ptr) | ||
} | ||
/// Replaces the value at self with src. | ||
pub fn replace(&self, src: impl Into<String>) { | ||
let arced = Arc::new(src.into()); | ||
self.0.store(arced); | ||
} | ||
} | ||
|
||
impl Drop for AtomicStr { | ||
fn drop(&mut self) { | ||
unsafe { | ||
let arced_ptr = self.0.swap(std::ptr::null_mut(), Ordering::SeqCst); | ||
assert!(!arced_ptr.is_null()); | ||
let _ = Arc::from_raw(arced_ptr); | ||
} | ||
impl From<&str> for AtomicStr { | ||
fn from(value: &str) -> Self { | ||
Self::new(value) | ||
} | ||
} | ||
|
||
impl AsRef<str> for AtomicStr { | ||
fn as_ref(&self) -> &str { | ||
self.as_str() | ||
impl fmt::Display for AtomicStr { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.write_str(&self.as_str()) | ||
} | ||
} | ||
|
||
impl<T> From<T> for AtomicStr | ||
where | ||
T: Into<String>, | ||
{ | ||
fn from(value: T) -> Self { | ||
Self::new(value) | ||
} | ||
} | ||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
impl From<&AtomicStr> for Arc<String> { | ||
fn from(value: &AtomicStr) -> Self { | ||
value.clone_string() | ||
fn test_str(s: &str) { | ||
assert_eq!(s, "hello"); | ||
} | ||
} | ||
|
||
impl fmt::Display for AtomicStr { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
write!(f, "{}", self.as_str()) | ||
#[test] | ||
fn test_atomic_str() { | ||
let s = AtomicStr::from("hello"); | ||
test_str(&s.as_str()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
use std::ops::Add; | ||
use std::thread::spawn; | ||
use std::time::{Duration, Instant}; | ||
|
||
use rust_i18n::{set_locale, t}; | ||
|
||
rust_i18n::i18n!("locales", fallback = "en"); | ||
|
||
#[test] | ||
fn test_load_and_store() { | ||
let end = Instant::now().add(Duration::from_secs(3)); | ||
let store = spawn(move || { | ||
let mut i = 0u32; | ||
while Instant::now() < end { | ||
for _ in 0..100 { | ||
i = i.wrapping_add(1); | ||
if i % 2 == 0 { | ||
set_locale(&format!("en-{i}")); | ||
} else { | ||
set_locale(&format!("fr-{i}")); | ||
} | ||
} | ||
} | ||
}); | ||
let load = spawn(move || { | ||
while Instant::now() < end { | ||
for _ in 0..100 { | ||
t!("hello"); | ||
} | ||
} | ||
}); | ||
store.join().unwrap(); | ||
load.join().unwrap(); | ||
} |