diff --git a/types/shares.go b/types/shares.go index 55238a1db0..a96ff36969 100644 --- a/types/shares.go +++ b/types/shares.go @@ -62,12 +62,12 @@ func (m Message) MarshalDelimited() ([]byte, error) { // appendToShares appends raw data as shares. // Used for messages. func appendToShares(shares []NamespacedShare, nid namespace.ID, rawData []byte) []NamespacedShare { - if len(rawData) < consts.MsgShareSize { + if len(rawData) <= consts.MsgShareSize { rawShare := []byte(append(nid, rawData...)) paddedShare := zeroPadIfNecessary(rawShare, consts.ShareSize) share := NamespacedShare{paddedShare, nid} shares = append(shares, share) - } else { // len(rawData) >= MsgShareSize + } else { // len(rawData) > MsgShareSize shares = append(shares, split(rawData, nid)...) } return shares @@ -102,7 +102,9 @@ func split(rawData []byte, nid namespace.ID) []NamespacedShare { rawData = rawData[consts.MsgShareSize:] for len(rawData) > 0 { shareSizeOrLen := min(consts.MsgShareSize, len(rawData)) - rawShare := []byte(append(nid, rawData[:shareSizeOrLen]...)) + rawShare := make([]byte, consts.NamespaceSize) + copy(rawShare, nid) + rawShare = append(rawShare, rawData[:shareSizeOrLen]...) paddedShare := zeroPadIfNecessary(rawShare, consts.ShareSize) share := NamespacedShare{paddedShare, nid} shares = append(shares, share) diff --git a/types/shares_test.go b/types/shares_test.go index ee18db6863..5fed4814ee 100644 --- a/types/shares_test.go +++ b/types/shares_test.go @@ -2,10 +2,13 @@ package types import ( "bytes" + "crypto/rand" "reflect" + "sort" "testing" "github.com/celestiaorg/nmt/namespace" + "github.com/stretchr/testify/assert" "github.com/tendermint/tendermint/internal/libs/protoio" "github.com/tendermint/tendermint/pkg/consts" ) @@ -176,3 +179,53 @@ func Test_zeroPadIfNecessary(t *testing.T) { }) } } + +func Test_appendToSharesOverwrite(t *testing.T) { + var shares NamespacedShares + + // generate some arbitrary namespaced shares first share that must be split + newShare := generateRandomNamespacedShares(1, consts.MsgShareSize+1)[0] + + // make a copy of the portion of the share to check if it's overwritten later + extraCopy := make([]byte, consts.MsgShareSize) + copy(extraCopy, newShare.Share[:consts.MsgShareSize]) + + // use appendToShares to add our new share + appendToShares(shares, newShare.ID, newShare.Share) + + // check if the original share data has been overwritten. + assert.Equal(t, extraCopy, []byte(newShare.Share[:consts.MsgShareSize])) +} + +func generateRandomNamespacedShares(count, leafSize int) []NamespacedShare { + shares := generateRandNamespacedRawData(count, consts.NamespaceSize, leafSize) + nsShares := make(NamespacedShares, count) + for i, s := range shares { + nsShares[i] = NamespacedShare{ + Share: s[consts.NamespaceSize:], + ID: s[:consts.NamespaceSize], + } + } + return nsShares +} + +func generateRandNamespacedRawData(total, nidSize, leafSize int) [][]byte { + data := make([][]byte, total) + for i := 0; i < total; i++ { + nid := make([]byte, nidSize) + rand.Read(nid) + data[i] = nid + } + sortByteArrays(data) + for i := 0; i < total; i++ { + d := make([]byte, leafSize) + rand.Read(d) + data[i] = append(data[i], d...) + } + + return data +} + +func sortByteArrays(src [][]byte) { + sort.Slice(src, func(i, j int) bool { return bytes.Compare(src[i], src[j]) < 0 }) +}