Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArrayString: Reducing reliance on unsafe blocks + adding some safety comments #288

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 91 additions & 38 deletions src/array_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::mem::MaybeUninit;
use std::ops::{Deref, DerefMut};
#[cfg(feature="std")]
use std::path::Path;
use std::ptr;
use std::slice;
use std::str;
use std::str::FromStr;
Expand Down Expand Up @@ -35,7 +34,8 @@ use serde::{Serialize, Deserialize, Serializer, Deserializer};
#[derive(Copy)]
#[repr(C)]
pub struct ArrayString<const CAP: usize> {
// the `len` first elements of the array are initialized
// the `len` first elements of the array are initialized and contain valid
// UTF-8
len: LenUint,
xs: [MaybeUninit<u8>; CAP],
}
Expand Down Expand Up @@ -64,8 +64,9 @@ impl<const CAP: usize> ArrayString<CAP>
/// ```
pub fn new() -> ArrayString<CAP> {
assert_capacity_limit!(CAP);
unsafe {
ArrayString { xs: MaybeUninit::uninit().assume_init(), len: 0 }
ArrayString {
xs: [MaybeUninit::uninit(); CAP],
len: 0,
}
}

Expand Down Expand Up @@ -124,11 +125,21 @@ impl<const CAP: usize> ArrayString<CAP>
let len = str::from_utf8(b)?.len();
debug_assert_eq!(len, CAP);
let mut vec = Self::new();

// This seems to result in the same, fast assembly code as some
// `unsafe` transmutes and a call to `copy_to_nonoverlapping`.
// See https://godbolt.org/z/vhM1WePTK for more details.
for (dst, src) in vec.xs.iter_mut().zip(b.iter()) {
*dst = MaybeUninit::new(*src);
}

// SAFETY: Copying `CAP` bytes in the `for` loop above initializes
// all the bytes in `vec`. `str::from_utf8` call above promises
// that the bytes are valid UTF-8.
unsafe {
(b as *const [u8; CAP] as *const [MaybeUninit<u8>; CAP])
.copy_to_nonoverlapping(&mut vec.xs as *mut [MaybeUninit<u8>; CAP], 1);
vec.set_len(CAP);
}

Ok(vec)
}

