diff --git a/common/sealable-trie/src/bits.rs b/common/sealable-trie/src/bits.rs index ebeabe96..f87bbdd6 100644 --- a/common/sealable-trie/src/bits.rs +++ b/common/sealable-trie/src/bits.rs @@ -521,7 +521,7 @@ impl Owned { /// Prepends given slice by a specified bit. /// - /// Returns `None` if length (in bits) of the resulting slice would exceed + /// Panics if length (in bits) of the resulting slice would exceed /// `u16::MAX`. /// /// ## Example @@ -531,19 +531,20 @@ impl Owned { /// # use lib::u3::U3; /// /// let suffix = Slice::new(&[255], U3::_1, 5).unwrap(); - /// let got = Owned::unshift(false, suffix).unwrap(); + /// let got = Owned::unshift(false, suffix); /// assert_eq!(Slice::new(&[124], U3::_0, 6).unwrap(), got); /// /// let suffix = Slice::new(&[255], U3::_1, 5).unwrap(); - /// let got = Owned::unshift(true, suffix).unwrap(); + /// let got = Owned::unshift(true, suffix); /// assert_eq!(Slice::new(&[252], U3::_0, 6).unwrap(), got); /// /// let suffix = Slice::new(&[255], U3::_0, 5).unwrap(); - /// let got = Owned::unshift(true, suffix).unwrap(); + /// let got = Owned::unshift(true, suffix); /// assert_eq!(Slice::new(&[255, 255], U3::_7, 6).unwrap(), got); /// ``` - pub fn unshift(bit: bool, suffix: Slice) -> Option { - let length = suffix.length.checked_add(1)?; + // TODO(mina86): Add consistent handling of length > u16::MAX. + pub fn unshift(bit: bool, suffix: Slice) -> Self { + let length = suffix.length.checked_add(1).unwrap(); let (bytes, offset) = if suffix.is_empty() { let offset = suffix.offset.wrapping_dec(); let bytes = alloc::vec![255 * u8::from(bit)]; @@ -558,7 +559,99 @@ impl Owned { let bytes = [core::slice::from_ref(&bit), suffix.bytes()].concat(); (bytes, U3::MAX) }; - Some(Self { bytes, offset, length }) + Self { bytes, offset, length } + } + + /// Append given bit to the slice. + /// + /// Returns `None` if length (in bits) of the resulting slice would exceed + /// `u16::MAX`. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits::{Owned, Slice}; + /// # use lib::u3::U3; + /// + /// let bits = Slice::new(&[0b_0100_1101], U3::_1, 5).unwrap(); + /// let mut bits = Owned::from(bits); + /// + /// bits.push(true); + /// assert_eq!(Slice::new(&[0b_0100_1110], U3::_1, 6).unwrap(), bits); + /// + /// bits.push(false); + /// assert_eq!(Slice::new(&[0b_0100_1110], U3::_1, 7).unwrap(), bits); + /// + /// bits.push(true); + /// assert_eq!(Slice::new(&[0b_0100_1110, 0x80], U3::_1, 8).unwrap(), bits); + /// ``` + // TODO(mina86): Add consistent handling of length > u16::MAX. + pub fn push(&mut self, bit: bool) { + let off = self.underlying_bits_length() % 8; + self.length = self.length.checked_add(1).unwrap(); + let mask = 0x80 >> off; + match self.bytes.last_mut() { + Some(byte) if off != 0 => { + // If self.bytes is non-empty and we’re not adding msb of a new + // byte (i.e. off != 0), modify the last byte. + *byte = (*byte & !mask) | (mask * u8::from(bit)); + } + _ => { + // Otherwise, either self.bytes is empty (and thus we’re adding + // a new byte with given bit set) or we’re aligned at the byte + // boundary (and we’re adding a new byte with msb set). + self.bytes.push(mask * u8::from(bit)); + } + } + } + + /// Returns the last bit in the slice shrinking the slice by one bit. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits::{Owned, Slice}; + /// # use lib::u3::U3; + /// + /// let slice = Slice::new(&[0x60], U3::_0, 3).unwrap(); + /// let mut bits = Owned::from(slice); + /// assert_eq!(Some(true), bits.pop_back()); + /// assert_eq!(Some(true), bits.pop_back()); + /// assert_eq!(Some(false), bits.pop_back()); + /// assert_eq!(None, bits.pop_back()); + /// ``` + pub fn pop_back(&mut self) -> Option { + self.length = self.length.checked_sub(1)?; + let off = self.underlying_bits_length() % 8; + let bit = *self.bytes.last().unwrap() & (0x80 >> off); + if off == 0 { + self.bytes.pop(); + } + Some(bit != 0) + } + + /// Sets the last bit in the slice; panics if slice is empty. + /// + /// ## Example + /// + /// ``` + /// # use sealable_trie::bits::{Owned, Slice}; + /// # use lib::u3::U3; + /// + /// let slice = Slice::new(&[0x0f], U3::_4, 4).unwrap(); + /// let mut bits = Owned::from(slice); + /// + /// bits.set_last(false); + /// assert_eq!(Slice::new(&[0x0e], U3::_4, 4).unwrap(), bits); + /// + /// bits.set_last(true); + /// assert_eq!(Slice::new(&[0x0f], U3::_4, 4).unwrap(), bits); + /// ``` + pub fn set_last(&mut self, bit: bool) { + let bits = self.underlying_bits_length(); + let last = self.bytes.last_mut().unwrap(); + let mask = 0x80 >> ((bits - 1) % 8); + *last = (*last & !mask) | (mask * u8::from(bit)); } /// Concatenates a [`Slice`] with [`Owned`]. @@ -582,6 +675,7 @@ impl Owned { /// let got = Owned::concat(prefix, suffix).unwrap(); /// assert_eq!(Slice::new(&[0, 126], U3::_6, 9).unwrap(), got); /// ``` + // TODO(mina86): Add consistent handling of length > u16::MAX. pub fn concat( prefix: Slice, suffix: Slice, @@ -625,6 +719,12 @@ impl Owned { phantom: Default::default(), } } + + /// Returns total number of underlying bits, i.e. bits in the slice plus the + /// offset. + fn underlying_bits_length(&self) -> usize { + usize::from(self.offset) + usize::from(self.length) + } } impl core::cmp::PartialEq for Slice<'_> { @@ -908,8 +1008,7 @@ fn test_pop() { #[test] fn test_owned_unshift() { - for offset in 0..7u8 { - let offset = U3::try_from(offset).unwrap(); + for offset in U3::all() { let slice = Slice::new(&[255], offset, 1).unwrap(); let want = offset .checked_dec() @@ -918,8 +1017,38 @@ fn test_owned_unshift() { |offset| Slice::new(&[255], offset, 2), ) .unwrap(); - let got = Owned::unshift(true, slice).unwrap(); - assert_eq!(want, got, "offset: {offset}"); + assert_eq!(want, Owned::unshift(true, slice), "offset: {offset}"); + } +} + +#[test] +fn test_owned_push() { + let mut bits = Owned::from(Slice::new(&[255], U3::_1, 1).unwrap()); + + let mut push = |bit, want| { + let want = Slice::new(want, U3::_1, bits.length + 1).unwrap(); + bits.push(bit != 0); + assert_eq!(want, bits); + }; + + push(1, &[0b_0110_0000]); + push(1, &[0b_0111_0000]); + push(0, &[0b_0111_0000]); + push(0, &[0b_0111_0000]); + push(1, &[0b_0111_0010]); + push(1, &[0b_0111_0011]); + push(1, &[0b_0111_0011, 0b_1000_0000]); +} + +#[test] +fn test_owned_push_from_empty() { + for offset in U3::all() { + let mut bits = Owned::from(Slice::new(&[], offset, 0).unwrap()); + for length in 1..=16 { + let want = Slice::new(&[255, 255, 255], offset, length).unwrap(); + bits.push(true); + assert_eq!(want, bits); + } } } diff --git a/common/sealable-trie/src/trie/del.rs b/common/sealable-trie/src/trie/del.rs index 20066bcc..90b45e39 100644 --- a/common/sealable-trie/src/trie/del.rs +++ b/common/sealable-trie/src/trie/del.rs @@ -113,7 +113,7 @@ impl<'a, A: memory::Allocator> Context<'a, A> { let child = children[1 - side]; Ok(self .maybe_pop_extension(child, &|key| { - bits::Owned::unshift(side == 0, key.into()).unwrap() + bits::Owned::unshift(side == 0, key.into()) })? .unwrap_or_else(|| { Action::Ext(