Skip to content

Commit

Permalink
refactor to better accomodate real world use cases (celestia node)
Browse files Browse the repository at this point in the history
Co-authored-by: rene <[email protected]>
  • Loading branch information
evan-forbes and renaynay committed Sep 21, 2021
1 parent 4beddda commit 5ee5b96
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 48 deletions.
3 changes: 3 additions & 0 deletions pkg/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,8 @@ var (
// NewBaseHashFunc change accordingly if another hash.Hash should be used as a base hasher in the NMT:
NewBaseHashFunc = sha256.New

// DefaultCodec is the defualt codec creator used for data erasure
// TODO(ismail): for better efficiency and a larger number shares
// we should switch to the rsmt2d.LeopardFF16 codec:
DefaultCodec = rsmt2d.NewRSGF8Codec
)
57 changes: 25 additions & 32 deletions pkg/da/data_availability_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

const (
maxDAHSize = consts.MaxSquareSize * 2
minDAHSize = consts.MinSquareSize * 2
maxExtendedSquareWidth = consts.MaxSquareSize * 2
minExtendedSquareWidth = consts.MinSquareSize * 2
)

// DataAvailabilityHeader (DAHeader) contains the row and column roots of the erasure
Expand All @@ -38,10 +38,23 @@ type DataAvailabilityHeader struct {
}

