From 61c9244abdb904902668c4e0961c1a36fb0faeee Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Mon, 23 Oct 2023 12:35:33 -0700 Subject: [PATCH] Implement wrapping_xxx; fix Shl/Shr semantics in debug builds - Implement `wrapping_add`, `wrapping_sub`, `wrapping_mul`, `wrapping_div`, `wrapping_shl`, `wrapping_shr` - In debug builds, `<<` (`Shl`, `ShlAssign`) and `>>` (`Shr`, `ShrAssign`) now bounds-check the shift amount using the same semantics as built-in shifts. For example, shifting a u5 by 5 or more bits will now panic as expected. --- CHANGELOG.md | 5 +++ src/lib.rs | 84 ++++++++++++++++++++++++++++++++++++ tests/tests.rs | 114 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 203 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e526ad..05016b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ - Support `Step` so that arbitrary-int can be used in a range expression, e.g. `for n in u3::MIN..=u3::MAX { println!("{n}") }`. Note this trait is currently unstable, and so is only usable in nightly. Enable this feature with `step_trait`. - Support formatting via [defmt](https://crates.io/crates/defmt). Enable the option `defmt` feature - Support serializing and deserializing via [serde](https://crates.io/crates/serde). Enable the option `serde` feature +- Implement `Mul`, `MulAssign`, `Div`, `DivAssign` +- Implement `wrapping_add`, `wrapping_sub`, `wrapping_mul`, `wrapping_div`, `wrapping_shl`, `wrapping_shr` + +### Changed +- In debug builds, `<<` (`Shl`, `ShlAssign`) and `>>` (`Shr`, `ShrAssign`) now bounds-check the shift amount using the same semantics as built-in shifts. For example, shifting a u5 by 5 or more bits will now panic as expected. ## arbitrary-int 1.2.6 diff --git a/src/lib.rs b/src/lib.rs index 63ebb27..541bc28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -340,6 +340,61 @@ macro_rules! uint_impl { UInt::<$type, BITS_RESULT> { value: self.value } } + pub const fn wrapping_add(&self, rhs: Self) -> Self { + let sum = self.value.wrapping_add(rhs.value); + Self { + value: sum & Self::MASK, + } + } + + pub const fn wrapping_sub(&self, rhs: Self) -> Self { + let sum = self.value.wrapping_sub(rhs.value); + Self { + value: sum & Self::MASK, + } + } + + pub const fn wrapping_mul(&self, rhs: Self) -> Self { + let sum = self.value.wrapping_mul(rhs.value); + Self { + value: sum & Self::MASK, + } + } + + pub const fn wrapping_div(&self, rhs: Self) -> Self { + let sum = self.value.wrapping_div(rhs.value); + Self { + // No need to mask here - divisions always produce a result that is <= self + value: sum, + } + } + + pub const fn wrapping_shl(&self, rhs: u32) -> Self { + // modulo is expensive on some platforms, so only do it when necessary + let shift_amount = if rhs >= (BITS as u32) { + rhs % (BITS as u32) + } else { + rhs + }; + + Self { + value: (self.value << shift_amount) & Self::MASK, + } + } + + pub const fn wrapping_shr(&self, rhs: u32) -> Self { + // modulo is expensive on some platforms, so only do it when necessary + let shift_amount = if rhs >= (BITS as u32) { + rhs % (BITS as u32) + } else { + rhs + }; + + Self { + value: (self.value >> shift_amount) & Self::MASK, + } + } + /// Reverses the order of bits in the integer. The least significant bit becomes the most significant bit, second least-significant bit becomes second most-significant bit, etc. pub const fn reverse_bits(self) -> Self { let shift_right = (core::mem::size_of::<$type>() << 3) - BITS; @@ -699,10 +754,18 @@ where + Shl + Shr + From, + TSHIFTBITS: TryInto + Copy, { type Output = UInt; fn shl(self, rhs: TSHIFTBITS) -> Self::Output { + // With debug assertions, the << and >> operators throw an exception if the shift amount + // is larger than the number of bits (in which case the result would always be 0) + #[cfg(debug_assertions)] + if rhs.try_into().unwrap_or(usize::MAX) >= BITS { + panic!("attempt to shift left with overflow") + } + Self { value: (self.value << rhs) & Self::MASK, } @@ -720,8 +783,15 @@ where + Shr + Shl + From, + TSHIFTBITS: TryInto + Copy, { fn shl_assign(&mut self, rhs: TSHIFTBITS) { + // With debug assertions, the << and >> operators throw an exception if the shift amount + // is larger than the number of bits (in which case the result would always be 0) + #[cfg(debug_assertions)] + if rhs.try_into().unwrap_or(usize::MAX) >= BITS { + panic!("attempt to shift left with overflow") + } self.value <<= rhs; self.value &= Self::MASK; } @@ -730,10 +800,17 @@ where impl Shr for UInt where T: Copy + Shr + Sub + Shl + From, + TSHIFTBITS: TryInto + Copy, { type Output = UInt; fn shr(self, rhs: TSHIFTBITS) -> Self::Output { + // With debug assertions, the << and >> operators throw an exception if the shift amount + // is larger than the number of bits (in which case the result would always be 0) + #[cfg(debug_assertions)] + if rhs.try_into().unwrap_or(usize::MAX) >= BITS { + panic!("attempt to shift left with overflow") + } Self { value: self.value >> rhs, } @@ -743,8 +820,15 @@ where impl ShrAssign for UInt where T: Copy + ShrAssign + Sub + Shl + From, + TSHIFTBITS: TryInto + Copy, { fn shr_assign(&mut self, rhs: TSHIFTBITS) { + // With debug assertions, the << and >> operators throw an exception if the shift amount + // is larger than the number of bits (in which case the result would always be 0) + #[cfg(debug_assertions)] + if rhs.try_into().unwrap_or(usize::MAX) >= BITS { + panic!("attempt to shift left with overflow") + } self.value >>= rhs; } } diff --git a/tests/tests.rs b/tests/tests.rs index fa1df3d..57ed664 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -338,6 +338,48 @@ fn shl() { assert_eq!(u9::new(0b11110000) << 3u64, u9::new(0b1_10000000)); } +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much8() { + let _ = u53::new(123) << 53u8; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much16() { + let _ = u53::new(123) << 53u16; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much32() { + let _ = u53::new(123) << 53u32; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much64() { + let _ = u53::new(123) << 53u64; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much128() { + let _ = u53::new(123) << 53u128; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shl_too_much_usize() { + let _ = u53::new(123) << 53usize; +} + #[test] fn shlassign() { let mut value = u9::new(0b11110000); @@ -345,6 +387,22 @@ fn shlassign() { assert_eq!(value, u9::new(0b1_10000000)); } +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shlassign_too_much() { + let mut value = u9::new(0b11110000); + value <<= 9; +} + +#[cfg(debug_assertions)] +#[test] +#[should_panic] +fn shlassign_too_much2() { + let mut value = u9::new(0b11110000); + value <<= 10; +} + #[test] fn shr() { assert_eq!(u17::new(0b100110) >> 5usize, u17::new(1)); @@ -1310,6 +1368,62 @@ fn simple_le_be() { } } +#[test] +fn wrapping_add() { + assert_eq!(u7::new(120).wrapping_add(u7::new(1)), u7::new(121)); + assert_eq!(u7::new(120).wrapping_add(u7::new(10)), u7::new(2)); + assert_eq!(u7::new(127).wrapping_add(u7::new(127)), u7::new(126)); +} + +#[test] +fn wrapping_sub() { + assert_eq!(u7::new(120).wrapping_sub(u7::new(1)), u7::new(119)); + assert_eq!(u7::new(10).wrapping_sub(u7::new(20)), u7::new(118)); + assert_eq!(u7::new(0).wrapping_sub(u7::new(1)), u7::new(127)); +} + +#[test] +fn wrapping_mul() { + assert_eq!(u7::new(120).wrapping_mul(u7::new(0)), u7::new(0)); + assert_eq!(u7::new(120).wrapping_mul(u7::new(1)), u7::new(120)); + + // Overflow u7 + assert_eq!(u7::new(120).wrapping_mul(u7::new(2)), u7::new(112)); + + // Overflow the underlying type + assert_eq!(u7::new(120).wrapping_mul(u7::new(3)), u7::new(104)); +} + +#[test] +fn wrapping_div() { + assert_eq!(u7::new(120).wrapping_div(u7::new(1)), u7::new(120)); + assert_eq!(u7::new(120).wrapping_div(u7::new(2)), u7::new(60)); + assert_eq!(u7::new(120).wrapping_div(u7::new(120)), u7::new(1)); + assert_eq!(u7::new(120).wrapping_div(u7::new(121)), u7::new(0)); +} + +#[test] +fn wrapping_shl() { + assert_eq!(u7::new(0b010_1101).wrapping_shl(0), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(1), u7::new(0b101_1010)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(6), u7::new(0b100_0000)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(7), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(8), u7::new(0b101_1010)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(14), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shl(15), u7::new(0b101_1010)); +} + +#[test] +fn wrapping_shr() { + assert_eq!(u7::new(0b010_1101).wrapping_shr(0), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(1), u7::new(0b001_0110)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(5), u7::new(0b000_0001)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(7), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(8), u7::new(0b001_0110)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(14), u7::new(0b010_1101)); + assert_eq!(u7::new(0b010_1101).wrapping_shr(15), u7::new(0b001_0110)); +} + #[test] fn reverse_bits() { const A: u5 = u5::new(0b11101);