From d8a92727ade4a7b29ca65b20756d81a1561307ae Mon Sep 17 00:00:00 2001 From: Rootul P Date: Wed, 5 Oct 2022 13:25:57 -0600 Subject: [PATCH] feat: throw error when parsing unsupported share version (#837) Closes https://github.com/celestiaorg/celestia-app/issues/830 specifically: > Alternatively, if we don't want this to be a ProcessProposal validity rule, we can enforce this type of check when we parse the data square back into block data (e.g. [here](https://github.com/celestiaorg/celestia-app/blob/6b86e91eea1063c27c1ae461eddbe505f778df08/pkg/shares/parse_compact_shares.go#L14)) Co-authored-by: CHAMI Rachid --- pkg/appconsts/appconsts.go | 3 +++ pkg/shares/compact_shares_test.go | 37 +++++++++++++++++++++++++----- pkg/shares/parse_compact_shares.go | 28 ++++++++++++++++------ pkg/shares/parse_sparse_shares.go | 33 ++++++++++++++++++-------- pkg/shares/share_merging.go | 6 ++--- pkg/shares/sparse_shares_test.go | 37 +++++++++++++++++++++++++++--- 6 files changed, 115 insertions(+), 29 deletions(-) diff --git a/pkg/appconsts/appconsts.go b/pkg/appconsts/appconsts.go index 82c9802e0a..386be34409 100644 --- a/pkg/appconsts/appconsts.go +++ b/pkg/appconsts/appconsts.go @@ -135,6 +135,9 @@ var ( // contains less space for data than a ContinuationCompactShare because the // first compact share includes a total data length varint. FirstCompactShareContentSize = ContinuationCompactShareContentSize - FirstCompactShareDataLengthBytes + + // SupportedShareVersions is a list of supported share versions. + SupportedShareVersions = []uint8{ShareVersion} ) // numberOfBytesVarint calculates the number of bytes needed to write a varint of n diff --git a/pkg/shares/compact_shares_test.go b/pkg/shares/compact_shares_test.go index 2b0664864d..08adca786b 100644 --- a/pkg/shares/compact_shares_test.go +++ b/pkg/shares/compact_shares_test.go @@ -23,7 +23,7 @@ func TestCompactShareWriter(t *testing.T) { } shares := w.Export() rawShares := ToBytes(shares) - rawResTxs, err := parseCompactShares(rawShares) + rawResTxs, err := parseCompactShares(rawShares, appconsts.SupportedShareVersions) resTxs := coretypes.ToTxs(rawResTxs) require.NoError(t, err) @@ -95,7 +95,7 @@ func Test_processCompactShares(t *testing.T) { shares := SplitTxs(txs) rawShares := ToBytes(shares) - parsedTxs, err := parseCompactShares(rawShares) + parsedTxs, err := parseCompactShares(rawShares, appconsts.SupportedShareVersions) if err != nil { t.Error(err) } @@ -113,7 +113,7 @@ func Test_processCompactShares(t *testing.T) { shares := SplitTxs(txs) rawShares := ToBytes(shares) - parsedTxs, err := parseCompactShares(rawShares) + parsedTxs, err := parseCompactShares(rawShares, appconsts.SupportedShareVersions) if err != nil { t.Error(err) } @@ -166,11 +166,36 @@ func TestContiguousCompactShareContainsInfoByte(t *testing.T) { assert.Equal(t, byte(want), infoByte) } -func Test_parseCompactSharesReturnsErrForShareWithStartIndicatorFalse(t *testing.T) { +func Test_parseCompactSharesErrors(t *testing.T) { + type testCase struct { + name string + rawShares [][]byte + } + txs := generateRandomTransaction(2, appconsts.ContinuationCompactShareContentSize*4) shares := SplitTxs(txs) rawShares := ToBytes(shares) - _, err := parseCompactShares(rawShares[1:]) // the second share has the message start indicator set to false - assert.Error(t, err) + unsupportedShareVersion := 5 + infoByte, _ := NewInfoByte(uint8(unsupportedShareVersion), true) + shareWithUnsupportedShareVersion := rawShares[0] + shareWithUnsupportedShareVersion[appconsts.NamespaceSize] = byte(infoByte) + + testCases := []testCase{ + { + "share with start indicator false", + rawShares[1:], // set the first share to the second share which has the start indicator set to false + }, + { + "share with unsupported share version", + [][]byte{shareWithUnsupportedShareVersion}, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + _, err := parseCompactShares(tt.rawShares, appconsts.SupportedShareVersions) + assert.Error(t, err) + }) + } } diff --git a/pkg/shares/parse_compact_shares.go b/pkg/shares/parse_compact_shares.go index 5e0974dc3b..b988a63300 100644 --- a/pkg/shares/parse_compact_shares.go +++ b/pkg/shares/parse_compact_shares.go @@ -1,22 +1,36 @@ package shares import ( + "bytes" "encoding/binary" "errors" + "fmt" "github.com/celestiaorg/celestia-app/pkg/appconsts" ) -// parseCompactShares takes raw shares and extracts out transactions, -// intermediate state roots, or evidence. The returned [][]byte do not have -// namespaces, info bytes, data length delimiter, or unit length -// delimiters and are ready to be unmarshalled -func parseCompactShares(shares [][]byte) (data [][]byte, err error) { - if len(shares) == 0 { +// parseCompactShares returns data (transactions, intermediate state roots, or +// evidence) based on the contents of rawShares and supportedShareVersions. If +// rawShares contains a share with a version that isn't present in +// supportedShareVersions, an error is returned. The returned data [][]byte does +// not have namespaces, info bytes, data length delimiter, or unit length +// delimiters and are ready to be unmarshalled. +func parseCompactShares(rawShares [][]byte, supportedShareVersions []uint8) (data [][]byte, err error) { + if len(rawShares) == 0 { return nil, nil } + shares := FromBytes(rawShares) + for _, share := range shares { + infoByte, err := share.InfoByte() + if err != nil { + return nil, err + } + if !bytes.Contains(supportedShareVersions, []byte{infoByte.Version()}) { + return nil, fmt.Errorf("unsupported share version %v is not present in the list of supported share versions %v", infoByte.Version(), supportedShareVersions) + } + } - ss := newShareStack(shares) + ss := newShareStack(rawShares) return ss.resolve() } diff --git a/pkg/shares/parse_sparse_shares.go b/pkg/shares/parse_sparse_shares.go index c048d707cf..4f6b9f1369 100644 --- a/pkg/shares/parse_sparse_shares.go +++ b/pkg/shares/parse_sparse_shares.go @@ -10,11 +10,24 @@ import ( "github.com/celestiaorg/celestia-app/pkg/appconsts" ) -// parseSparseShares iterates through raw shares and parses out individual messages. -func parseSparseShares(shares [][]byte) ([]coretypes.Message, error) { - if len(shares) == 0 { +// parseSparseShares iterates through rawShares and parses out individual +// messages. It returns an error if a rawShare contains a share version that +// isn't present in supportedShareVersions. +func parseSparseShares(rawShares [][]byte, supportedShareVersions []uint8) ([]coretypes.Message, error) { + if len(rawShares) == 0 { return nil, nil } + shares := FromBytes(rawShares) + for _, share := range shares { + infoByte, err := share.InfoByte() + if err != nil { + return nil, err + } + if !bytes.Contains(supportedShareVersions, []byte{infoByte.Version()}) { + return nil, fmt.Errorf("unsupported share version %v is not present in the list of supported share versions %v", infoByte.Version(), supportedShareVersions) + } + } + // msgs returned msgs := []coretypes.Message{} currentMsgLen := 0 @@ -31,21 +44,21 @@ func parseSparseShares(shares [][]byte) ([]coretypes.Message, error) { isNewMessage = true } // iterate through all the shares and parse out each msg - for i := 0; i < len(shares); i++ { + for i := 0; i < len(rawShares); i++ { dataLen = len(currentMsg.Data) + appconsts.SparseShareContentSize switch { case isNewMessage: - nextMsgChunk, nextMsgLen, err := ParseDelimiter(shares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:]) + nextMsgChunk, nextMsgLen, err := ParseDelimiter(rawShares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:]) if err != nil { return nil, err } // the current share is namespaced padding so we ignore it - if bytes.Equal(shares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:], appconsts.NameSpacedPaddedShareBytes) { + if bytes.Equal(rawShares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:], appconsts.NameSpacedPaddedShareBytes) { continue } currentMsgLen = int(nextMsgLen) - nid := shares[i][:appconsts.NamespaceSize] - infoByte, err := ParseInfoByte(shares[i][appconsts.NamespaceSize : appconsts.NamespaceSize+appconsts.ShareInfoBytes][0]) + nid := rawShares[i][:appconsts.NamespaceSize] + infoByte, err := ParseInfoByte(rawShares[i][appconsts.NamespaceSize : appconsts.NamespaceSize+appconsts.ShareInfoBytes][0]) if err != nil { panic(err) } @@ -66,12 +79,12 @@ func parseSparseShares(shares [][]byte) ([]coretypes.Message, error) { isNewMessage = false // this entire share contains a chunk of message that we need to save case currentMsgLen > dataLen: - currentMsg.Data = append(currentMsg.Data, shares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:]...) + currentMsg.Data = append(currentMsg.Data, rawShares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:]...) // this share contains the last chunk of data needed to complete the // message case currentMsgLen <= dataLen: remaining := currentMsgLen - len(currentMsg.Data) + appconsts.NamespaceSize + appconsts.ShareInfoBytes - currentMsg.Data = append(currentMsg.Data, shares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:remaining]...) + currentMsg.Data = append(currentMsg.Data, rawShares[i][appconsts.NamespaceSize+appconsts.ShareInfoBytes:remaining]...) saveMessage() } } diff --git a/pkg/shares/share_merging.go b/pkg/shares/share_merging.go index d44280cd0e..6602eedb58 100644 --- a/pkg/shares/share_merging.go +++ b/pkg/shares/share_merging.go @@ -79,7 +79,7 @@ func Merge(eds *rsmt2d.ExtendedDataSquare) (coretypes.Data, error) { // ParseTxs collects all of the transactions from the shares provided func ParseTxs(shares [][]byte) (coretypes.Txs, error) { // parse the sharse - rawTxs, err := parseCompactShares(shares) + rawTxs, err := parseCompactShares(shares, appconsts.SupportedShareVersions) if err != nil { return nil, err } @@ -97,7 +97,7 @@ func ParseTxs(shares [][]byte) (coretypes.Txs, error) { func ParseEvd(shares [][]byte) (coretypes.EvidenceData, error) { // the raw data returned does not have length delimiters or namespaces and // is ready to be unmarshaled - rawEvd, err := parseCompactShares(shares) + rawEvd, err := parseCompactShares(shares, appconsts.SupportedShareVersions) if err != nil { return coretypes.EvidenceData{}, err } @@ -125,7 +125,7 @@ func ParseEvd(shares [][]byte) (coretypes.EvidenceData, error) { // ParseMsgs collects all messages from the shares provided func ParseMsgs(shares [][]byte) (coretypes.Messages, error) { - msgList, err := parseSparseShares(shares) + msgList, err := parseSparseShares(shares, appconsts.SupportedShareVersions) if err != nil { return coretypes.Messages{}, err } diff --git a/pkg/shares/sparse_shares_test.go b/pkg/shares/sparse_shares_test.go index f3cc2659aa..ffa249f10d 100644 --- a/pkg/shares/sparse_shares_test.go +++ b/pkg/shares/sparse_shares_test.go @@ -1,10 +1,12 @@ package shares import ( + "bytes" "fmt" "testing" "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/nmt/namespace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" coretypes "github.com/tendermint/tendermint/types" @@ -49,7 +51,7 @@ func Test_parseSparseShares(t *testing.T) { shares, _ := SplitMessages(0, nil, msgs.MessagesList, false) rawShares := ToBytes(shares) - parsedMsgs, err := parseSparseShares(rawShares) + parsedMsgs, err := parseSparseShares(rawShares, appconsts.SupportedShareVersions) if err != nil { t.Error(err) } @@ -70,7 +72,7 @@ func Test_parseSparseShares(t *testing.T) { rawShares[i] = []byte(share) } - parsedMsgs, err := parseSparseShares(rawShares) + parsedMsgs, err := parseSparseShares(rawShares, appconsts.SupportedShareVersions) if err != nil { t.Error(err) } @@ -84,6 +86,35 @@ func Test_parseSparseShares(t *testing.T) { } } +func Test_parseSparseSharesErrors(t *testing.T) { + type testCase struct { + name string + rawShares [][]byte + } + + unsupportedShareVersion := 5 + infoByte, _ := NewInfoByte(uint8(unsupportedShareVersion), true) + + rawShare := []byte{} + rawShare = append(rawShare, namespace.ID{1, 1, 1, 1, 1, 1, 1, 1}...) + rawShare = append(rawShare, byte(infoByte)) + rawShare = append(rawShare, bytes.Repeat([]byte{0}, appconsts.ShareSize-len(rawShare))...) + + tests := []testCase{ + { + "share with unsupported share version", + [][]byte{rawShare}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(*testing.T) { + _, err := parseSparseShares(tt.rawShares, appconsts.SupportedShareVersions) + assert.Error(t, err) + }) + } +} + func TestParsePaddedMsg(t *testing.T) { msgWr := NewSparseShareSplitter() randomSmallMsg := generateRandomMessage(appconsts.SparseShareContentSize / 2) @@ -101,7 +132,7 @@ func TestParsePaddedMsg(t *testing.T) { msgWr.WriteNamespacedPaddedShares(10) shares := msgWr.Export() rawShares := ToBytes(shares) - pmsgs, err := parseSparseShares(rawShares) + pmsgs, err := parseSparseShares(rawShares, appconsts.SupportedShareVersions) require.NoError(t, err) require.Equal(t, msgs.MessagesList, pmsgs) }