From 5ee5b964e7ac88bd3ff03de1cb96db2e0d2f05f8 Mon Sep 17 00:00:00 2001 From: evan-forbes Date: Tue, 21 Sep 2021 10:00:53 -0500 Subject: [PATCH] refactor to better accomodate real world use cases (celestia node) Co-authored-by: rene <41963722+renaynay@users.noreply.github.com> --- pkg/consts/consts.go | 3 ++ pkg/da/data_availability_header.go | 57 +++++++++++-------------- pkg/da/data_availability_header_test.go | 43 ++++++++++++++----- pkg/wrapper/nmt_wrapper_test.go | 8 ++-- types/shares_test.go | 2 +- 5 files changed, 65 insertions(+), 48 deletions(-) diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index c7d9025fb2..59acaf4b08 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -63,5 +63,8 @@ var ( // NewBaseHashFunc change accordingly if another hash.Hash should be used as a base hasher in the NMT: NewBaseHashFunc = sha256.New + // DefaultCodec is the defualt codec creator used for data erasure + // TODO(ismail): for better efficiency and a larger number shares + // we should switch to the rsmt2d.LeopardFF16 codec: DefaultCodec = rsmt2d.NewRSGF8Codec ) diff --git a/pkg/da/data_availability_header.go b/pkg/da/data_availability_header.go index b9293f7204..61e1f3c0c3 100644 --- a/pkg/da/data_availability_header.go +++ b/pkg/da/data_availability_header.go @@ -14,8 +14,8 @@ import ( ) const ( - maxDAHSize = consts.MaxSquareSize * 2 - minDAHSize = consts.MinSquareSize * 2 + maxExtendedSquareWidth = consts.MaxSquareSize * 2 + minExtendedSquareWidth = consts.MinSquareSize * 2 ) // DataAvailabilityHeader (DAHeader) contains the row and column roots of the erasure @@ -38,10 +38,23 @@ type DataAvailabilityHeader struct { } // NewDataAvailabilityHeader generates a DataAvailability header using the provided square size and shares -func NewDataAvailabilityHeader(squareSize uint64, shares [][]byte) (DataAvailabilityHeader, error) { +func NewDataAvailabilityHeader(eds *rsmt2d.ExtendedDataSquare) DataAvailabilityHeader { + // generate the row and col roots using the EDS + dah := DataAvailabilityHeader{ + RowsRoots: eds.RowRoots(), + ColumnRoots: eds.ColRoots(), + } + + // generate the hash of the data using the new roots + dah.Hash() + + return dah +} + +func ExtendShares(squareSize uint64, shares [][]byte) (*rsmt2d.ExtendedDataSquare, error) { // Check that square size is with range if squareSize < consts.MinSquareSize || squareSize > consts.MaxSquareSize { - return DataAvailabilityHeader{}, fmt.Errorf( + return nil, fmt.Errorf( "invalid square size: min %d max %d provided %d", consts.MinSquareSize, consts.MaxSquareSize, @@ -50,32 +63,14 @@ func NewDataAvailabilityHeader(squareSize uint64, shares [][]byte) (DataAvailabi } // check that valid number of shares have been provided if squareSize*squareSize != uint64(len(shares)) { - return DataAvailabilityHeader{}, fmt.Errorf( + return nil, fmt.Errorf( "must provide valid number of shares for square size: got %d wanted %d", len(shares), squareSize*squareSize, ) } - tree := wrapper.NewErasuredNamespacedMerkleTree(squareSize) - - // TODO(ismail): for better efficiency and a larger number shares - // we should switch to the rsmt2d.LeopardFF16 codec: - extendedDataSquare, err := rsmt2d.ComputeExtendedDataSquare(shares, rsmt2d.NewRSGF8Codec(), tree.Constructor) - if err != nil { - return DataAvailabilityHeader{}, err - } - - // generate the row and col roots using the EDS - dah := DataAvailabilityHeader{ - RowsRoots: extendedDataSquare.RowRoots(), - ColumnRoots: extendedDataSquare.ColRoots(), - } - - // generate the hash of the data using the new roots - dah.Hash() - - return dah, nil + return rsmt2d.ComputeExtendedDataSquare(shares, consts.DefaultCodec(), tree.Constructor) } // String returns hex representation of merkle hash of the DAHeader. @@ -143,16 +138,16 @@ func (dah *DataAvailabilityHeader) ValidateBasic() error { if dah == nil { return errors.New("nil data availability header is not valid") } - if len(dah.ColumnRoots) < minDAHSize || len(dah.RowsRoots) < minDAHSize { + if len(dah.ColumnRoots) < minExtendedSquareWidth || len(dah.RowsRoots) < minExtendedSquareWidth { return fmt.Errorf( "minimum valid DataAvailabilityHeader has at least %d row and column roots", - minDAHSize, + minExtendedSquareWidth, ) } - if len(dah.ColumnRoots) > maxDAHSize || len(dah.RowsRoots) > maxDAHSize { + if len(dah.ColumnRoots) > maxExtendedSquareWidth || len(dah.RowsRoots) > maxExtendedSquareWidth { return fmt.Errorf( "maximum valid DataAvailabilityHeader has at most %d row and column roots", - maxDAHSize, + maxExtendedSquareWidth, ) } if len(dah.ColumnRoots) != len(dah.RowsRoots) { @@ -190,13 +185,11 @@ func MinDataAvailabilityHeader() DataAvailabilityHeader { for i := 0; i < consts.MinSharecount; i++ { shares[i] = tailPaddingShare } - dah, err := NewDataAvailabilityHeader( - consts.MinSquareSize, - shares, - ) + eds, err := ExtendShares(consts.MinSquareSize, shares) if err != nil { panic(err) } + dah := NewDataAvailabilityHeader(eds) return dah } diff --git a/pkg/da/data_availability_header_test.go b/pkg/da/data_availability_header_test.go index 3b16e5ac39..b5540e09d4 100644 --- a/pkg/da/data_availability_header_test.go +++ b/pkg/da/data_availability_header_test.go @@ -37,15 +37,13 @@ func TestNewDataAvailabilityHeader(t *testing.T) { type test struct { name string expectedHash []byte - expectedErr bool squareSize uint64 shares [][]byte } tests := []test{ { - name: "typical", - expectedErr: false, + name: "typical", expectedHash: []byte{ 0xfe, 0x9c, 0x6b, 0xd8, 0xe5, 0x7c, 0xd1, 0x5d, 0x1f, 0xd6, 0x55, 0x7e, 0x87, 0x7d, 0xd9, 0x7d, 0xdb, 0xf2, 0x66, 0xfa, 0x60, 0x24, 0x2d, 0xb3, 0xa0, 0x9c, 0x4f, 0x4e, 0x5b, 0x2a, 0x2c, 0x2a, @@ -54,8 +52,7 @@ func TestNewDataAvailabilityHeader(t *testing.T) { shares: generateShares(4, 1), }, { - name: "max square size", - expectedErr: false, + name: "max square size", expectedHash: []byte{ 0xe2, 0x87, 0x23, 0xd0, 0x2d, 0x54, 0x25, 0x5f, 0x79, 0x43, 0x8e, 0xfb, 0xb7, 0xe8, 0xfa, 0xf5, 0xbf, 0x93, 0x50, 0xb3, 0x64, 0xd0, 0x4f, 0xa7, 0x7b, 0xb1, 0x83, 0x3b, 0x8, 0xba, 0xd3, 0xa4, @@ -63,6 +60,29 @@ func TestNewDataAvailabilityHeader(t *testing.T) { squareSize: consts.MaxSquareSize, shares: generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 99), }, + } + + for _, tt := range tests { + tt := tt + eds, err := ExtendShares(tt.squareSize, tt.shares) + require.NoError(t, err) + resdah := NewDataAvailabilityHeader(eds) + require.Equal(t, tt.squareSize*2, uint64(len(resdah.ColumnRoots)), tt.name) + require.Equal(t, tt.squareSize*2, uint64(len(resdah.RowsRoots)), tt.name) + require.Equal(t, tt.expectedHash, resdah.hash, tt.name) + } +} + +func TestExtendShares(t *testing.T) { + type test struct { + name string + expectedHash []byte + expectedErr bool + squareSize uint64 + shares [][]byte + } + + tests := []test{ { name: "too large square size", expectedErr: true, @@ -79,15 +99,13 @@ func TestNewDataAvailabilityHeader(t *testing.T) { for _, tt := range tests { tt := tt - resdah, err := NewDataAvailabilityHeader(tt.squareSize, tt.shares) + eds, err := ExtendShares(tt.squareSize, tt.shares) if tt.expectedErr { require.NotNil(t, err) continue } require.NoError(t, err) - require.Equal(t, tt.squareSize*2, uint64(len(resdah.ColumnRoots)), tt.name) - require.Equal(t, tt.squareSize*2, uint64(len(resdah.RowsRoots)), tt.name) - require.Equal(t, tt.expectedHash, resdah.hash, tt.name) + require.Equal(t, tt.squareSize*2, eds.Width(), tt.name) } } @@ -98,8 +116,9 @@ func TestDataAvailabilityHeaderProtoConversion(t *testing.T) { } shares := generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 1) - bigdah, err := NewDataAvailabilityHeader(consts.MaxSquareSize, shares) + eds, err := ExtendShares(consts.MaxSquareSize, shares) require.NoError(t, err) + bigdah := NewDataAvailabilityHeader(eds) tests := []test{ { @@ -133,8 +152,10 @@ func Test_DAHValidateBasic(t *testing.T) { } shares := generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 1) - bigdah, err := NewDataAvailabilityHeader(consts.MaxSquareSize, shares) + eds, err := ExtendShares(consts.MaxSquareSize, shares) require.NoError(t, err) + bigdah := NewDataAvailabilityHeader(eds) + // make a mutant dah that has too many roots var tooBigDah DataAvailabilityHeader tooBigDah.ColumnRoots = make([][]byte, consts.MaxSquareSize*consts.MaxSquareSize) diff --git a/pkg/wrapper/nmt_wrapper_test.go b/pkg/wrapper/nmt_wrapper_test.go index 8bd4e83eb8..a1cd7580b1 100644 --- a/pkg/wrapper/nmt_wrapper_test.go +++ b/pkg/wrapper/nmt_wrapper_test.go @@ -27,7 +27,7 @@ func TestPushErasuredNamespacedMerkleTree(t *testing.T) { tree := n.Constructor() // push test data to the tree - for i, d := range generateErasuredData(t, tc.squareSize, rsmt2d.NewRSGF8Codec()) { + for i, d := range generateErasuredData(t, tc.squareSize, consts.DefaultCodec()) { // push will panic if there's an error tree.Push(d, rsmt2d.SquareIndex{Axis: uint(0), Cell: uint(i)}) } @@ -64,7 +64,7 @@ func TestErasureNamespacedMerkleTreePanics(t *testing.T) { "push over square size", assert.PanicTestFunc( func() { - data := generateErasuredData(t, 16, rsmt2d.NewRSGF8Codec()) + data := generateErasuredData(t, 16, consts.DefaultCodec()) n := NewErasuredNamespacedMerkleTree(uint64(15)) tree := n.Constructor() for i, d := range data { @@ -76,7 +76,7 @@ func TestErasureNamespacedMerkleTreePanics(t *testing.T) { "push in incorrect lexigraphic order", assert.PanicTestFunc( func() { - data := generateErasuredData(t, 16, rsmt2d.NewRSGF8Codec()) + data := generateErasuredData(t, 16, consts.DefaultCodec()) n := NewErasuredNamespacedMerkleTree(uint64(16)) tree := n.Constructor() for i := len(data) - 1; i > 0; i-- { @@ -104,7 +104,7 @@ func TestExtendedDataSquare(t *testing.T) { tree := NewErasuredNamespacedMerkleTree(uint64(squareSize)) - _, err := rsmt2d.ComputeExtendedDataSquare(raw, rsmt2d.NewRSGF8Codec(), tree.Constructor) + _, err := rsmt2d.ComputeExtendedDataSquare(raw, consts.DefaultCodec(), tree.Constructor) assert.NoError(t, err) } diff --git a/types/shares_test.go b/types/shares_test.go index ddf7c29b07..e5cd4abe3e 100644 --- a/types/shares_test.go +++ b/types/shares_test.go @@ -252,7 +252,7 @@ func TestDataFromSquare(t *testing.T) { shares, _ := data.ComputeShares() rawShares := shares.RawShares() - eds, err := rsmt2d.ComputeExtendedDataSquare(rawShares, rsmt2d.NewRSGF8Codec(), rsmt2d.NewDefaultTree) + eds, err := rsmt2d.ComputeExtendedDataSquare(rawShares, consts.DefaultCodec(), rsmt2d.NewDefaultTree) if err != nil { t.Error(err) }