Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sealable-trie: propagate decode error rather than panicking #53

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions common/sealable-trie/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,17 @@ impl<'a, P, S> Node<'a, P, S> {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct BadExtensionKey;
#[derive(Copy, Clone, PartialEq, Eq, Debug, derive_more::Display)]
pub enum DecodeError {
#[display(fmt = "Invalid extension key")]
BadExtensionKey,
#[display(fmt = "Invalid Value node")]
BadValueNode,
#[display(fmt = "Invalid node reference")]
BadNodeRef,
#[display(fmt = "Invalid value reference")]
BadValueRef,
}

impl<'a> Node<'a> {
/// Builds raw representation of given node.
Expand Down Expand Up @@ -311,36 +320,32 @@ impl RawNode {
///
/// In debug builds panics if `node` holds malformed representation, i.e. if
/// any unused bits (which must be cleared) are set.
// TODO(mina86): Convert debug_assertions to the method returning Result.
pub fn decode(&self) -> Node {
pub fn decode(&self) -> Result<Node, DecodeError> {
let (left, right) = self.halfs();
let right = Reference::from_raw(right);
let right = Reference::from_raw(right)?;
// `>> 6` to grab the two most significant bits only.
let tag = self.first() >> 6;
if tag == 0 || tag == 1 {
Ok(if tag == 0 || tag == 1 {
// Branch
Node::Branch { children: [Reference::from_raw(left), right] }
Node::Branch { children: [Reference::from_raw(left)?, right] }
} else if tag == 2 {
// Extension
let key = Slice::decode(left, 0x80).unwrap_or_else(|| {
panic!("Failed decoding raw: {self:?}");
});
let key = Slice::decode(left, 0x80)
.ok_or(DecodeError::BadExtensionKey)?;
Node::Extension { key, child: right }
} else {
// Value
let (num, value) = stdx::split_array_ref::<4, 32, 36>(left);
let num = u32::from_be_bytes(*num);
debug_assert_eq!(
0xC000_0000, num,
"Failed decoding raw node: {self:?}",
);
let value = ValueRef::new((), value.into());
let child = right.try_into().unwrap_or_else(|_| {
debug_assert!(false, "Failed decoding raw node: {self:?}");
NodeRef::new(None, &CryptoHash::DEFAULT)
});
Node::Value { value, child }
}
let (num, hash) = Reference::into_parts(left);
if num != 0xC000_0000 {
return Err(DecodeError::BadValueNode);
}
Node::Value {
value: ValueRef::new((), hash),
child: right
.try_into()
.map_err(|_| DecodeError::BadValueNode)?,
}
})
}

/// Returns the first byte in the raw representation.
Expand Down Expand Up @@ -395,41 +400,39 @@ impl<'a> Reference<'a> {
}
}

/// Parses raw node reference representation into the pointer and hash
/// parts.
///
/// This is an internal helper method which splits the buffer without doing
/// any validation on it.
fn into_parts(bytes: &'a [u8; 36]) -> (u32, &'a CryptoHash) {
let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes);
(u32::from_be_bytes(*ptr), hash.into())
}

/// Parses bytes to form a raw node reference representation.
///
/// Assumes that the bytes are trusted. I.e. doesn’t verify that the most
/// significant bit is zero or that if second bit is one than pointer value
/// must be zero.
///
/// In debug builds, panics if `bytes` has non-canonical representation,
/// i.e. any unused bits are set.
// TODO(mina86): Convert debug_assertions to the method returning Result.
fn from_raw(bytes: &'a [u8; 36]) -> Self {
let (ptr, hash) = stdx::split_array_ref::<4, 32, 36>(bytes);
let ptr = u32::from_be_bytes(*ptr);
let hash = hash.into();
if ptr & 0x4000_0000 == 0 {
// The two most significant bits must be zero.
debug_assert_eq!(
0,
ptr & 0xC000_0000,
"Failed decoding Reference: {bytes:?}"
);
let ptr = Ptr::new_truncated(ptr);
fn from_raw(bytes: &'a [u8; 36]) -> Result<Self, DecodeError> {
let (ptr, hash) = Self::into_parts(bytes);
Ok(if ptr & 0x4000_0000 == 0 {
// The two most significant bits must be zero. Ptr::new fails if
// they aren’t.
let ptr = Ptr::new(ptr).map_err(|_| DecodeError::BadNodeRef)?;
Self::Node(NodeRef { ptr, hash })
} else {
// * The second most significant bit (so 0b4000_0000) is always set.
// * The third most significant bit (so 0b2000_0000) specifies
// whether value is sealed.
// * All other bits are cleared.
debug_assert_eq!(
0x4000_0000,
ptr & !0x2000_0000,
"Failed decoding Reference: {bytes:?}"
);
if ptr & !0x2000_0000 != 0x4000_0000 {
return Err(DecodeError::BadValueRef);
}
let is_sealed = ptr & 0x2000_0000 != 0;
Self::Value(ValueRef { is_sealed, hash })
}
})
}

/// Encodes the node reference into the buffer.
Expand Down
4 changes: 2 additions & 2 deletions common/sealable-trie/src/nodes/stress_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn stress_test_raw_encoding_round_trip() {
let mut raw = RawNode([0; RawNode::SIZE]);
for _ in 0..get_iteration_count(1) {
gen_random_raw_node(&mut rng, &mut raw.0);
let node = raw.decode();
let node = raw.decode().unwrap();
// Test RawNode→Node→RawNode round trip conversion.
assert_eq!(Ok(raw), node.encode(), "node: {node:?}");
}
Expand Down Expand Up @@ -91,7 +91,7 @@ fn stress_test_node_encoding_round_trip() {
let node = gen_random_node(&mut rng, &mut buf);

let raw = super::tests::raw_from_node(&node);
assert_eq!(node, raw.decode(), "Failed decoding Raw: {raw:?}");
assert_eq!(Ok(node), raw.decode(), "Failed decoding Raw: {raw:?}");
}
}

Expand Down
4 changes: 2 additions & 2 deletions common/sealable-trie/src/nodes/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const TWO: CryptoHash = CryptoHash([2; 32]);
pub(super) fn raw_from_node(node: &Node) -> RawNode {
let raw = node.encode().unwrap_or_else(|err| panic!("{err:?}: {node:?}"));
assert_eq!(
*node,
Ok(*node),
raw.decode(),
"Node → RawNode → Node gave different result:\n Raw: {raw:?}"
);
Expand All @@ -44,7 +44,7 @@ pub(super) fn raw_from_node(node: &Node) -> RawNode {
fn check_node_encoding(node: Node, want: [u8; RawNode::SIZE], want_hash: &str) {
let raw = raw_from_node(&node);
assert_eq!(want, raw.0, "Unexpected raw representation");
assert_eq!(node, RawNode(want).decode(), "Bad Raw→Node conversion");
assert_eq!(Ok(node), RawNode(want).decode(), "Bad Raw→Node conversion");

let want_hash = b64decode(want_hash);
assert_eq!(want_hash, node.hash(), "Unexpected hash of {node:?}");
Expand Down
22 changes: 17 additions & 5 deletions common/sealable-trie/src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,20 @@ pub enum Error {
NotFound,
#[display(fmt = "Not enough space")]
OutOfMemory,
#[display(fmt = "Error decoding node: {}", "_0")]
BadRawNode(crate::nodes::DecodeError),
}

impl From<memory::OutOfMemory> for Error {
#[inline]
fn from(_: memory::OutOfMemory) -> Self { Self::OutOfMemory }
}

impl From<crate::nodes::DecodeError> for Error {
#[inline]
fn from(err: crate::nodes::DecodeError) -> Self { Self::BadRawNode(err) }
}

type Result<T, E = Error> = ::core::result::Result<T, E>;
type Value = [u8; crate::nodes::RawNode::SIZE];

Expand Down Expand Up @@ -175,7 +183,7 @@ impl<A: memory::Allocator<Value = Value>> Trie<A> {
let mut node_hash = self.root_hash.clone();
loop {
let node = self.alloc.get(node_ptr.ok_or(Error::Sealed)?);
let node = <&RawNode>::from(node).decode();
let node = <&RawNode>::from(node).decode()?;
debug_assert_eq!(node_hash, node.hash());

let child = match node {
Expand Down Expand Up @@ -324,20 +332,24 @@ impl<A: memory::Allocator<Value = Value>> Trie<A> {
println!(" (sealed)");
return;
};
match <&RawNode>::from(self.alloc.get(ptr)).decode() {
Node::Branch { children } => {
let node = <&RawNode>::from(self.alloc.get(ptr));
match node.decode() {
Ok(Node::Branch { children }) => {
println!(" Branch");
print_ref(children[0], depth + 2);
print_ref(children[1], depth + 2);
}
Node::Extension { key, child } => {
Ok(Node::Extension { key, child }) => {
println!(" Extension {key}");
print_ref(child, depth + 2);
}
Node::Value { value, child } => {
Ok(Node::Value { value, child }) => {
println!(" Value {}", value.hash);
print_ref(Reference::from(child), depth + 2);
}
Err(err) => {
println!(" BadRawNode: {err}: {node:?}");
}
}
}
}
Expand Down
43 changes: 26 additions & 17 deletions common/sealable-trie/src/trie/del.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
fn handle(&mut self, nref: NodeRef, from_ext: bool) -> Result<Action> {
let ptr = nref.ptr.ok_or(Error::Sealed)?;
let node = RawNode(*self.wlog.allocator().get(ptr));
let node = node.decode();
let node = node.decode()?;
debug_assert_eq!(*nref.hash, node.hash());

match node {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
(children[0], child)
};
let node = RawNode::branch(left, right);
return Ok(Action::Ref(self.set_node(ptr, node)));
return self.set_node(ptr, node).map(Action::Ref);
}

// The child has been deleted. We need to convert this Branch into an
Expand All @@ -114,7 +114,7 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
Ok(self
.maybe_pop_extension(child, &|key| {
bits::Owned::unshift(side == 0, key).unwrap()
})
})?
.unwrap_or_else(|| {
Action::Ext(
bits::Owned::bit(side == 0, key_offset),
Expand Down Expand Up @@ -162,7 +162,7 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
let action = self
.maybe_pop_extension(Reference::Node(child), &|key| {
key.into()
});
})?;
if let Some(action) = action {
return Ok(action);
}
Expand All @@ -172,21 +172,29 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {

// Traverse into the child and handle that.
let action = self.handle(child, false)?;
Ok(match self.ref_from_action(action)? {
match self.ref_from_action(action)? {
None => {
// We’re deleting the child which means we need to delete the
// Value node and replace parent’s reference to ValueRef.
self.del_node(ptr);
let value = ValueRef::new(false, value.hash);
Action::Ref(value.into())
Ok(Action::Ref(value.into()))
}
Some(OwnedRef::Node(child_ptr, hash)) => {
let child = NodeRef::new(child_ptr, &hash);
let node = RawNode::value(value, child);
Action::Ref(self.set_node(ptr, node))
self.set_node(ptr, node).map(Action::Ref)
}
Some(OwnedRef::Value(..)) => unreachable!(),
})
Some(OwnedRef::Value(..)) => {
// The only possible way we’ve reached here if the self.handle
// call above recursively called self.handle_value (since this
// method is the only one which may Value references). But if
// that happens, it means that we had a Value node whose child
// was another Value node. This is an invalid trie (since Value
// may only point at Branch or Extension) so we report an error.
Err(Error::BadRawNode(crate::nodes::DecodeError::BadValueNode))
}
}
}

/// If `child` is a node reference pointing at an Extension node, pops that
Expand All @@ -195,27 +203,28 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
&mut self,
child: Reference,
make_key: &dyn Fn(bits::Slice) -> bits::Owned,
) -> Option<Action> {
) -> Result<Option<Action>> {
if let Reference::Node(NodeRef { ptr: Some(ptr), hash }) = child {
let node = RawNode(*self.wlog.allocator().get(ptr));
let node = node.decode();
let node = node.decode()?;
debug_assert_eq!(*hash, node.hash());

if let Node::Extension { key, child } = node {
// Drop the child Extension and merge keys.
self.del_node(ptr);
return Some(Action::Ext(make_key(key), OwnedRef::from(child)));
let action = Action::Ext(make_key(key), OwnedRef::from(child));
return Ok(Some(action));
}
}
None
Ok(None)
}

/// Sets value of a node cell at given address and returns an [`OwnedRef`]
/// pointing at the node.
fn set_node(&mut self, ptr: Ptr, node: RawNode) -> OwnedRef {
let hash = node.decode().hash();
fn set_node(&mut self, ptr: Ptr, node: RawNode) -> Result<OwnedRef> {
let hash = node.decode()?.hash();
self.wlog.set(ptr, *node);
OwnedRef::Node(Some(ptr), hash)
Ok(OwnedRef::Node(Some(ptr), hash))
}

/// Frees a node.
Expand All @@ -236,7 +245,7 @@ impl<'a, A: memory::Allocator<Value = super::Value>> Context<'a, A> {
for chunk in key.as_slice().chunks().rev() {
let node = RawNode::extension(chunk, child.to_ref()).unwrap();
let ptr = self.wlog.alloc(node.0)?;
child = OwnedRef::Node(Some(ptr), node.decode().hash());
child = OwnedRef::Node(Some(ptr), node.decode()?.hash());
}

Ok(Some(child))
Expand Down
Loading
Loading