Expand All @@ -144,13 +155,9 @@ impl<const CAP: usize> ArrayString<CAP>
#[inline]
pub fn zero_filled() -> Self {
assert_capacity_limit!(CAP);
// SAFETY: `assert_capacity_limit` asserts that `len` won't overflow and
// `zeroed` fully fills the array with nulls.
unsafe {
ArrayString {
xs: MaybeUninit::zeroed().assume_init(),
len: CAP as _
}
ArrayString {
xs: [MaybeUninit::zeroed(); CAP],
len: CAP as _,
}
}

Expand Down Expand Up @@ -229,16 +236,21 @@ impl<const CAP: usize> ArrayString<CAP>
/// ```
pub fn try_push(&mut self, c: char) -> Result<(), CapacityError<char>> {
let len = self.len();
unsafe {
let ptr = self.as_mut_ptr().add(len);
let remaining_cap = self.capacity() - len;
match encode_utf8(c, ptr, remaining_cap) {
Ok(n) => {
let ptr: *mut MaybeUninit<u8> = self.xs[len..].as_mut_ptr();
let ptr = ptr as *mut u8;
let remaining_cap = self.capacity() - len;

// SAFETY: `ptr` points to `remaining_cap` bytes.
match unsafe { encode_utf8(c, ptr, remaining_cap) } {
Ok(n) => {
// SAFETY: `encode_utf8` promises that it initialized `n` bytes
// and that it wrote valid UTF-8.
unsafe {
self.set_len(len + n);
Ok(())
}
Err(_) => Err(CapacityError::new(c)),
Ok(())
}
Err(_) => Err(CapacityError::new(c)),
}
}

Expand Down Expand Up @@ -285,13 +297,25 @@ impl<const CAP: usize> ArrayString<CAP>
if s.len() > self.capacity() - self.len() {
return Err(CapacityError::new(s));
}
let old_len = self.len();
let new_len = old_len + s.len();

// This loop is similar to the one in `from_byte_string` and therefore
// it is expected to result in the same, fast assembly code as some
// `unsafe` transmutes and a call to `copy_to_nonoverlapping`.
let dst = &mut self.xs[old_len..new_len];
let src = s.as_bytes();
for (dst, src) in dst.iter_mut().zip(src.iter()) {
*dst = MaybeUninit::new(*src);
}

// SAFETY: Copying `CAP` bytes in the `for` loop above initializes
// all the bytes in `self.xs[old_len..new_len]`. We copy the bytes
// from `s: &'a str` so the bytes must be valid UTF-8.
unsafe {
let dst = self.as_mut_ptr().add(self.len());
let src = s.as_ptr();
ptr::copy_nonoverlapping(src, dst, s.len());
let newl = self.len() + s.len();
self.set_len(newl);
self.set_len(new_len);
}

Ok(())
}

Expand All @@ -316,9 +340,17 @@ impl<const CAP: usize> ArrayString<CAP>
None => return None,
};
let new_len = self.len() - ch.len_utf8();

// SAFETY: Type invariant guarantees that `self.len()` bytes are
// initialized and valid UTF-8. Therefore `new_len` bytes (less bytes)
// are also initialized. And they are still valid UTF-8 because we cut
// on char boundary.
unsafe {
debug_assert!(new_len <= self.len());
debug_assert!(self.is_char_boundary(new_len));
self.set_len(new_len);
}

Some(ch)
}

Expand All @@ -341,11 +373,17 @@ impl<const CAP: usize> ArrayString<CAP>
pub fn truncate(&mut self, new_len: usize) {
if new_len <= self.len() {
assert!(self.is_char_boundary(new_len));

// SAFETY: Type invariant guarantees that `self.len()` bytes are
// initialized and form valid UTF-8. `new_len` bytes are also
// initialized, because we checked above that `new_len <=
// self.len()`. And `new_len` bytes are valid UTF-8, because we
// `assert!` above that `new_len` is at a char boundary.
//
// In libstd truncate is called on the underlying vector, which in
// turns drops each element. Here we work with `u8` butes, so we
// don't have to worry about Drop, and we can just set the length.
unsafe {
// In libstd truncate is called on the underlying vector,
// which in turns drops each element.
// As we know we don't have to worry about Drop,
// we can just set the length (a la clear.)
self.set_len(new_len);
}
}
Expand Down Expand Up @@ -375,36 +413,51 @@ impl<const CAP: usize> ArrayString<CAP>
};

let next = idx + ch.len_utf8();
self.xs.copy_within(next.., idx);

// SAFETY: Type invariant guarantees that `self.len()` bytes are
// initialized and form valid UTF-8. Therefore `new_len` bytes (less
// bytes) are also initialized. We remove a whole UTF-8 char, so
// `new_len` bytes remain valid UTF-8.
let len = self.len();
let ptr = self.as_mut_ptr();
let new_len = len - (next - idx);
unsafe {
ptr::copy(
ptr.add(next),
ptr.add(idx),
len - next);
self.set_len(len - (next - idx));
debug_assert!(new_len <= self.len());
self.set_len(new_len);
}
ch
}

/// Make the string empty.
pub fn clear(&mut self) {
// SAFETY: Empty slice is initialized by definition. Empty string is
// valid UTF-8.
unsafe {
self.set_len(0);
}
}

/// Set the strings’s length.
///
/// This function is `unsafe` because it changes the notion of the
/// number of “valid” bytes in the string. Use with care.
///
/// This method uses *debug assertions* to check the validity of `length`
/// and may use other debug assertions.
///
/// # Safety
///
/// The caller needs to guarantee that `length` bytes of the underlying
/// storage:
///
/// * have been initialized
/// * encode valid UTF-8
pub unsafe fn set_len(&mut self, length: usize) {
// type invariant that capacity always fits in LenUint
debug_assert!(length <= self.capacity());

self.len = length as LenUint;

// type invariant that we contain a valid UTF-8 string
// (this is just an O(1) heuristic - full check would require O(N)).
debug_assert!(self.is_char_boundary(length));
}

/// Return a string slice of the whole `ArrayString`.
Expand Down