// NewDataAvailabilityHeader generates a DataAvailability header using the provided square size and shares
func NewDataAvailabilityHeader(squareSize uint64, shares [][]byte) (DataAvailabilityHeader, error) {
func NewDataAvailabilityHeader(eds *rsmt2d.ExtendedDataSquare) DataAvailabilityHeader {
// generate the row and col roots using the EDS
dah := DataAvailabilityHeader{
RowsRoots: eds.RowRoots(),
ColumnRoots: eds.ColRoots(),
}

// generate the hash of the data using the new roots
dah.Hash()

return dah
}

func ExtendShares(squareSize uint64, shares [][]byte) (*rsmt2d.ExtendedDataSquare, error) {
// Check that square size is with range
if squareSize < consts.MinSquareSize || squareSize > consts.MaxSquareSize {
return DataAvailabilityHeader{}, fmt.Errorf(
return nil, fmt.Errorf(
"invalid square size: min %d max %d provided %d",
consts.MinSquareSize,
consts.MaxSquareSize,
Expand All @@ -50,32 +63,14 @@ func NewDataAvailabilityHeader(squareSize uint64, shares [][]byte) (DataAvailabi
}
// check that valid number of shares have been provided
if squareSize*squareSize != uint64(len(shares)) {
return DataAvailabilityHeader{}, fmt.Errorf(
return nil, fmt.Errorf(
"must provide valid number of shares for square size: got %d wanted %d",
len(shares),
squareSize*squareSize,
)
}

tree := wrapper.NewErasuredNamespacedMerkleTree(squareSize)

// TODO(ismail): for better efficiency and a larger number shares
// we should switch to the rsmt2d.LeopardFF16 codec:
extendedDataSquare, err := rsmt2d.ComputeExtendedDataSquare(shares, rsmt2d.NewRSGF8Codec(), tree.Constructor)
if err != nil {
return DataAvailabilityHeader{}, err
}

// generate the row and col roots using the EDS
dah := DataAvailabilityHeader{
RowsRoots: extendedDataSquare.RowRoots(),
ColumnRoots: extendedDataSquare.ColRoots(),
}

// generate the hash of the data using the new roots
dah.Hash()

return dah, nil
return rsmt2d.ComputeExtendedDataSquare(shares, consts.DefaultCodec(), tree.Constructor)
}

// String returns hex representation of merkle hash of the DAHeader.
Expand Down Expand Up @@ -143,16 +138,16 @@ func (dah *DataAvailabilityHeader) ValidateBasic() error {
if dah == nil {
return errors.New("nil data availability header is not valid")
}
if len(dah.ColumnRoots) < minDAHSize || len(dah.RowsRoots) < minDAHSize {
if len(dah.ColumnRoots) < minExtendedSquareWidth || len(dah.RowsRoots) < minExtendedSquareWidth {
return fmt.Errorf(
"minimum valid DataAvailabilityHeader has at least %d row and column roots",
minDAHSize,
minExtendedSquareWidth,
)
}
if len(dah.ColumnRoots) > maxDAHSize || len(dah.RowsRoots) > maxDAHSize {
if len(dah.ColumnRoots) > maxExtendedSquareWidth || len(dah.RowsRoots) > maxExtendedSquareWidth {
return fmt.Errorf(
"maximum valid DataAvailabilityHeader has at most %d row and column roots",
maxDAHSize,
maxExtendedSquareWidth,
)
}
if len(dah.ColumnRoots) != len(dah.RowsRoots) {
Expand Down Expand Up @@ -190,13 +185,11 @@ func MinDataAvailabilityHeader() DataAvailabilityHeader {
for i := 0; i < consts.MinSharecount; i++ {
shares[i] = tailPaddingShare
}
dah, err := NewDataAvailabilityHeader(
consts.MinSquareSize,
shares,
)
eds, err := ExtendShares(consts.MinSquareSize, shares)
if err != nil {
panic(err)
}
dah := NewDataAvailabilityHeader(eds)
return dah
}

Expand Down
43 changes: 32 additions & 11 deletions pkg/da/data_availability_header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ func TestNewDataAvailabilityHeader(t *testing.T) {
type test struct {
name string
expectedHash []byte
expectedErr bool
squareSize uint64
shares [][]byte
}

tests := []test{
{
name: "typical",
expectedErr: false,
name: "typical",
expectedHash: []byte{
0xfe, 0x9c, 0x6b, 0xd8, 0xe5, 0x7c, 0xd1, 0x5d, 0x1f, 0xd6, 0x55, 0x7e, 0x87, 0x7d, 0xd9, 0x7d,
0xdb, 0xf2, 0x66, 0xfa, 0x60, 0x24, 0x2d, 0xb3, 0xa0, 0x9c, 0x4f, 0x4e, 0x5b, 0x2a, 0x2c, 0x2a,
Expand All @@ -54,15 +52,37 @@ func TestNewDataAvailabilityHeader(t *testing.T) {
shares: generateShares(4, 1),
},
{
name: "max square size",
expectedErr: false,
name: "max square size",
expectedHash: []byte{
0xe2, 0x87, 0x23, 0xd0, 0x2d, 0x54, 0x25, 0x5f, 0x79, 0x43, 0x8e, 0xfb, 0xb7, 0xe8, 0xfa, 0xf5,
0xbf, 0x93, 0x50, 0xb3, 0x64, 0xd0, 0x4f, 0xa7, 0x7b, 0xb1, 0x83, 0x3b, 0x8, 0xba, 0xd3, 0xa4,
},
squareSize: consts.MaxSquareSize,
shares: generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 99),
},
}

for _, tt := range tests {
tt := tt
eds, err := ExtendShares(tt.squareSize, tt.shares)
require.NoError(t, err)
resdah := NewDataAvailabilityHeader(eds)
require.Equal(t, tt.squareSize*2, uint64(len(resdah.ColumnRoots)), tt.name)
require.Equal(t, tt.squareSize*2, uint64(len(resdah.RowsRoots)), tt.name)
require.Equal(t, tt.expectedHash, resdah.hash, tt.name)
}
}

func TestExtendShares(t *testing.T) {
type test struct {
name string
expectedHash []byte
expectedErr bool
squareSize uint64
shares [][]byte
}

tests := []test{
{
name: "too large square size",
expectedErr: true,
Expand All @@ -79,15 +99,13 @@ func TestNewDataAvailabilityHeader(t *testing.T) {

for _, tt := range tests {
tt := tt
resdah, err := NewDataAvailabilityHeader(tt.squareSize, tt.shares)
eds, err := ExtendShares(tt.squareSize, tt.shares)
if tt.expectedErr {
require.NotNil(t, err)
continue
}
require.NoError(t, err)
require.Equal(t, tt.squareSize*2, uint64(len(resdah.ColumnRoots)), tt.name)
require.Equal(t, tt.squareSize*2, uint64(len(resdah.RowsRoots)), tt.name)
require.Equal(t, tt.expectedHash, resdah.hash, tt.name)
require.Equal(t, tt.squareSize*2, eds.Width(), tt.name)
}
}

Expand All @@ -98,8 +116,9 @@ func TestDataAvailabilityHeaderProtoConversion(t *testing.T) {
}

shares := generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 1)
bigdah, err := NewDataAvailabilityHeader(consts.MaxSquareSize, shares)
eds, err := ExtendShares(consts.MaxSquareSize, shares)
require.NoError(t, err)
bigdah := NewDataAvailabilityHeader(eds)

tests := []test{
{
Expand Down Expand Up @@ -133,8 +152,10 @@ func Test_DAHValidateBasic(t *testing.T) {
}

shares := generateShares(consts.MaxSquareSize*consts.MaxSquareSize, 1)
bigdah, err := NewDataAvailabilityHeader(consts.MaxSquareSize, shares)
eds, err := ExtendShares(consts.MaxSquareSize, shares)
require.NoError(t, err)
bigdah := NewDataAvailabilityHeader(eds)

// make a mutant dah that has too many roots
var tooBigDah DataAvailabilityHeader
tooBigDah.ColumnRoots = make([][]byte, consts.MaxSquareSize*consts.MaxSquareSize)
Expand Down
8 changes: 4 additions & 4 deletions pkg/wrapper/nmt_wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestPushErasuredNamespacedMerkleTree(t *testing.T) {
tree := n.Constructor()

// push test data to the tree
for i, d := range generateErasuredData(t, tc.squareSize, rsmt2d.NewRSGF8Codec()) {
for i, d := range generateErasuredData(t, tc.squareSize, consts.DefaultCodec()) {
// push will panic if there's an error
tree.Push(d, rsmt2d.SquareIndex{Axis: uint(0), Cell: uint(i)})
}
Expand Down Expand Up @@ -64,7 +64,7 @@ func TestErasureNamespacedMerkleTreePanics(t *testing.T) {
"push over square size",
assert.PanicTestFunc(
func() {
data := generateErasuredData(t, 16, rsmt2d.NewRSGF8Codec())
data := generateErasuredData(t, 16, consts.DefaultCodec())
n := NewErasuredNamespacedMerkleTree(uint64(15))
tree := n.Constructor()
for i, d := range data {
Expand All @@ -76,7 +76,7 @@ func TestErasureNamespacedMerkleTreePanics(t *testing.T) {
"push in incorrect lexigraphic order",
assert.PanicTestFunc(
func() {
data := generateErasuredData(t, 16, rsmt2d.NewRSGF8Codec())
data := generateErasuredData(t, 16, consts.DefaultCodec())
n := NewErasuredNamespacedMerkleTree(uint64(16))
tree := n.Constructor()
for i := len(data) - 1; i > 0; i-- {
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestExtendedDataSquare(t *testing.T) {

tree := NewErasuredNamespacedMerkleTree(uint64(squareSize))

_, err := rsmt2d.ComputeExtendedDataSquare(raw, rsmt2d.NewRSGF8Codec(), tree.Constructor)
_, err := rsmt2d.ComputeExtendedDataSquare(raw, consts.DefaultCodec(), tree.Constructor)
assert.NoError(t, err)
}

Expand Down
2 changes: 1 addition & 1 deletion types/shares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func TestDataFromSquare(t *testing.T) {
shares, _ := data.ComputeShares()
rawShares := shares.RawShares()

eds, err := rsmt2d.ComputeExtendedDataSquare(rawShares, rsmt2d.NewRSGF8Codec(), rsmt2d.NewDefaultTree)
eds, err := rsmt2d.ComputeExtendedDataSquare(rawShares, consts.DefaultCodec(), rsmt2d.NewDefaultTree)
if err != nil {
t.Error(err)
}
Expand Down

0 comments on commit 5ee5b96

Please sign in to comment.