diff --git a/common/sealable-trie/src/nodes.rs b/common/sealable-trie/src/nodes.rs index 29a09069..e714f9fd 100644 --- a/common/sealable-trie/src/nodes.rs +++ b/common/sealable-trie/src/nodes.rs @@ -208,11 +208,13 @@ impl<'a, P, S> Node<'a, P, S> { // tag = 0b100v_0000 where v indicates whether the child is // a value reference. let tag = 0x80 | (u8::from(child.is_value()) << 4); - // XXX if let Ok(len) = key.encode_into(key_buf, tag) { buf[len..len + 32].copy_from_slice(child.hash().as_slice()); len + 32 } else { + // We support Extension nodes with invalid keys so this + // method is infallible. It’s easier to handle it than have + // callers deal with errors. return hash_extension_slow_path(*key, child); } } diff --git a/common/sealable-trie/src/nodes/tests.rs b/common/sealable-trie/src/nodes/tests.rs index 99b750da..af13aa71 100644 --- a/common/sealable-trie/src/nodes/tests.rs +++ b/common/sealable-trie/src/nodes/tests.rs @@ -1,5 +1,3 @@ -use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; -use base64::Engine; use lib::hash::CryptoHash; use memory::Ptr; use pretty_assertions::assert_eq; @@ -48,17 +46,24 @@ fn check_node_encoding(node: Node, want: [u8; RawNode::SIZE], want_hash: &str) { assert_eq!(want, raw.0, "Unexpected raw representation"); assert_eq!(node, RawNode(want).decode(), "Bad Raw→Node conversion"); - let want_hash = BASE64_ENGINE.decode(want_hash).unwrap(); - let want_hash = <&[u8; 32]>::try_from(want_hash.as_slice()).unwrap(); - let want_hash = CryptoHash::from(*want_hash); + let want_hash = b64decode(want_hash); assert_eq!(want_hash, node.hash(), "Unexpected hash of {node:?}"); - if let Node::Extension { key, child } = node { let got = super::hash_extension_slow_path(key, &child); assert_eq!(want_hash, got, "Unexpected slow path hash of {node:?}"); } } +/// Decodes base64-encoded CryptoHash; panics on error. +fn b64decode(hash: &str) -> CryptoHash { + use base64::engine::general_purpose::STANDARD as BASE64_ENGINE; + use base64::Engine; + + let hash = BASE64_ENGINE.decode(hash).unwrap(); + let hash = <&[u8; 32]>::try_from(hash.as_slice()).unwrap(); + CryptoHash::from(*hash) +} + #[test] #[rustfmt::skip] fn test_branch_encoding() { @@ -214,6 +219,38 @@ fn test_extension_encoding() { ], "uU9GlH+fEQAnezn3HWuvo/ZSBIhuSkuE2IGjhUFdC04="); } +#[test] +fn test_extension_hash_bad_key() { + let empty_key = bits::Slice::new(&[], 0, 0).unwrap(); + let long_key = bits::Slice::new(&[0; 53], 2, 420).unwrap(); + for (child, sealed, if_empty, if_long) in [ + ( + Reference::node(Some(DEAD), &ONE), + Reference::node(None, &ONE), + b64decode("NXOo9QBg+AbJSM/zh4Rikg8R5otOByKUJfiWhKUiZ5Y="), + b64decode("Q/Er5mIa2gsawPOfF+Q/2XO5l29WZyknw/kyI53tGXo="), + ), + ( + Reference::value(false, &ONE), + Reference::value(true, &ONE), + b64decode("LtVLGf1mBecG3z2Lcq1IXOGo/zQpGxWbXN977zUZMpI="), + b64decode("BMi90/Oois7h3CPNz6y8BB/9agz2mkAePLFQt2hdluM="), + ), + ] { + assert_eq!(if_empty, Node::Extension { key: empty_key, child }.hash()); + assert_eq!(if_long, Node::Extension { key: long_key, child }.hash()); + + assert_eq!( + if_empty, + Node::Extension { key: empty_key, child: sealed }.hash() + ); + assert_eq!( + if_long, + Node::Extension { key: long_key, child: sealed }.hash() + ); + } +} + #[test] #[rustfmt::skip] fn test_value_encoding() {