Skip to content

Commit

Permalink
Use arc_swap to implement AtomicStr (#72)
Browse files Browse the repository at this point in the history
* Use `arc_swap` to implement `AtomicStr`

The previous implementation allowed use-after-free in a multi-threaded
context. This PR fixes the problem by using `ArcSwap`, which is
implements thread-safe swapping of `Arc`s, and is widely used.

* Use `triomphe::Arc` to mitigate performance losses

This change replaces `std::sync::Arc` with `triomphe::Arc`. The latter
has no weak references, and is a lot faster because of that.

* Update atomic_str.rs

* Update lib.rs

* Update lib.rs

---------

Co-authored-by: Jason Lee <[email protected]>
  • Loading branch information
Kijewski and huacnlee authored Jan 22, 2024
1 parent 37aa93a commit 22e0609
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ You can use `rust_i18n::set_locale` to set the global locale at runtime, so that
rust_i18n::set_locale("zh-CN");
let locale = rust_i18n::locale();
assert_eq!(*locale, "zh-CN");
assert_eq!(&*locale, "zh-CN");
```

### Extend Backend
Expand Down
2 changes: 2 additions & 0 deletions crates/support/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repository = "https://github.com/longbridgeapp/rust-i18n"
version = "3.0.0"

[dependencies]
arc-swap = "1.6.0"
globwalk = "0.8.1"
once_cell = "1.10.0"
proc-macro2 = "1.0"
Expand All @@ -18,3 +19,4 @@ toml = "0.7.4"
normpath = "1.1.1"
lazy_static = "1"
regex = "1"
triomphe = { version = "0.1.11", features = ["arc-swap"] }
93 changes: 39 additions & 54 deletions crates/support/src/atomic_str.rs
Original file line number Diff line number Diff line change
@@ -1,80 +1,65 @@
use std::fmt;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
use std::ops::Deref;

use arc_swap::{ArcSwapAny, Guard};
use triomphe::Arc;

/// A thread-safe atomically reference-counting string.
pub struct AtomicStr(AtomicPtr<String>);
pub struct AtomicStr(ArcSwapAny<Arc<String>>);

/// A thread-safe view the string that was stored when `AtomicStr::as_str()` was called.
struct GuardedStr(Guard<Arc<String>>);

impl Deref for GuardedStr {
type Target = str;

fn deref(&self) -> &Self::Target {
self.0.as_str()
}
}

impl AtomicStr {
/// Create a new `AtomicStr` with the given value.
pub fn new(value: impl Into<String>) -> Self {
pub fn new(value: &str) -> Self {
let arced = Arc::new(value.into());
Self(AtomicPtr::new(Arc::into_raw(arced) as _))
Self(ArcSwapAny::new(arced))
}

/// Get the string slice.
pub fn as_str(&self) -> &str {
unsafe {
let arced_ptr = self.0.load(Ordering::SeqCst);
assert!(!arced_ptr.is_null());
&*arced_ptr
}
}

/// Get the cloned inner `Arc<String>`.
pub fn clone_string(&self) -> Arc<String> {
unsafe {
let arced_ptr = self.0.load(Ordering::SeqCst);
assert!(!arced_ptr.is_null());
Arc::increment_strong_count(arced_ptr);
Arc::from_raw(arced_ptr)
}
pub fn as_str(&self) -> impl Deref<Target = str> {
GuardedStr(self.0.load())
}

/// Replaces the value at self with src, returning the old value, without dropping either.
pub fn replace(&self, src: impl Into<String>) -> Arc<String> {
unsafe {
let arced_new = Arc::new(src.into());
let arced_old_ptr = self.0.swap(Arc::into_raw(arced_new) as _, Ordering::SeqCst);
assert!(!arced_old_ptr.is_null());
Arc::from_raw(arced_old_ptr)
}
/// Replaces the value at self with src.
pub fn replace(&self, src: impl Into<String>) {
let arced = Arc::new(src.into());
self.0.store(arced);
}
}

impl Drop for AtomicStr {
fn drop(&mut self) {
unsafe {
let arced_ptr = self.0.swap(std::ptr::null_mut(), Ordering::SeqCst);
assert!(!arced_ptr.is_null());
let _ = Arc::from_raw(arced_ptr);
}
impl From<&str> for AtomicStr {
fn from(value: &str) -> Self {
Self::new(value)
}
}

impl AsRef<str> for AtomicStr {
fn as_ref(&self) -> &str {
self.as_str()
impl fmt::Display for AtomicStr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str())
}
}

impl<T> From<T> for AtomicStr
where
T: Into<String>,
{
fn from(value: T) -> Self {
Self::new(value)
}
}
#[cfg(test)]
mod tests {
use super::*;

impl From<&AtomicStr> for Arc<String> {
fn from(value: &AtomicStr) -> Self {
value.clone_string()
fn test_str(s: &str) {
assert_eq!(s, "hello");
}
}

impl fmt::Display for AtomicStr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
#[test]
fn test_atomic_str() {
let s = AtomicStr::from("hello");
test_str(&s.as_str());
}
}
13 changes: 7 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![doc = include_str!("../README.md")]

use std::ops::Deref;

use once_cell::sync::Lazy;
use std::sync::Arc;

#[doc(hidden)]
pub use once_cell;
Expand All @@ -16,8 +17,8 @@ pub fn set_locale(locale: &str) {
}

/// Get current locale
pub fn locale() -> Arc<String> {
CURRENT_LOCALE.clone_string()
pub fn locale() -> impl Deref<Target = str> {
CURRENT_LOCALE.as_str()
}

/// Replace patterns and return a new string.
Expand Down Expand Up @@ -110,7 +111,7 @@ pub fn replace_patterns(input: &str, patterns: &[&str], values: &[String]) -> St
macro_rules! t {
// t!("foo")
($key:expr) => {
crate::_rust_i18n_translate(rust_i18n::locale().as_str(), $key)
crate::_rust_i18n_translate(&rust_i18n::locale(), $key)
};

// t!("foo", locale = "en")
Expand Down Expand Up @@ -184,7 +185,7 @@ mod tests {

#[test]
fn test_locale() {
assert_locale_type(locale().as_str(), CURRENT_LOCALE.as_str());
assert_locale_type(&locale(), CURRENT_LOCALE.as_str());
assert_locale_type(&locale(), &CURRENT_LOCALE.as_str());
assert_eq!(locale().deref(), "en");
}
}
34 changes: 34 additions & 0 deletions tests/multi_threading.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::ops::Add;
use std::thread::spawn;
use std::time::{Duration, Instant};

use rust_i18n::{set_locale, t};

rust_i18n::i18n!("locales", fallback = "en");

#[test]
fn test_load_and_store() {
let end = Instant::now().add(Duration::from_secs(3));
let store = spawn(move || {
let mut i = 0u32;
while Instant::now() < end {
for _ in 0..100 {
i = i.wrapping_add(1);
if i % 2 == 0 {
set_locale(&format!("en-{i}"));
} else {
set_locale(&format!("fr-{i}"));
}
}
}
});
let load = spawn(move || {
while Instant::now() < end {
for _ in 0..100 {
t!("hello");
}
}
});
store.join().unwrap();
load.join().unwrap();
}

0 comments on commit 22e0609

Please sign in to comment.