Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cache and traverse nmt sub tree roots #549

Merged
merged 25 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e4a450a
initial sub tree root traversal code
evan-forbes Jul 15, 2022
e4c2da5
use nmt wrapper when generating commitments
evan-forbes Jul 15, 2022
339370e
typo
evan-forbes Jul 15, 2022
8dc9003
fix doc typos
evan-forbes Jul 15, 2022
ee525f3
remove unused testutil code
evan-forbes Jul 15, 2022
1684c44
update hardcoded test
evan-forbes Jul 15, 2022
7fddd80
fix docs left <-> right
evan-forbes Jul 20, 2022
30af4bc
fix docs left <-> right
evan-forbes Jul 20, 2022
2e300fc
fix typo
evan-forbes Jul 20, 2022
c177384
fix typo
evan-forbes Jul 20, 2022
b335f2e
chore: move power of two code to an exported util package
evan-forbes Jul 27, 2022
d68a62e
add subtree root code
evan-forbes Jul 29, 2022
e0d45e0
Revert "chore: move power of two code to an exported util package"
evan-forbes Jul 29, 2022
c006232
add docs
evan-forbes Aug 17, 2022
aa55322
Merge branch 'main' into evan/msg-inclusion-api
evan-forbes Aug 17, 2022
2976947
move path code to a different PR
evan-forbes Aug 17, 2022
492597c
use proper name for nmt node visitor in docs
evan-forbes Aug 17, 2022
568e45d
consistent naming
evan-forbes Aug 17, 2022
770e06a
consistent name
evan-forbes Aug 17, 2022
e2a15f7
use normal error message formatting and not consts
evan-forbes Aug 17, 2022
38973a0
fix comment
evan-forbes Aug 17, 2022
fe7efd6
Merge branch 'main' into evan/msg-inclusion-api
evan-forbes Aug 17, 2022
0c8f884
Merge branch 'main' into evan/msg-inclusion-api
evan-forbes Aug 18, 2022
1518b82
remove comment
evan-forbes Sep 2, 2022
b6745f6
PR feedback
evan-forbes Sep 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions pkg/inclusion/nmt_caching.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package inclusion

import (
"fmt"

"github.com/celestiaorg/nmt"
"github.com/celestiaorg/rsmt2d"
"github.com/tendermint/tendermint/pkg/da"
"github.com/tendermint/tendermint/pkg/wrapper"
)

// WalkInstruction wraps the bool type to indicate the direction that should be
// used while traversing a binary tree
type WalkInstruction bool

const (
WalkLeft = false
WalkRight = true
)

// subTreeRootCacher keep track of all the inner nodes of an nmt using a simple
// map. Note: this cacher does not cache individual leaves or their hashes, only
// inner nodes.
type subTreeRootCacher struct {
cache map[string][2]string
}

func newSubTreeRootCacher() *subTreeRootCacher {
return &subTreeRootCacher{cache: make(map[string][2]string)}
}

// Visit fullfills the nmt.NodeVisitorFn function definition. It stores each inner
// node in a simple map, which can later be used to walk the tree. This function
// is called by the nmt when calculating the root.
func (strc *subTreeRootCacher) Visit(hash []byte, children ...[]byte) {
switch len(children) {
case 2:
strc.cache[string(hash)] = [2]string{string(children[0]), string(children[1])}
case 1:
return
default:
panic("unexpected visit")
}
}

// walk recursively traverses the subTreeRootCacher's internal tree by using the
// provided sub tree root and path. The provided path should be a []bool, false
// indicating that the first child node (left most node) should be used to find
// the next path, and the true indicating that the second (right) should be used.
// walk throws an error if the sub tree cannot be found.
func (strc subTreeRootCacher) walk(root []byte, path []WalkInstruction) ([]byte, error) {
// return if we've reached the end of the path
if len(path) == 0 {
return root, nil
}
// try to lookup the provided sub root
children, has := strc.cache[string(root)]
if !has {
// note: we might want to consider panicing here
return nil, fmt.Errorf("did not find sub tree root: %v", root)
}

// continue to traverse the tree by recursively calling this function on the next root
switch path[0] {
case WalkLeft:
return strc.walk([]byte(children[0]), path[1:])
case WalkRight:
return strc.walk([]byte(children[1]), path[1:])
default:
// this is unreachable code, but the compiler doesn't recognize this somehow
panic("bool other than true or false, computers were a mistake, everything is a lie, math is fake.")
}
}

