Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Fix proof direction and key bit order #22

Merged
merged 9 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
// Returns an array of sibling nodes, the leaf hash found at that path and the
// leaf data. If the leaf is a placeholder, the leaf data is nil.
func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byte, []byte, []byte, error) {
var sideNodes [][]byte
// Side nodes for the path. Nodes are inserted in reverse order, then the
// slice is reversed at the end.
sideNodes := make([][]byte, 0, smt.depth())

if bytes.Equal(root, smt.th.placeholder()) {
// If the root is a placeholder, there are no sidenodes to return.
Expand All @@ -345,20 +347,16 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt

// Get sidenode depending on whether the path bit is on or off.
if getBitAtFromMSB(path, i) == right {
sideNodes = append(sideNodes, nil)
copy(sideNodes[1:], sideNodes)
sideNodes[0] = leftNode
sideNodes = append(sideNodes, leftNode)
nodeHash = rightNode
} else {
sideNodes = append(sideNodes, nil)
copy(sideNodes[1:], sideNodes)
sideNodes[0] = rightNode
sideNodes = append(sideNodes, rightNode)
nodeHash = leftNode
}

if bytes.Equal(nodeHash, smt.th.placeholder()) {
// If the node is a placeholder, we've reached the end.
return sideNodes, nodeHash, nil, nil
return reverseSideNodes(sideNodes), nodeHash, nil, nil
}

currentData, err = smt.ms.Get(nodeHash)
Expand All @@ -370,7 +368,7 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt
}
}

return sideNodes, nodeHash, currentData, err
return reverseSideNodes(sideNodes), nodeHash, currentData, err
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
}

// Prove generates a Merkle proof for a key.
Expand Down
18 changes: 13 additions & 5 deletions smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,17 @@ func TestSparseMerkleTreeKnown(t *testing.T) {
}
}

// Test tree operations when two leafs are immediate neighbours.
// Test tree operations when two leafs are immediate neighbors.
func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
h := newDummyHasher(sha256.New())
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, h)
var value []byte
var err error

// Make two neighbouring keys.
// Make two neighboring keys.
//
// The dummy hash function excepts keys to prefixed with four bytes of 0,
// The dummy hash function expects keys to prefixed with four bytes of 0,
// which will cause it to return the preimage itself as the digest, without
// the first four bytes.
key1 := make([]byte, h.Size()+4)
Expand All @@ -297,7 +297,8 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
key1[h.Size()+4-1] = byte(0)
key2 := make([]byte, h.Size()+4)
copy(key2, key1)
setBitAtFromMSB(key2, (h.Size()+4)*8-1)
// We make key2's least significant bit different than key1's
key2[h.Size()+4-1] = byte(1)

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
Expand All @@ -315,14 +316,21 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
if !bytes.Equal([]byte("testValue1"), value) {
t.Error("did not get correct value when getting non-empty key")
}

value, err = smt.Get(key2)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}

proof, err := smt.Prove(key1)
if err != nil {
t.Errorf("returned error when proving key: %v", err)
}
if len(proof.SideNodes) != 256 {
t.Errorf("unexpected proof size")
}
}

// Test base case tree delete operations with a few keys.
Expand Down
8 changes: 8 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@ func emptyBytes(length int) []byte {
b := make([]byte, length)
return b
}

func reverseSideNodes(sideNodes [][]byte) [][]byte {
for left, right := 0, len(sideNodes)-1; left < right; left, right = left+1, right-1 {
sideNodes[left], sideNodes[right] = sideNodes[right], sideNodes[left]
}

return sideNodes
}