Skip to content

Commit

Permalink
Use arc_swap to implement AtomicStr
Browse files Browse the repository at this point in the history
The previous implementation allowed use-after-free in a multi-threaded
context. This PR fixes the problem by using `ArcSwap`, which is
implements thread-safe swapping of `Arc`s, and is widely used.
  • Loading branch information
Kijewski committed Jan 19, 2024
1 parent fe72199 commit b740582
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 39 deletions.
1 change: 1 addition & 0 deletions crates/support/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repository = "https://github.com/longbridgeapp/rust-i18n"
version = "3.0.0"

[dependencies]
arc-swap = "1.6.0"
globwalk = "0.8.1"
once_cell = "1.10.0"
proc-macro2 = "1.0"
Expand Down
65 changes: 26 additions & 39 deletions crates/support/src/atomic_str.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,47 @@
use std::fmt;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;

use arc_swap::{ArcSwap, Guard};

/// A thread-safe atomically reference-counting string.
pub struct AtomicStr(AtomicPtr<String>);
pub struct AtomicStr(ArcSwap<String>);

/// A thread-safe view the string that was stored when `AtomicStr::as_str()` was called.
struct GuardedStr(Guard<Arc<String>>);

impl AsRef<str> for GuardedStr {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}

impl fmt::Display for GuardedStr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.0.as_ref())
}
}

impl AtomicStr {
/// Create a new `AtomicStr` with the given value.
pub fn new(value: impl Into<String>) -> Self {
let arced = Arc::new(value.into());
Self(AtomicPtr::new(Arc::into_raw(arced) as _))
Self(ArcSwap::new(arced))
}

/// Get the string slice.
pub fn as_str(&self) -> &str {
unsafe {
let arced_ptr = self.0.load(Ordering::SeqCst);
assert!(!arced_ptr.is_null());
&*arced_ptr
}
pub fn as_str(&self) -> impl AsRef<str> + fmt::Display {
GuardedStr(self.0.load())
}

/// Get the cloned inner `Arc<String>`.
pub fn clone_string(&self) -> Arc<String> {
unsafe {
let arced_ptr = self.0.load(Ordering::SeqCst);
assert!(!arced_ptr.is_null());
Arc::increment_strong_count(arced_ptr);
Arc::from_raw(arced_ptr)
}
}

/// Replaces the value at self with src, returning the old value, without dropping either.
pub fn replace(&self, src: impl Into<String>) -> Arc<String> {
unsafe {
let arced_new = Arc::new(src.into());
let arced_old_ptr = self.0.swap(Arc::into_raw(arced_new) as _, Ordering::SeqCst);
assert!(!arced_old_ptr.is_null());
Arc::from_raw(arced_old_ptr)
}
}
}

impl Drop for AtomicStr {
fn drop(&mut self) {
unsafe {
let arced_ptr = self.0.swap(std::ptr::null_mut(), Ordering::SeqCst);
assert!(!arced_ptr.is_null());
let _ = Arc::from_raw(arced_ptr);
}
Guard::into_inner(self.0.load())
}
}

impl AsRef<str> for AtomicStr {
fn as_ref(&self) -> &str {
self.as_str()
/// Replaces the value at self with src.
pub fn replace(&self, src: impl Into<String>) {
let arced = Arc::new(src.into());
self.0.store(arced);
}
}

Expand Down
34 changes: 34 additions & 0 deletions tests/multi_threading.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::ops::Add;
use std::thread::spawn;
use std::time::{Duration, Instant};

use rust_i18n::{set_locale, t};

rust_i18n::i18n!("locales", fallback = "en");

#[test]
fn test_load_and_store() {
let end = Instant::now().add(Duration::from_secs(3));
let store = spawn(move || {
let mut i = 0u32;
while Instant::now() < end {
for _ in 0..100 {
i = i.wrapping_add(1);
if i % 2 == 0 {
set_locale(&format!("en-{i}"));
} else {
set_locale(&format!("fr-{i}"));
}
}
}
});
let load = spawn(move || {
while Instant::now() < end {
for _ in 0..100 {
t!("hello");
}
}
});
store.join().unwrap();
load.join().unwrap();
}

0 comments on commit b740582

Please sign in to comment.