// EDSSubTreeRootCacher caches the inner nodes for each row so that we can
// traverse it later to check for message inclusion. NOTE: Currently this has to
// use a leaky abstraction (see docs on counter field below), and is not
// threadsafe, but with a future refactor, we could simply read from rsmt2d and
// not use the tree constructor which would fix both of these issues.
type EDSSubTreeRootCacher struct {
caches []*subTreeRootCacher
squareSize uint64
// counter is used to ignore columns NOTE: this is a leaky abstraction that
// we make because rsmt2d is used to generate the roots for us, so we have
// to assume that it will generate a row root every other tree contructed.
// This is also one of the reasons this implementation is not thread safe.
// Please see note above on a better refactor.
counter int
}

func NewCachedSubtreeCacher(squareSize uint64) *EDSSubTreeRootCacher {
return &EDSSubTreeRootCacher{
caches: []*subTreeRootCacher{},
squareSize: squareSize,
}
}

// Constructor fullfills the rsmt2d.TreeCreatorFn by keeping a pointer to the
// cache and embedding it as a nmt.NodeVisitor into a new wrapped nmt.
func (stc *EDSSubTreeRootCacher) Constructor() rsmt2d.Tree {
// see docs of counter field for more
// info. if the counter is even or == 0, then we make the assumption that we
// are creating a tree for a row
var newTree wrapper.ErasuredNamespacedMerkleTree
switch stc.counter % 2 {
case 0:
strc := newSubTreeRootCacher()
stc.caches = append(stc.caches, strc)
newTree = wrapper.NewErasuredNamespacedMerkleTree(stc.squareSize, nmt.NodeVisitor(strc.Visit))
default:
newTree = wrapper.NewErasuredNamespacedMerkleTree(stc.squareSize)
}

stc.counter++
return &newTree
}

// GetSubTreeRoot traverses the nmt of the selected row and returns the
// subtree root. An error is thrown if the subtree cannot be found.
func (stc *EDSSubTreeRootCacher) GetSubTreeRoot(dah da.DataAvailabilityHeader, row int, path []WalkInstruction) ([]byte, error) {
if len(stc.caches) != len(dah.RowsRoots) {
return nil, fmt.Errorf("data availability header has unexpected number of row roots: expected %d got %d", len(stc.caches), len(dah.RowsRoots))
}
if row >= len(stc.caches) {
return nil, fmt.Errorf("row exceeds range of cache: max %d got %d", len(stc.caches), row)
}
return stc.caches[row].walk(dah.RowsRoots[row], path)
}
184 changes: 184 additions & 0 deletions pkg/inclusion/nmt_caching_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package inclusion

import (
"testing"

"github.com/celestiaorg/celestia-app/testutil/coretestutil"
"github.com/celestiaorg/nmt"
"github.com/celestiaorg/rsmt2d"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/pkg/consts"
"github.com/tendermint/tendermint/pkg/da"
"github.com/tendermint/tendermint/pkg/wrapper"
)

func TestWalkCachedSubTreeRoot(t *testing.T) {
// create the first main tree
strc := newSubTreeRootCacher()
oss := uint64(8)
tr := wrapper.NewErasuredNamespacedMerkleTree(oss, nmt.NodeVisitor(strc.Visit))
d := []byte{0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8}
for i := 0; i < 8; i++ {
tr.Push(d, rsmt2d.SquareIndex{
Axis: uint(rsmt2d.Row),
Cell: uint(i),
})
}
highestRoot := tr.Root()

// create a short sub tree
shortSubTree := wrapper.NewErasuredNamespacedMerkleTree(oss)
for i := 0; i < 2; i++ {
shortSubTree.Push(d, rsmt2d.SquareIndex{
Axis: uint(rsmt2d.Row),
Cell: uint(i),
})
}
shortSTR := shortSubTree.Root()

// create a tall sub tree root
tallSubTree := wrapper.NewErasuredNamespacedMerkleTree(oss)
for i := 0; i < 4; i++ {
tallSubTree.Push(d, rsmt2d.SquareIndex{
Axis: uint(rsmt2d.Row),
Cell: uint(i),
})
}
tallSTR := tallSubTree.Root()

type test struct {
name string
path []WalkInstruction
subTreeRoot []byte
expectedError string
}

tests := []test{
{
"left most short sub tree",
[]WalkInstruction{WalkLeft, WalkLeft},
shortSTR,
"",
},
{
"left middle short sub tree",
[]WalkInstruction{WalkLeft, WalkRight},
shortSTR,
"",
},
{
"right middle short sub tree",
[]WalkInstruction{WalkRight, WalkLeft},
shortSTR,
"",
},
{
"right most short sub tree",
[]WalkInstruction{WalkRight, WalkRight},
shortSTR,
"",
},
{
"left most tall sub tree",
[]WalkInstruction{WalkLeft},
tallSTR,
"",
},
{
"right most tall sub tree",
[]WalkInstruction{WalkRight},
tallSTR,
"",
},
rach-id marked this conversation as resolved.
Show resolved Hide resolved
{
"right most tall sub tree",
[]WalkInstruction{WalkRight, WalkRight, WalkRight, WalkRight},
tallSTR,
"did not find sub tree root",
},
}

for _, tt := range tests {
foundSubRoot, err := strc.walk(highestRoot, tt.path)
if tt.expectedError != "" {
require.Error(t, err, tt.name)
assert.Contains(t, err.Error(), tt.expectedError, tt.name)
continue
}

require.NoError(t, err)
require.Equal(t, tt.subTreeRoot, foundSubRoot, tt.name)
}
}

