diff --git a/commons/zenoh-shm/src/api/provider/types.rs b/commons/zenoh-shm/src/api/provider/types.rs index bb04dfa5fc..71d3753e26 100644 --- a/commons/zenoh-shm/src/api/provider/types.rs +++ b/commons/zenoh-shm/src/api/provider/types.rs @@ -52,7 +52,7 @@ impl Display for AllocAlignment { impl Default for AllocAlignment { fn default() -> Self { Self { - pow: (std::mem::align_of::() as f64).log2().round() as u8, + pow: std::mem::align_of::().ilog2() as _, } } } @@ -65,9 +65,10 @@ impl AllocAlignment { /// This function will return an error if provided alignment power cannot fit into usize. #[zenoh_macros::unstable_doc] pub const fn new(pow: u8) -> Result { - match pow { - pow if pow < usize::BITS as u8 => Ok(Self { pow }), - _ => Err(ZLayoutError::IncorrectLayoutArgs), + if pow < usize::BITS as u8 { + Ok(Self { pow }) + } else { + Err(ZLayoutError::IncorrectLayoutArgs) } } @@ -92,19 +93,45 @@ impl AllocAlignment { /// ``` #[zenoh_macros::unstable_doc] pub fn align_size(&self, size: NonZeroUsize) -> NonZeroUsize { - let alignment = self.get_alignment_value(); - match size.get() % alignment { - 0 => size, - // SAFETY: - // This unsafe block is always safe: - // 1. 0 < remainder < alignment - // 2. because of 1, the value of (alignment.get() - remainder) is always > 0 - // 3. because of 2, we add nonzero size to nonzero (alignment.get() - remainder) and it is always positive if no overflow - // 4. we make sure that there is no overflow condition in 3 by means of alignment limitation in `new` by limiting pow value - remainder => unsafe { - NonZeroUsize::new_unchecked(size.get() + (alignment.get() - remainder)) - }, - } + // Notations: + // - size to align S + // - usize::BITS B + // - pow P where 0 ≤ P < B + // - alignment value A = 2^P + // - return R = min{x | x ≥ S, x % A = 0} + // + // Example 1: A = 4 = (00100)₂, S = 4 = (00100)₂ ⇒ R = 4 = (00100)₂ + // Example 2: A = 4 = (00100)₂, S = 7 = (00111)₂ ⇒ R = 8 = (01000)₂ + // Example 3: A = 4 = (00100)₂, S = 8 = (01000)₂ ⇒ R = 8 = (01000)₂ + // Example 4: A = 4 = (00100)₂, S = 9 = (01001)₂ ⇒ R = 12 = (01100)₂ + // + // Algorithm: For any x = (bₙ, ⋯, b₂, b₁)₂ in binary representation, + // 1. x % A = 0 ⇔ ∀i < P, bᵢ = 0 + // 2. f(x) ≜ x & !(A-1) leads to ∀i < P, bᵢ = 0, hence f(x) % A = 0 + // (i.e. f zeros all bits before the P-th bit) + // 3. R = min{x | x ≥ S, x % A = 0} is equivalent to find the unique R where S ≤ R < S+A and R % A = 0 + // 4. x-A < f(x) ≤ x ⇒ S-1 < f(S+A-1) ≤ S+A-1 ⇒ S ≤ f(S+A-1) < S+A + // + // Hence R = f(S+A-1) = (S+(A-1)) & !(A-1) is the desired value + + // Compute A - 1 = 2^P - 1 + let a_minus_1 = self.get_alignment_value().get() - 1; + + // Overflow check: ensure S ≤ 2^B - 2^P = (2^B - 1) - (A - 1) + // so that R < S+A ≤ 2^B and hence it's a valid usize + let bound = usize::MAX - a_minus_1; + assert!( + size.get() <= bound, + "The given size {} exceeded the maximum {}", + size.get(), + bound + ); + + // Overflow never occurs due to the check above + let r = (size.get() + a_minus_1) & !a_minus_1; + + // SAFETY: R ≥ 0 since R ≥ S ≥ 0 + unsafe { NonZeroUsize::new_unchecked(r) } } }