func TestEDSSubRootCacher(t *testing.T) {
oss := uint64(8)
d := coretestutil.GenerateRandNamespacedRawData(uint32(oss*oss), consts.NamespaceSize, consts.ShareSize-consts.NamespaceSize)
stc := NewCachedSubtreeCacher(oss)

eds, err := rsmt2d.ComputeExtendedDataSquare(d, consts.DefaultCodec(), stc.Constructor)
require.NoError(t, err)

dah := da.NewDataAvailabilityHeader(eds)

for i := range dah.RowsRoots[:oss] {
expectedSubTreeRoots := calculateSubTreeRoots(eds.Row(uint(i))[:oss], 2)
require.NotNil(t, expectedSubTreeRoots)
// note: the depth is one greater than expected because we're dividing
// the row in half when we calculate the expected roots.
result, err := stc.GetSubTreeRoot(dah, i, []WalkInstruction{false, false, false})
require.NoError(t, err)
assert.Equal(t, expectedSubTreeRoots[0], result)
}
}

// calculateSubTreeRoots generates the subtree roots for a given row. If the
// selected depth is too deep for the tree, nil is returned. It relies on
// passing a row whose length is a power of 2 and assumes that the row is
// **NOT** extended since calculating subtree root for erasure data using the
// nmt wrapper makes this difficult.
func calculateSubTreeRoots(row [][]byte, depth int) [][]byte {
subLeafRange := len(row)
for i := 0; i < depth; i++ {
subLeafRange = subLeafRange / 2
}

if subLeafRange == 0 || subLeafRange%2 != 0 {
return nil
}

count := len(row) / subLeafRange
subTreeRoots := make([][]byte, count)
chunks := chunkSlice(row, subLeafRange)
for i, rowChunk := range chunks {
tr := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(row)))
for j, r := range rowChunk {
c := (i * subLeafRange) + j
tr.Push(r, rsmt2d.SquareIndex{
Axis: uint(rsmt2d.Row),
Cell: uint(c),
})
}
subTreeRoots[i] = tr.Root()
}

return subTreeRoots
}

func chunkSlice(slice [][]byte, chunkSize int) [][][]byte {
var chunks [][][]byte
for i := 0; i < len(slice); i += chunkSize {
end := i + chunkSize

// necessary check to avoid slicing beyond
// slice capacity
if end > len(slice) {
end = len(slice)
}

chunks = append(chunks, slice[i:end])
}

return chunks
}
28 changes: 28 additions & 0 deletions testutil/coretestutil/core.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package coretestutil

import (
"bytes"
"math/rand"
"sort"
)

func GenerateRandNamespacedRawData(total, nidSize, leafSize uint32) [][]byte {
data := make([][]byte, total)
for i := uint32(0); i < total; i++ {
nid := make([]byte, nidSize)
rand.Read(nid)
data[i] = nid
}
sortByteArrays(data)
for i := uint32(0); i < total; i++ {
d := make([]byte, leafSize)
rand.Read(d)
data[i] = append(data[i], d...)
}

return data
}

func sortByteArrays(src [][]byte) {
sort.Slice(src, func(i, j int) bool { return bytes.Compare(src[i], src[j]) < 0 })
}